攻击客户机1
这段代码是用于进行攻击的部分。它试图通过使用客户端0的信息(实体嵌入和关系嵌入)来破解客户端1的信息(部分实体和关系的嵌入)。攻击的过程包括以下步骤:
-
加载训练得到的模型参数:通过
torch.load()
函数加载之前训练得到的模型参数,其中ent_embed
和rel_embed
分别表示实体嵌入和关系嵌入。 -
创建客户端0的字典信息:从数据中提取客户端0的实体,并将其与对应的实体嵌入组成字典
c0_ent_embed_dict
。 -
对客户端0的实体进行映射:由于在客户端1的数据中,实体的索引可能与客户端0的数据中不同,因此需要建立映射关系
c0_mapping
来将客户端0的实体索引映射到客户端1的实体索引。 -
在客户端1上执行攻击:对客户端1进行攻击,通过在客户端1的实体池中选择一部分实体(由
p
参数控制选择比例),然后计算这些实体与客户端0的实体的嵌入之间的余弦距离,并选择距离最近的客户端0的实体作为对应的伪造实体,形成伪造实体列表syc_ent_list
。 -
计算攻击成功率:计算成功破解的实体的比例和关系的比例,即伪造实体列表中与客户端1的实体池中实体相同的实体数量与客户端1数据中所有三元组数量之间的比值。
请注意,这段代码是为了演示攻击方法,并且使用余弦距离来测量实体之间的相似性。在实际应用中,可能需要更复杂的攻击策略和更准确的相似性度量来实现更高效的攻击。
已知客户机0的信息,攻击客户机1的信息
import torch
import pickle
import numpy as np
import random
from scipy import spatial
emb = torch.load('./state/fb15k237_fed3_fed_TransE.best', map_location=torch.device('cpu'))
ent_embed = emb['ent_embed']
rel_embed = emb['rel_embed']
data = pickle.load(open("Fed_data/FB15K237-Fed3.pkl", "rb" ))
#生成第一个客户端的字典信息
c0_ent = np.unique(data[0]['train']['edge_index'])
c0_ent_embed_dict = {}
value = ent_embed[0]
for idx,ent in enumerate(c0_ent):
c0_ent_embed_dict[ent] = value[idx]
c0_mapping = dict(zip(data[0]['train']['edge_index'][0], data[0]['train']['edge_index_ori'][0]))
c0_mapping.update(dict(zip(data[0]['train']['edge_index'][1], data[0]['train']['edge_index_ori'][1])))
c0_ent_embed_dict_mapped = dict((c0_mapping[key], value) for (key, value) in c0_ent_embed_dict.items())
c0_ent_pool_mapped = [c0_mapping[i] for i in c0_ent]
# map local to global
c1_mapping = dict(zip(data[1]['train']['edge_index'][0], data[1]['train']['edge_index_ori'][0]))
c1_mapping.update(dict(zip(data[1]['train']['edge_index'][1], data[1]['train']['edge_index_ori'][1])))
c1_ent = np.unique(data[1]['train']['edge_index'])
random.seed(10)
np.random.seed(10)
p = 1
c1_ent_pool = np.random.choice(c1_ent, int(p * len(c1_ent)), replace = False)
c1_ent_embed = ent_embed[1][[c1_ent_pool]]
c1_ent_pool_mapped = [c1_mapping[i] for i in c1_ent_pool]
syn_ent_list = [] # synthetic entity label
for i in c1_ent_pool:
c1_ent_embed = ent_embed[1][i]
count = 0
loss_bound = 0
ent_idx = []
for j in c0_ent_embed_dict_mapped:
loss = spatial.distance.cosine(c1_ent_embed.detach().numpy(), c0_ent_embed_dict_mapped[j].detach().numpy())
if count == 0: # first round
loss_bound = loss
ent_idx.append(j)
count += 1
else:
if loss < loss_bound:
loss_bound = loss
ent_idx.append(j)
syn_ent_list.append(ent_idx[-1]) # global index of the entity
tru_ent_list = [c1_mapping[i] for i in c1_ent_pool]
# calculate the number of correct reconstruction
sum(first == second for (first, second) in zip(syn_ent_list, tru_ent_list)) / len(c1_ent)
c0_rel = np.unique(data[0]['train']['edge_type_ori'])
# creat relation pool based on selected entities (global relation index)
c1_triple_all = np.array([data[1]['train']['edge_index_ori'][0],
data[1]['train']['edge_type_ori'],
data[1]['train']['edge_index_ori'][1]])
tru_trr_list = []
# the adversary knows all relation embeddings and their corresponding index, so here we use ori directly
len_c1_triple = c1_triple_all[0].shape[0]
for i in range(len_c1_triple):
triple = c1_triple_all[:,i]
h, r, t= triple[0], triple[1], triple[2]
if (h in c1_ent_pool_mapped) and (t in c1_ent_pool_mapped):
if h not in tru_trr_list:
tru_trr_list.append(h)
if t not in tru_trr_list:
tru_trr_list.append(t)
syn_trr_list = []
for (first, second) in zip(syn_ent_list, tru_ent_list):
if first == second:
syn_trr_list.append(first)
# calculate the number of correct reconstruction
len(list(set(syn_trr_list).intersection(tru_trr_list))) / len_c1_triple
标签:攻击,代码,list,ent,FedR,c1,c0,data,embed
From: https://www.cnblogs.com/csjywu01/p/17588867.html