首页 > 其他分享 >一文讲懂图神经网络GNN(剧本杀版)

一文讲懂图神经网络GNN(剧本杀版)

时间:2024-07-18 21:28:44浏览次数:13  
标签:GNN 杀版 torch self 神经网络 3.5 data 节点

1. 图神经网络(GNN)简介

图神经网络(Graph Neural Networks,GNN)是一类专门设计用于处理图结构数据的深度学习模型。它能够有效地捕捉和利用数据中的关系信息,使其在许多领域中展现出强大的潜力。为了便于各位看官易于理解GNN中节点、边和图的概念,小编这里构造一个虚拟的剧本杀游戏如下:

1.1 故事背景:

在一座古老的庄园里,富豪老爷爷突然离奇死亡。他的遗嘱中提到,只有解开庄园的秘密,才能继承他的遗产。六位与老爷爷有关的人被邀请到庄园中解开谜题。

主要人物:
詹姆斯(老管家)
玛丽(老爷爷的女儿)
托马斯(玛丽的丈夫)
莉莉(老爷爷的孙女)
亚瑟(家族律师)
威廉(老爷爷的私人医生)

1.2 人物关系图:

在这里插入图片描述

1.3 故事情节:

参与者需要在庄园中搜寻线索,解开谜题,同时调查老爷爷的死因。每个人物都有自己的秘密和动机,可能是凶手,也可能是无辜的。玩家需要通过对话、搜证和逻辑推理来揭示真相。

1.4 引入GNN图神经网络

那么在这个故事中,节点、边和图分别表示什么呢?

1.4.1 节点:

在这个剧本杀剧情中,每个人物可以被表示为一个节点,每个节点的特征可以包括:【年龄,与死者的关系,财务状况,性格特征,在案发时的不在场证明】。

1.4.2 边:

人物之间的关系可以用边来表示,边的特征可以包括:【关系类型,关系强度,互动频率,潜在冲突】

1.4.3 图:

在这里插入图片描述

1.4.4 GNN任务:

  • 节点分类:预测每个人物是否可能是凶手
  • 边预测:推测人物之间可能存在的隐藏关系
  • 图分类:根据整体图结构判断案件的类型(谋杀、意外等)
    想必看到这里大家对于GNN中节点、边和图这几个词建立了一些概念,那么接下来我们详细讲解一下什么是GNN。

2. 为什么引入 GNN

首先,大家可能会疑惑,为什么要提出GNN网络呢,现有的CNN、RNN和Transformer等网络难道还不够用吗?

2.1 GNN vs 传统神经网络

  • CNN:擅长处理规则网格数据如图像,操作简单。
  • RNN:适用于序列数据,如文本或时间序列。
  • Transformer:善于处理长距离依赖的序列数据,。

然而,这些传统模型在处理图结构数据时存在局限性。图数据的不规则性和复杂的关系结构使得传统方法难以直接应用。且传统的结构难以捕获数据间的拓扑结构。比如对于一个化学分子结构:
在这里插入图片描述
对于该领域的问题,传统的CNN、RNN或者Transformer根本就不适用,相反GNN能够捕捉数据的拓扑结构信息,其中每个原子就是一个节点,每个化学键就是一个边。

2.2 GNN 还有较为广泛的应用领域:

  • 社交网络分析:预测用户行为,检测社区结构。
  • 推荐系统:基于用户-物品图的个性化推荐。
  • 生物信息学:蛋白质结构预测,药物相互作用分析。
  • 交通流量预测:基于道路网络的交通状况分析。
  • 知识图谱:实体关系推理,知识补全。

3. 图神经网络的原理

3.1 GNN 的基本模块

  • 节点(Nodes):图中的实体,如剧本杀中的角色。
  • 边(Edges):节点之间的关系,如角色间的社交关系。
  • 图(Graph):由节点和边组成的整体结构。

3.2 邻接矩阵

在这里插入图片描述

