首页 > 其他分享 >FedR代码学习文档

FedR代码学习文档

时间:2023-07-21 15:33:36浏览次数:27  
标签:... data 代码 args edge 文档 FedR type self

main.py

参数设置,进入主函数

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # parser.add_argument('--data_path', default='Fed_data/WN18RR-Fed3.pkl', type=str)
    parser.add_argument('--data_path', default='Fed_data/DDB14-Fed3.pkl', type=str)
    parser.add_argument('--name', default='wn18rr_fed3_fed_TransE', type=str)
    parser.add_argument('--state_dir', '-state_dir', default='./state', type=str)
    parser.add_argument('--log_dir', '-log_dir', default='./log', type=str)
    parser.add_argument('--tb_log_dir', '-tb_log_dir', default='./tb_log', type=str)
    parser.add_argument('--run_mode', default='FedR', choices=['FedE', 'Single', 'test_pretrain'])
    parser.add_argument('--num_multi', default=3, type=int)

    parser.add_argument('--model', default='TransE', choices=['TransE', 'RotatE', 'DistMult', 'ComplEx'])

    # one task hyperparam
    parser.add_argument('--one_client_idx', default=0, type=int)
    parser.add_argument('--max_epoch', default=10000, type=int)
    parser.add_argument('--log_per_epoch', default=1, type=int)
    parser.add_argument('--check_per_epoch', default=10, type=int)


    parser.add_argument('--batch_size', default=512, type=int)
    parser.add_argument('--test_batch_size', default=16, type=int)
    parser.add_argument('--num_neg', default=256, type=int)
    parser.add_argument('--lr', default=0.001, type=int)

    # for FedE
    parser.add_argument('--num_client', default=3, type=int)
    parser.add_argument('--max_round', default=10000, type=int)
    parser.add_argument('--local_epoch', default=3, type=int)
    parser.add_argument('--fraction', default=1, type=float)
    parser.add_argument('--log_per_round', default=1, type=int)
    parser.add_argument('--check_per_round', default=5, type=int)

    parser.add_argument('--early_stop_patience', default=5, type=int)
    parser.add_argument('--gamma', default=10.0, type=float)
    parser.add_argument('--epsilon', default=2.0, type=float)
    parser.add_argument('--hidden_dim', default=128, type=int)
    parser.add_argument('--gpu', default='0', type=str)
    parser.add_argument('--num_cpu', default=10, type=int)
    parser.add_argument('--adversarial_temperature', default=1.0, type=float)

    # parser.add_argument('--negative_adversarial_sampling', default=True, type=bool)
    parser.add_argument('--seed', default=12345, type=int)

    args = parser.parse_args()
    args_str = json.dumps(vars(args))

    args.gpu = torch.device('cuda:' + args.gpu)
    # args.gpu = torch.device(("cuda:" + args.gpu) if torch.cuda.is_available() else "cpu")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    init_dir(args)
    writer = SummaryWriter(os.path.join(args.tb_log_dir, args.name))
    args.writer = writer
    init_logger(args)
    logging.info(args_str)

    if args.run_mode == 'FedR':
        all_data = pickle.load(open(args.data_path, 'rb'))
        learner = FedR(args, all_data)
        learner.train()
    elif args.run_mode == 'Single':
        all_data = pickle.load(open(args.data_path, 'rb'))
        data = all_data[args.one_client_idx]
        learner = KGERunner(args, data)
        learner.train()

数据导入

.pkl形式的数据 (通过csv的代码可以进行转换)
这里的数据分给三个客户端,每个客户端当中又有train,valid,test

  • edge_index:是一个二维数组,表示第i个三元组的起始节点和终止节点
  • edge_type:表示第i个三元组的relation
  • edge_index_ori:
  • edge_type_ori:
