首页 > 其他分享 >异构图中节点的分类/回归

异构图中节点的分类/回归

时间:2023-05-01 09:55:37浏览次数:37  
标签:异构 graph hetero college user skill 节点 feats 图中

异构图中节点的分类/回归

导入包

import numpy as np
import torch
import dgl
import torch.nn as nn
import torch.nn.functional as F
import dgl.nn as dglnn

创建一个异构图

设置这个图中的节点个数和边的个数

n_users = 100   #user节点个数
n_jobspre = 500     #jobpre节点的个数
n_uj = 3000     #边'uj'的个数
n_ju = 3000    #边'ju'的个数
n_college = 100     #college节点个数
n_cu = 100
n_uc = 100
n_hetero_features = 20     #相当于要做的嵌入维度
n_user_classes = 20 #假设得出的job种类有这么多
n_skills = 500       #skill节点个数
n_us = 4000
n_su = 4000
n_js = 4000
n_sj = 4000

设置每条边对应的头和尾(如果有数据集,则直接导入就好)这里自动生成 还不是因为没有数据集

uj_src = np.random.randint(0,n_users,n_uj)       #user_jobpre中user的编号
uj_dst = np.random.randint(0,n_jobspre,n_uj)     #user_jobpre中jobpre的编号
uc_src = np.random.randint(0,n_users,n_cu)       #user_college中user的编号
uc_dst = np.random.randint(0,n_college,n_cu)     #user_college中college的编号
us_src = np.random.randint(0,n_users,n_us)      #user_skill中user的编号
us_dst = np.random.randint(0,n_skills,n_us)       #user_skill中skill的编号
js_src = np.random.randint(0,n_jobspre,n_js)
js_dst = np.random.randint(0,n_skills,n_js)

利用dgl构建异构图

hetero_graph = dgl.heterograph({
    ('user', 'uj', 'jobpre'): (uj_src, uj_dst),
    ('jobpre', 'ju', 'user'): (uj_dst, uj_src),
    ('user', 'uc', 'college'): (uc_src, uc_dst),
    ('college', 'cu', 'user'): (uc_dst, uc_src),
    ('user', 'us', 'skill'): (us_src, us_dst),
    ('skill', 'su', 'user'): (us_dst, us_src),
    ('jobpre','js','skill'):(js_src,js_dst),
    ('skill','sj','jobpre'):(js_dst,js_src),
    })

初始化每个节点的嵌入

hetero_graph.nodes['user'].data['feature'] = torch.randn(n_users, n_hetero_features)  #给user node 添加属性,相当于用户嵌入维度为20
hetero_graph.nodes['jobpre'].data['feature'] = torch.randn(n_jobspre, n_hetero_features) 
hetero_graph.nodes['college'].data['feature'] = torch.randn(n_college, n_hetero_features)
hetero_graph.nodes['skill'].data['feature'] = torch.randn(n_skills, n_hetero_features)
hetero_graph.nodes['user'].data['label'] = torch.randint(0, n_user_classes, (n_users,))    #用户的个人标签
hetero_graph.nodes['user'].data['train_mask'] = torch.zeros(n_users, dtype=torch.bool).bernoulli(0.6) #选出一些计算损失函数

定义一个异构卷积层

class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super().__init__()
        # 实例化HeteroGraphConv,in_feats是输入特征的维度,out_feats是输出特征的维度,aggregate是聚合函数的类型
        self.conv1 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feats, hid_feats)
            for rel in rel_names}, aggregate='sum')
        self.conv2 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(hid_feats, out_feats)
            for rel in rel_names}, aggregate='sum')

    def forward(self, graph, inputs):   #inputs: node_features
        # 输入是节点的特征字典
        h = self.conv1(graph, inputs)
        h = {k: F.relu(v) for k, v in h.items()}
        h = self.conv2(graph, h)
        return h

定义模型并进行训练

model = RGCN(n_hetero_features, 20, n_user_classes, hetero_graph.etypes)       #['clicked-by', 'disliked-by', 'click', 'dislike', 'follow', 'followed-by']
user_feats = hetero_graph.nodes['user'].data['feature']         #用户的特征嵌入
jobpre_feats = hetero_graph.nodes['jobpre'].data['feature']      #物品的特征嵌入
college_feats = hetero_graph.nodes['college'].data['feature']      #物品的特征嵌入
skill_feats = hetero_graph.nodes['skill'].data['feature']      #物品的特征嵌入
labels = hetero_graph.nodes['user'].data['label']
train_mask = hetero_graph.nodes['user'].data['train_mask']

node_features = {'user': user_feats, 'jobpre': jobpre_feats,'college':college_feats,'skill':skill_feats}        #所有用户的特征嵌入     所有物品的特征嵌入
opt = torch.optim.Adam(model.parameters())