邻接矩阵是表示图结构的数学工具。对于有 N 个节点的图,邻接矩阵 A 是一个 N×N 的矩阵:
A i j = { 1 , 如果节点 i 和节点 j 之间有边  0 , 其他情况 A_{ij} = \begin{cases} 1, & \text{如果节点 i 和节点 j 之间有边} \ 0, & \text{其他情况} \end{cases} Aij​={1,​如果节点 i 和节点 j 之间有边 0,​其他情况​
在这里插入图片描述

3.3 GNN 消息传递

消息传递机制是图神经网络的核心,它允许节点之间交换信息,从而学习到更丰富的特征表示。这个过程通常包括三个主要步骤:消息生成、消息聚合和节点更新。

3.3.1 消息生成

在这个阶段,每个节点会根据自身的特征和与邻居节点的连接关系生成消息:
在这里插入图片描述

3.3.2 消息聚合

节点收集来自其所有邻居的消息,并使用某种聚合函数(如求和、平均或最大值)将这些消息组合起来:
在这里插入图片描述

3.3.3 节点更新

基于聚合的消息和节点自身的当前状态,更新节点的特征表示:
在这里插入图片描述

3.3.4 更新过程

节点 B 和 C 生成发送给 A 的消息
A 聚合来自 B 和 C 的消息
A 基于聚合的消息和自身当前状态更新其特征
在这里插入图片描述

3.4 GNN 实例

考虑以下简单的无向图:
在这里插入图片描述
每个节点初始有一个标量特征值(如A:1表示节点A的初始特征值为1)。

3.4.1 GNN 实例使用的函数

  • 消息函数(MSG):取发送节点和接收节点特征的平均值
  • 聚合函数(AGGREGATE):对所有接收到的消息求和
  • 更新函数(UPDATE):将聚合的消息加到当前节点特征上

3.4.2 第一轮消息传递

  • 消息生成
    对于每条边,计算消息:
    A -> B: MSG(1, 2) = (1 + 2) / 2 = 1.5
    A -> C: MSG(1, 3) = (1 + 3) / 2 = 2
    B -> D: MSG(2, 4) = (2 + 4) / 2 = 3
    C -> D: MSG(3, 4) = (3 + 4) / 2 = 3.5

  • 消息聚合
    每个节点聚合收到的所有消息:

    A: AGGREGATE(1.5, 2) = 3.5
    B: AGGREGATE(1.5) = 1.5
    C: AGGREGATE(2) = 2
    D: AGGREGATE(3, 3.5) = 6.5

  • 节点更新
    更新每个节点的特征:

    A: UPDATE(1, 3.5) = 1 + 3.5 = 4.5
    B: UPDATE(2, 1.5) = 2 + 1.5 = 3.5
    C: UPDATE(3, 2) = 3 + 2 = 5
    D: UPDATE(4, 6.5) = 4 + 6.5 = 10.5

  • 第一轮后的图结构
    在这里插入图片描述

3.4.3 第二轮消息传递

  • 消息生成
    A -> B: MSG(4.5, 3.5) = (4.5 + 3.5) / 2 = 4
    A -> C: MSG(4.5, 5) = (4.5 + 5) / 2 = 4.75
    B -> D: MSG(3.5, 10.5) = (3.5 + 10.5) / 2 = 7
    C -> D: MSG(5, 10.5) = (5 + 10.5) / 2 = 7.75
  • 消息聚合
    A: AGGREGATE(4, 4.75) = 8.75
    B: AGGREGATE(4) = 4
    C: AGGREGATE(4.75) = 4.75
    D: AGGREGATE(7, 7.75) = 14.75
  • 节点更新
    A: UPDATE(4.5, 8.75) = 4.5 + 8.75 = 13.25
    B: UPDATE(3.5, 4) = 3.5 + 4 = 7.5
    C: UPDATE(5, 4.75) = 5 + 4.75 = 9.75
    D: UPDATE(10.5, 14.75) = 10.5 + 14.75 = 25.25
  • 第二轮后的图结构
    在这里插入图片描述

3.4.4 结果分析

  • 信息传播:可以看到,经过两轮消息传递,每个节点的特征值都发生了显著变化。这反映了图中信息的传播过程。
  • 中心性:节点D的特征值增长最快,这反映了它在图中的中心地位(连接度最高)。
  • 特征融合:每个节点的新特征不仅包含了自身的信息,还融合了邻居节点的信息。

4. 实例代码

下面是一个完整的 PyTorch 代码,包含了模型定义、数据加载和训练过程:
首先,我们需要安装必要的库:

pip install torch torch_geometric

完整代码如下:

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

class GCN(torch.nn.Module):
    def __init__(self, num_features, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_features, 16)
        self.conv2 = GCNConv(16, 16)
        self.classifier = torch.nn.Linear(16, num_classes)

    def forward(self, x, edge_index):
        h = F.relu(self.conv1(x, edge_index))
        h = F.dropout(h, p=0.5, training=self.training)
        h = self.conv2(h, edge_index)
        h = F.dropout(h, p=0.5, training=self.training)
        out = self.classifier(h)
        return out

# 加载Cora数据集
dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=NormalizeFeatures())
data = dataset[0]

# 初始化模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN(num_features=dataset.num_features, num_classes=dataset.num_classes).to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

# 训练模型
model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    
    if (epoch + 1) % 10 == 0:
        print(f'Epoch {epoch+1:3d}, Loss: {loss.item():.4f}')

# 评估模型
model.eval()
pred = model(data.x, data.edge_index).argmax(dim=1)
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
acc = int(correct) / int(data.test_mask.sum())
print(f'Accuracy: {acc:.4f}')

本文只是让各位看官对GNN有个初步认识,便于大家理解更加复杂的结构,如需要更加深刻的认识还是建议去看相关论文。

标签:GNN,杀版,torch,self,神经网络,3.5,data,节点
From: https://blog.csdn.net/m0_59257547/article/details/140529230

相关文章

  • 脑机接口--BP神经网络预测
    BP神经网络的原理见下文,这里主要讲两种方式(代码形式和工具箱形式)             1.BP神经网络的简介和结构参数神经网络是机器学习中一种常见的数学模型,通过构建类似于大脑神经突触联接的结构,来进行信息处理。在应用神经网络的过程中,处理信息的单元......
  • 北京交通大学《深度学习》专业课,实验3卷积、空洞卷积、残差神经网络实验
    一、实验要求1.二维卷积实验(平台课与专业课要求相同)⚫手写二维卷积的实现,并在至少一个数据集上进行实验,从训练时间、预测精度、Loss变化等角度分析实验结果(最好使用图表展示)⚫使用torch.nn实现二维卷积,并在至少一个数据集上进行实验,从训练时间、预测精度、Loss变化等角......
  • Datawhale AI夏令营第二期——机器学习 基于神经网络stack融合策略的多模型融合
    #AI夏令营#Datawhale夏令营基于神经网络stack融合策略的多模型融合改进点:1.数据清洗,异常值替换(板块2)2.基于神经网络的stack模型融合(板块5)根据大佬的提示对Task3所做的改进,大佬链接:http://t.csdnimg.cn/RSC3o1.模型导入导入所需要包:importpandasaspdimportnumpy......
  • 前向反馈神经网络模型训练过程
    https://www.mrdbourke.com/the-unofficial-pytorch-optimization-loop-song/Unofficialpytorchoptimizationsong:ForanepochinarangeCallmodeldottrainDotheforwardpassCalculatethelossOptimizerzerogradLossssssbackwardOptimizerstepstepst......
  • 十五、【机器学习】【监督学习】- 神经网络回归
    系列文章目录第一章【机器学习】初识机器学习第二章【机器学习】【监督学习】-逻辑回归算法(LogisticRegression)第三章【机器学习】【监督学习】-支持向量机(SVM)第四章【机器学习】【监督学习】-K-近邻算法(K-NN)第五章【机器学习】【监督学习】-决策树(Dec......
  • P27-P47构建神经网络进化智能体-构建用于训练强化学习之鞥提的随机环境-构建基于价值
    文章目录构建神经网络进化智能体前期准备实现步骤工作原理参考资料第二章基于价值、策略和行动者-评论家的深度强化学习算法实现技术要求构建用于训练强化学习智能体的随机环境前期准备实现步骤工作原理构建基于价值的强化学习智能体算法前期准备实现步骤工作原理......
  • 【matlab】智能优化算法优化BP神经网络
    目录引言一、BP神经网络简介二、智能优化算法概述三、智能优化算法优化BP神经网络的方法四、蜣螂优化算法案例1、算法来源2、算法描述3、算法性能结果仿真代码实现引言智能优化算法优化BP神经网络是一个重要的研究领域,旨在通过智能算法提高BP神经网络的性能和......
  • Reinforced Causal Explainer for GNN论文笔记
    论文:TPAMI2023 图神经网络的强化因果解释器论文代码地址:代码目录AbstractIntroductionPRELIMINARIESCausalAttributionofaHolisticSubgraph​individualcausaleffect(ICE)​*CausalScreeningofanEdgeSequenceReinforcedCausalExplainer(RC-Explaine......
  • CEEMDAN-VMD-CNN-LSTM二次分解结合卷积双向长短期记忆神经网络多变量时序预测(Matlab完
    CEEMDAN-VMD-CNN-LSTM二次分解结合卷积长短期记忆神经网络多变量时序预测(Matlab完整源码和数据)CEEMDAN分解,计算样本熵,根据样本熵进行kmeans聚类,调用VMD对高频分量Co-IMF1二次分解,VMD分解的高频分量与Co_IMF2;Co_IMF3分量作为卷积长短期记忆神经网络模型的目标输出分别预测......
  • 大白话【卷积神经网络】工作原理
    卷积神经网络(ConvolutionalNeuralNetwork,简称CNN)是一种专门设计用于处理具有网格结构的数据(如图像)的神经网络。想象一下,你正在玩一个游戏,游戏的目标是识别图片上的内容。但是,你不能直接看到整个图片,而只能通过一个小窗口(称为“滤波器”或“卷积核”)来观察图片的一部分。每次......