,train,test,valid
0,"{'edge_index': array([[3515, 3614, 3299, ...,  246, 3912, 2853],
       [ 961, 2501,  703, ...,  211, 1904,  442]], dtype=int64), 'edge_type': array([1, 9, 1, ..., 7, 1, 1], dtype=int64), 'edge_index_ori': array([[1796, 4767, 3939, ...,  345, 3215, 4054],
       [1787, 3036,  950, ...,  341, 3204,  537]], dtype=int64), 'edge_type_ori': array([2, 8, 2, ..., 7, 2, 2], dtype=int64)}","{'edge_index': array([[ 392,  822,  331, ..., 1207,  247,  902],
       [ 199,  261,  175, ...,  802,  195,  540]], dtype=int64), 'edge_type': array([1, 1, 1, ..., 1, 2, 1], dtype=int64), 'edge_index_ori': array([[ 424,  655,  373, ..., 1531,  364, 1047],
       [ 416,  530,  366, ..., 1530,  360,  527]], dtype=int64), 'edge_type_ori': array([2, 2, 2, ..., 2, 3, 2], dtype=int64)}","{'edge_index': array([[ 358, 1204, 2395, ...,  210, 1139,  371],
       [2581,  813, 1564, ...,  211,  583,  288]], dtype=int64), 'edge_type': array([1, 1, 6, ..., 8, 1, 1], dtype=int64), 'edge_index_ori': array([[ 393, 1601, 2983, ...,  229, 1223,  572],
       [2692, 1580, 1188, ...,  341,  644,  554]], dtype=int64), 'edge_type_ori': array([2, 2, 5, ..., 6, 2, 2], dtype=int64)}"
1,"{'edge_index': array([[4881, 5080, 2512, ...,  531,  876, 3547],
       [4882,   30,   38, ...,  532,  574,   95]], dtype=int64), 'edge_type': array([10,  0,  1, ...,  1,  4,  0], dtype=int64), 'edge_index_ori': array([[8515, 2880, 4337, ..., 3921,  721, 1996],
       [8391, 2695, 4333, ...,  234, 1556, 2442]], dtype=int64), 'edge_type_ori': array([0, 2, 4, ..., 4, 5, 2], dtype=int64)}","{'edge_index': array([[ 309, 1661, 2880, ..., 1861, 1831,  652],
       [  72, 2083, 2154, ...,  127, 2730, 3940]], dtype=int64), 'edge_type': array([0, 0, 1, ..., 0, 0, 0], dtype=int64), 'edge_index_ori': array([[ 402, 4842,  827, ..., 2229, 2742,  228],
       [ 379, 5955,  826, ..., 2256,  890, 5910]], dtype=int64), 'edge_type_ori': array([2, 2, 4, ..., 2, 2, 2], dtype=int64)}","{'edge_index': array([[1821,   91, 1049, ...,   59,  749,  560],
       [1800,  398, 2353, ...,  511,   34,  381]], dtype=int64), 'edge_type': array([0, 0, 0, ..., 0, 0, 0], dtype=int64), 'edge_index_ori': array([[4224,   62,  620, ..., 1059, 2398, 1696],
       [5502, 3330, 4833, ...,  266,   13,  125]], dtype=int64), 'edge_type_ori': array([2, 2, 2, ..., 2, 2, 2], dtype=int64)}"
2,"{'edge_index': array([[1048, 5151, 2026, ..., 3552, 1835,  897],
       [ 286,   33, 2180, ..., 4712, 1836,   56]], dtype=int64), 'edge_type': array([2, 2, 0, ..., 8, 6, 2], dtype=int64), 'edge_index_ori': array([[5172, 4261, 1779, ..., 6148,  222, 1803],
       [6817, 6663, 1859, ..., 9069, 2987, 6810]], dtype=int64), 'edge_type_ori': array([ 0,  0,  2, ...,  8, 12,  0], dtype=int64)}","{'edge_index': array([[ 508, 5263, 1230, ...,  577, 1646,  439],
       [ 649, 4329,  649, ..., 1496, 2298,  598]], dtype=int64), 'edge_type': array([0, 0, 0, ..., 0, 0, 0], dtype=int64), 'edge_index_ori': array([[ 630, 1847, 1297, ...,  266,  576,  876],
       [1256, 8628, 1256, ..., 4247, 1734, 3295]], dtype=int64), 'edge_type_ori': array([2, 2, 2, ..., 2, 2, 2], dtype=int64)}","{'edge_index': array([[  91, 1798, 1622, ..., 2358, 4665,  427],
       [ 672, 2482,   82, ...,  506,  136, 1011]], dtype=int64), 'edge_type': array([0, 3, 0, ..., 0, 0, 0], dtype=int64), 'edge_index_ori': array([[ 527, 3702,  566, ..., 1209,  251,   39],
       [ 224, 2336, 4523, ...,  666, 1172, 4816]], dtype=int64), 'edge_type_ori': array([2, 5, 2, ..., 2, 2, 2], dtype=int64)}"

数据分发

1.将隐私数据分发到客户机 (客户拥有),初始化服务器
2.统计客户机测试集、验证集的数量,以及权重数量