for epoch in range(5):
    model.train()
    # 使用所有节点的特征进行前向传播计算,并提取输出的user节点嵌入
    logits = model(hetero_graph, node_features)['user']
    h_dict = model(hetero_graph, {'user': user_feats, 'jobpre': jobpre_feats,'college':college_feats,'skill':skill_feats} )
    h_user = h_dict['user']     #模型每次迭代后得出的 用户的特征嵌入
    print(h_user[2])  # 第2个用户的特征,输出特征为5,因为用户的分类为5个
    h_jobpre= h_dict['jobpre']       #模型每次迭代后得出的 物品的特征嵌入
    h_college = h_dict['college']
    h_skill = h_dict['skill']
    
    # 计算损失值
    loss = F.cross_entropy(logits[train_mask], labels[train_mask])
    # 计算验证集的准确度。在本例中省略。
    # 进行反向传播计算
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())
    torch.save(model.state_dict(), "./model/main01" + "_" + str(epoch))   #保存模型

使用(测试)模型

model = RGCN(n_hetero_features, 20, n_user_classes, hetero_graph.etypes)
model.load_state_dict(torch.load("./model/main01_4" ))

user_feats = hetero_graph.nodes['user'].data['feature']
print(user_feats )
print(user_feats[0])

结果截图:

标签:异构,graph,hetero,college,user,skill,节点,feats,图中
From: https://www.cnblogs.com/monster-little/p/17366206.html

相关文章

  • day 60 84. 柱状图中最大的矩形
    给定n个非负整数,用来表示柱状图中各个柱子的高度。每个柱子彼此相邻,且宽度为1。求在该柱状图中,能够勾勒出来的矩形的最大面积。  classSolution{publicintlargestRectangleArea(int[]heights){if(heights==null||heights.length==0){......
  • 基于python实现将AWS-ElastiCache-的Reserved_Cache_Nodes-预留节点及费用的信息统计
    在AWS-ElastiCache中,Reserved_Cache_Nodes-预留节点,也就类似于EC2与RDS的RI(预留实例),都是为了节省成本而选择预付费用的一种方式,当AWS账号有多个时,如何通过编程的方式批量获取所有账号所有区域Region的RN信息呢我们可以通过awscli的方式,也可以通过AWSSDKforPython(Boto3)的......
  • 关于AWS-ElastiCache的Reserved nodes预留节点支付类型-费用说明
    关于AWS-ElastiCache的Reservednodes的购买(类似于EC2的RI),可以节省成本引擎,可以选择Redis或者Memcached,期限一般大多都支持1年或者3年的对于Offeringtype-产品类型,这里分类比其他产品要复杂一点、,分为【标准预留节点产品】与【旧式预留节点产品】这个还与节点类型有关系......
  • el-tree实现树形结构叶子节点和非叶子节点的区分显示的写法
    需求,非叶子节点显示主题名称+主题下的指标;叶子节点显示代码+名称1、设置prop属性<el-tree:data="dimListTree"ref="dimListTree"row-key="getGroup":props="treeProps":allow-drop="al......
  • 动力节点老杜Vue框架教程【五】Vuex
    Vue.js是一个渐进式MVVM框架,目前被广泛使用,也成为前端中最火爆的框架Vue可以按照实际需要逐步进阶使用更多特性,也是前端的必备技能动力节点老杜的Vue2+3全家桶教程已经上线咯!学习地址:https://www.bilibili.com/video/BV17h41137i4/视频将从Vue2开始讲解,一步一个案例,知识点......
  • 关于sap-hana-数据库-在pacemaker集群中迁移主控节点-master节点
    环境介绍,hana数据库的两个节点:azphxxxdb01azphxxxdb02目前master位于azphxxxdb02,现在需要切换回azphxxxdb01 需要确保Pacemaker没有任何失败的操作(通过pcs状态检查)、没有任何意外的位置约束(例如迁移测试的遗留内容),并且HANA处于同步状态,例如,使用systemReplicationStat......
  • Mysql查询父、子节点
    一、概述相信大家在实际的开发工程中,都会遇到需要依据当前节点,查询出其上级节点或下级节点的需求。下面就我在工作过程中的处理方式记录如下,如有片面之处,欢迎批评指正。二、示例表结构初始表数据如图:查看表结构和初始数据脚本DROPTABLEIFEXISTS`t_cfg_region`;CREATE......
  • Chrome devTools--节点监听
    节点监听dombreakpoints:子节点修改/属性修改/节点移除子节点修改适用场景:当鼠标移入下拉框时,下拉选项出现,想要选中下拉选项dom,修改下拉选项的dom时,却又消失了解决方案:选中body节点,监听dom字节点的修改,Breakon---》subtreemodifications,下拉选项触发时进入debugger 节......
  • 一棵广度和深度都未知的树,存储于数据库的表中,节点存储顺序随机...
     publicclassDeleteNode{publicstaticvoidmain(String[]args){Nodenode=newNode(1,1,"aa");Nodenode1=newNode(2,3,"bb");Nodenode2=newNode(3,2,"cc");Nodenode3=ne......
  • 删除链表的倒数第N个节点
    题目:给你一个链表,删除链表的倒数第 n 个结点,并且返回链表的头结点。输入:head=[1,2,3,4,5],n=2输出:[1,2,3,5]本题需要使用双指针,需要注意的点:1、双指针都指向头结点2、快指针提前移动n+1个点3、结束条件:快指针指向空指针4、慢指针指向要删除结点的前一个结点5、删除结点时......