class FedR(object):
    def __init__(self, args, all_data):
        self.args = args

        train_dataloader_list, valid_dataloader_list, test_dataloader_list, \
            self.rel_freq_mat, ent_embed_list, nrelation = get_all_clients(all_data, args)

        self.args.nrelation = nrelation # question

        # client
        self.num_clients = len(train_dataloader_list)
        # Create client objects for each client
        self.clients = []
        for i in range(self.num_clients):
            client = Client(args, i, all_data[i], train_dataloader_list[i], valid_dataloader_list[i],
                            test_dataloader_list[i], ent_embed_list[i])
            self.clients.append(client)

        # Create the server object
        self.server = Server(args, nrelation)

        #   统计客户机测试集、验证集的数量,以及权重数量
        # Calculate total test data size and test evaluation weights
        self.total_test_data_size = 0
        for client in self.clients:
            self.total_test_data_size += len(client.test_dataloader.dataset)

        self.test_eval_weights = []
        for client in self.clients:
            weight = len(client.test_dataloader.dataset) / self.total_test_data_size
            self.test_eval_weights.append(weight)

        # Calculate total valid data size and valid evaluation weights
        self.total_valid_data_size = 0
        for client in self.clients:
            self.total_valid_data_size += len(client.valid_dataloader.dataset)

        self.valid_eval_weights = []
        for client in self.clients:
            weight = len(client.valid_dataloader.dataset) / self.total_valid_data_size
            self.valid_eval_weights.append(weight)
对初始数据集进行分发

1.all_rel = np.union1d(all_rel, data['train']['edge_type_ori']).reshape(-1):在这里,通过 np.union1d 函数将当前客户端的训练数据中的关系类型与 all_rel 数组进行合并并去除重复项。最后通过 reshape(-1) 将结果变为一维数组,并更新 all_rel。
2.train_dataloader_list, valid_dataloader_list, test_dataloader_list, ent_embed_list, rel_freq_list 初始化:这里分别初始化了存储训练、验证和测试数据加载器、实体嵌入向量以及关系频率的列表。
3.for data in tqdm(all_data):这个循环遍历所有客户端的数据,并对每个客户端进行处理
4.nentity = len(np.unique(data['train']['edge_index'])): 这行代码计算当前客户端训练数据中的实体数量,通过获取边索引 'edge_index' 并使用 np.unique 函数获取独特的实体索引,然后通过 len 函数计算实体的数量。
5.构建训练、验证和测试数据集:这部分代码通过整理当前客户端的训练、验证和测试数据来创建相应的数据集。训练数据集使用了 TrainDataset 类,而验证和测试数据集则使用了 valid_dataset和TestDataset 类。
6.构建数据加载器:用于将数据分发给不同客户端
7.初始化实体嵌入向量 ent_embed:这部分代码根据模型的不同(args.model)初始化实体嵌入向量 ent_embed,并将其添加到 ent_embed_list 列表中。
8.计算关系频率:计算不同客户机中relation的出现频率,并将其保存在 rel_freq 中。这样可以用于在后续任务中根据关系频率进行权重调整等操作。
9.rel_freq_mat = torch.stack(rel_freq_list).to(args.gpu):将关系频率列表 rel_freq_list 转换为 PyTorch 张量,并将其放置在指定的 GPU 上。
10.返回结果:最后,函数返回所有客户端的数据加载器、关系频率矩阵、实体嵌入向量列表和总关系数量 nrelation

def get_all_clients(all_data, args):
    all_rel = np.array([], dtype=int)
    for data in all_data:
        all_rel = np.union1d(all_rel, data['train']['edge_type_ori']).reshape(-1)
    nrelation = len(all_rel) # all relations of training set in all clients

    train_dataloader_list = []
    test_dataloader_list = []
    valid_dataloader_list = []

    ent_embed_list = []

    rel_freq_list = []

    for data in tqdm(all_data): # in a client
        nentity = len(np.unique(data['train']['edge_index'])) # entities of training in a client

        train_triples = np.stack((data['train']['edge_index'][0],
                                  data['train']['edge_type_ori'],
                                  data['train']['edge_index'][1])).T

        valid_triples = np.stack((data['valid']['edge_index'][0],
                                  data['valid']['edge_type_ori'],
                                  data['valid']['edge_index'][1])).T

        test_triples = np.stack((data['test']['edge_index'][0],
                                 data['test']['edge_type_ori'],
                                 data['test']['edge_index'][1])).T

        client_mask_rel = np.setdiff1d(np.arange(nrelation),
                                       np.unique(data['train']['edge_type_ori'].reshape(-1)), assume_unique=True)

        all_triples = np.concatenate([train_triples, valid_triples, test_triples]) # in a client
        train_dataset = TrainDataset(train_triples, nentity, args.num_neg)
        valid_dataset = TestDataset(valid_triples, all_triples, nentity, client_mask_rel)
        test_dataset = TestDataset(test_triples, all_triples, nentity, client_mask_rel)

        # dataloader,数据划分
        train_dataloader = DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=True,
            collate_fn=TrainDataset.collate_fn
        )
        train_dataloader_list.append(train_dataloader)

        valid_dataloader = DataLoader(
            valid_dataset,
            batch_size=args.test_batch_size,
            collate_fn=TestDataset.collate_fn
        )
        valid_dataloader_list.append(valid_dataloader)

        test_dataloader = DataLoader(
            test_dataset,
            batch_size=args.test_batch_size,
            collate_fn=TestDataset.collate_fn
        )
        test_dataloader_list.append(test_dataloader)

        embedding_range = torch.Tensor([(args.gamma + args.epsilon) / args.hidden_dim])

        '''use n of entity in train or all (train, valid, test)?'''
        if args.model in ['RotatE', 'ComplEx']:
            ent_embed = torch.zeros(nentity, args.hidden_dim*2).to(args.gpu).requires_grad_()
        else:
            ent_embed = torch.zeros(nentity, args.hidden_dim).to(args.gpu).requires_grad_()
        nn.init.uniform_(
            tensor=ent_embed,
            a=-embedding_range.item(),
            b=embedding_range.item()
        )
        ent_embed_list.append(ent_embed)

        rel_freq = torch.zeros(nrelation)
        for r in data['train']['edge_type_ori'].reshape(-1):
            rel_freq[r] += 1
        rel_freq_list.append(rel_freq)

    rel_freq_mat = torch.stack(rel_freq_list).to(args.gpu)

    return train_dataloader_list, valid_dataloader_list, test_dataloader_list, \
           rel_freq_mat, ent_embed_list, nrelation

客户端的数据分发

每个客户端都有数据,并且拥有自己的模型

class Client(object):
    def __init__(self, args, client_id, data, train_dataloader,
                 valid_dataloader, test_dataloader, ent_embed):
        self.args = args
        self.data = data
        self.train_dataloader = train_dataloader
        self.valid_dataloader = valid_dataloader
        self.test_dataloader = test_dataloader
        self.ent_embed = ent_embed
        self.client_id = client_id

        self.score_local = []
        self.score_global = []

        self.kge_model = KGEModel(args, args.model)
        self.rel_embed = None
class KGEModel(nn.Module):
    def __init__(self, args, model_name):
        super(KGEModel, self).__init__()
        self.model_name = model_name
        self.embedding_range = torch.Tensor([(args.gamma + args.epsilon) / args.hidden_dim])
        self.gamma = nn.Parameter(
            torch.Tensor([args.gamma]),
            requires_grad=False
        )
服务器的数据分发

1.embedding_range = torch.Tensor([(args.gamma + args.epsilon) / args.hidden_dim]):这行代码计算了关系嵌入向量初始化的范围 embedding_range。参数 args.gamma 和 args.epsilon 是模型的一些超参数,用于控制关系嵌入向量初始化范围的大小。args.hidden_dim 是模型中嵌入向量的维度。
2.self.rel_embed = torch.zeros(nrelation, args.hidden_dim2).to(args.gpu).requires_grad_():如果模型类型是 'ComplEx',则创建一个形状为 (nrelation, args.hidden_dim2) 的全零张量 self.rel_embed,用于存储关系嵌入向量。nrelation 是关系的数量,args.hidden_dim*2 是每个关系嵌入向量的维度。通过 .to(args.gpu) 将张量放置在指定的 GPU 上(如果使用了 GPU)。最后,通过 requires_grad_() 方法指定张量需要计算梯度,用于后续的模型训练和优化。
3.nn.init.uniform_(tensor=self.rel_embed, a=-embedding_range.item(), b=embedding_range.item()):这行代码使用均匀分布初始化关系嵌入向量 self.rel_embed。关系嵌入向量的值被随机采样自均匀分布,范围是从 -embedding_range.item() 到 embedding_range.item()。

    def __init__(self, args, nrelation):
        self.args = args
        embedding_range = torch.Tensor([(args.gamma + args.epsilon) / args.hidden_dim])
        if args.model in ['ComplEx']:
            self.rel_embed = torch.zeros(nrelation, args.hidden_dim*2).to(args.gpu).requires_grad_()
        else:
            self.rel_embed = torch.zeros(nrelation, args.hidden_dim).to(args.gpu).requires_grad_()
        nn.init.uniform_(
            tensor=self.rel_embed,
            a=-embedding_range.item(),
            b=embedding_range.item()
        )
        self.nrelation = nrelation

模型训练

标签:...,data,代码,args,edge,文档,FedR,type,self
From: https://www.cnblogs.com/csjywu01/p/17571009.html

相关文章

  • swagger文档和 knife4j 文档
    老版本的swagger-bootstrap-ui,可以显示非RestController,可以测试html页面显示,可以和springfox-swagger-ui配合显示<!--swagger--><dependency><groupId>io.springfox</groupId><artifactId>springfox-swagger2</artifactId><vers......
  • sam自动生成mask代码解析
    要自动生成mask,请向“SamAutomaticMaskGenerator”类注入SAM模型(需要先初始化SAM模型)importsyssys.path.append("..")fromsegment_anythingimportsam_model_registry,SamAutomaticMaskGenerator,SamPredictorsam_checkpoint="sam_vit_b_01ec64.pth"model_type......
  • R语言隐马尔可夫模型(HMM)识别不断变化的股市状况股票指数预测实战|附代码数据
    全文下载链接: http://tecdat.cn/?p=1557最近我们被客户要求撰写关于隐马尔可夫模型(HMM)的研究报告,包括一些图形和统计输出。“了解不同的股市状况,改变交易策略,对股市收益有很大的影响。弄清楚何时开始或何时止损,调整风险和资金管理技巧,都取决于股市的当前状况 ( 点击文末“阅......
  • python弹出窗口的代码
    Python弹出窗口的代码弹出窗口是指在图形用户界面(GUI)中弹出一个窗口来与用户进行交互。在Python中,我们可以使用不同的库来创建弹出窗口,其中最常用的是tkinter库。tkinter是Python的标准GUI库,它提供了创建并管理窗口、按钮、标签等GUI元素的功能。安装tkinter在使用tkinter库之......
  • python代码优化 编译cuda
    Python代码优化编译CUDAPython是一种高级编程语言,通常被用于快速开发和原型设计。然而,由于其动态类型和解释执行特性,Python在执行大规模计算密集型任务时可能会变得相对较慢。为了解决这个问题,我们可以使用CUDA编译Python代码。CUDA(ComputeUnifiedDeviceArchitecture)是一种由......
  • 拓端tecdat|R语言贝叶斯Metropolis-Hastings Gibbs 吉布斯采样器估计变点指数分布分析
    原文链接:http://tecdat.cn/?p=26578 原文出处:拓端数据部落公众号最近我们被客户要求撰写关于吉布斯采样器的研究报告,包括一些图形和统计输出。指数分布是泊松过程中事件之间时间的概率分布,因此它用于预测到下一个事件的等待时间,例如,您需要在公共汽车站等待的时间,直到下一班车......
  • 基于R语言股票市场收益的统计可视化分析|附代码数据
    全文链接:http://tecdat.cn/?p=16453 最近我们被客户要求撰写关于股票市场的研究报告,包括一些图形和统计输出。金融市场上最重要的任务之一就是分析各种投资的历史收益要执行此分析,我们需要资产的历史数据。数据提供者很多,有些是免费的,大多数是付费的。在本文中,我们将使用Yahoo......
  • PHP代码练习Demo02
    <!DOCTYPEhtml><html><body><?phpecho"<h2>PHPisfun!</h2>";echo"helloworld"; echo"I'mabouttolearnPHP!<br>";echo"This","string","was&qu......
  • WINUI 后台代码绑定
    以image为例 前端进行绑定时哪下,注意下述代码中用的是x:Bind,用它进行绑定时需要标明其绑定ViewModel的key值;用Bingding时则不需要。<Imagex:Name="CTCoronalCImage"Width="1010"Height="442"HorizontalAlignment="Stretch"VerticalAlignm......
  • vscode python代码提示
    VSCodePython代码提示简介VSCode(VisualStudioCode)是一款轻量级的代码编辑器,具有丰富的扩展功能。通过安装Python扩展,可以在VSCode中进行Python开发,并享受强大的代码提示功能。本文将介绍如何在VSCode中使用Python代码提示。安装Python插件在开始使用Python代码提示之前,......