1. 图神经网络(GNN)简介
图神经网络(Graph Neural Networks,GNN)是一类专门设计用于处理图结构数据的深度学习模型。它能够有效地捕捉和利用数据中的关系信息,使其在许多领域中展现出强大的潜力。为了便于各位看官易于理解GNN中节点、边和图的概念,小编这里构造一个虚拟的剧本杀游戏如下:
1.1 故事背景:
在一座古老的庄园里,富豪老爷爷突然离奇死亡。他的遗嘱中提到,只有解开庄园的秘密,才能继承他的遗产。六位与老爷爷有关的人被邀请到庄园中解开谜题。
主要人物:
詹姆斯(老管家)
玛丽(老爷爷的女儿)
托马斯(玛丽的丈夫)
莉莉(老爷爷的孙女)
亚瑟(家族律师)
威廉(老爷爷的私人医生)
1.2 人物关系图:
1.3 故事情节:
参与者需要在庄园中搜寻线索,解开谜题,同时调查老爷爷的死因。每个人物都有自己的秘密和动机,可能是凶手,也可能是无辜的。玩家需要通过对话、搜证和逻辑推理来揭示真相。
1.4 引入GNN图神经网络
那么在这个故事中,节点、边和图分别表示什么呢?
1.4.1 节点:
在这个剧本杀剧情中,每个人物可以被表示为一个节点,每个节点的特征可以包括:【年龄,与死者的关系,财务状况,性格特征,在案发时的不在场证明】。
1.4.2 边:
人物之间的关系可以用边来表示,边的特征可以包括:【关系类型,关系强度,互动频率,潜在冲突】
1.4.3 图:
1.4.4 GNN任务:
- 节点分类:预测每个人物是否可能是凶手
- 边预测:推测人物之间可能存在的隐藏关系
- 图分类:根据整体图结构判断案件的类型(谋杀、意外等)
想必看到这里大家对于GNN中节点、边和图这几个词建立了一些概念,那么接下来我们详细讲解一下什么是GNN。
2. 为什么引入 GNN
首先,大家可能会疑惑,为什么要提出GNN网络呢,现有的CNN、RNN和Transformer等网络难道还不够用吗?
2.1 GNN vs 传统神经网络
- CNN:擅长处理规则网格数据如图像,操作简单。
- RNN:适用于序列数据,如文本或时间序列。
- Transformer:善于处理长距离依赖的序列数据,。
然而,这些传统模型在处理图结构数据时存在局限性。图数据的不规则性和复杂的关系结构使得传统方法难以直接应用。且传统的结构难以捕获数据间的拓扑结构。比如对于一个化学分子结构:
对于该领域的问题,传统的CNN、RNN或者Transformer根本就不适用,相反GNN能够捕捉数据的拓扑结构信息,其中每个原子就是一个节点,每个化学键就是一个边。
2.2 GNN 还有较为广泛的应用领域:
- 社交网络分析:预测用户行为,检测社区结构。
- 推荐系统:基于用户-物品图的个性化推荐。
- 生物信息学:蛋白质结构预测,药物相互作用分析。
- 交通流量预测:基于道路网络的交通状况分析。
- 知识图谱:实体关系推理,知识补全。
3. 图神经网络的原理
3.1 GNN 的基本模块
- 节点(Nodes):图中的实体,如剧本杀中的角色。
- 边(Edges):节点之间的关系,如角色间的社交关系。
- 图(Graph):由节点和边组成的整体结构。
3.2 邻接矩阵
邻接矩阵是表示图结构的数学工具。对于有 N 个节点的图,邻接矩阵 A 是一个 N×N 的矩阵:
A
i
j
=
{
1
,
如果节点 i 和节点 j 之间有边
0
,
其他情况
A_{ij} = \begin{cases} 1, & \text{如果节点 i 和节点 j 之间有边} \ 0, & \text{其他情况} \end{cases}
Aij={1,如果节点 i 和节点 j 之间有边 0,其他情况
3.3 GNN 消息传递
消息传递机制是图神经网络的核心,它允许节点之间交换信息,从而学习到更丰富的特征表示。这个过程通常包括三个主要步骤:消息生成、消息聚合和节点更新。
3.3.1 消息生成
在这个阶段,每个节点会根据自身的特征和与邻居节点的连接关系生成消息:
3.3.2 消息聚合
节点收集来自其所有邻居的消息,并使用某种聚合函数(如求和、平均或最大值)将这些消息组合起来:
3.3.3 节点更新
基于聚合的消息和节点自身的当前状态,更新节点的特征表示:
3.3.4 更新过程
节点 B 和 C 生成发送给 A 的消息
A 聚合来自 B 和 C 的消息
A 基于聚合的消息和自身当前状态更新其特征
3.4 GNN 实例
考虑以下简单的无向图:
每个节点初始有一个标量特征值(如A:1表示节点A的初始特征值为1)。
3.4.1 GNN 实例使用的函数
- 消息函数(MSG):取发送节点和接收节点特征的平均值
- 聚合函数(AGGREGATE):对所有接收到的消息求和
- 更新函数(UPDATE):将聚合的消息加到当前节点特征上
3.4.2 第一轮消息传递
-
消息生成:
对于每条边,计算消息:
A -> B: MSG(1, 2) = (1 + 2) / 2 = 1.5
A -> C: MSG(1, 3) = (1 + 3) / 2 = 2
B -> D: MSG(2, 4) = (2 + 4) / 2 = 3
C -> D: MSG(3, 4) = (3 + 4) / 2 = 3.5 -
消息聚合:
每个节点聚合收到的所有消息:A: AGGREGATE(1.5, 2) = 3.5
B: AGGREGATE(1.5) = 1.5
C: AGGREGATE(2) = 2
D: AGGREGATE(3, 3.5) = 6.5 -
节点更新:
更新每个节点的特征:A: UPDATE(1, 3.5) = 1 + 3.5 = 4.5
B: UPDATE(2, 1.5) = 2 + 1.5 = 3.5
C: UPDATE(3, 2) = 3 + 2 = 5
D: UPDATE(4, 6.5) = 4 + 6.5 = 10.5 -
第一轮后的图结构:
3.4.3 第二轮消息传递
- 消息生成:
A -> B: MSG(4.5, 3.5) = (4.5 + 3.5) / 2 = 4
A -> C: MSG(4.5, 5) = (4.5 + 5) / 2 = 4.75
B -> D: MSG(3.5, 10.5) = (3.5 + 10.5) / 2 = 7
C -> D: MSG(5, 10.5) = (5 + 10.5) / 2 = 7.75 - 消息聚合:
A: AGGREGATE(4, 4.75) = 8.75
B: AGGREGATE(4) = 4
C: AGGREGATE(4.75) = 4.75
D: AGGREGATE(7, 7.75) = 14.75 - 节点更新:
A: UPDATE(4.5, 8.75) = 4.5 + 8.75 = 13.25
B: UPDATE(3.5, 4) = 3.5 + 4 = 7.5
C: UPDATE(5, 4.75) = 5 + 4.75 = 9.75
D: UPDATE(10.5, 14.75) = 10.5 + 14.75 = 25.25 - 第二轮后的图结构:
3.4.4 结果分析
- 信息传播:可以看到,经过两轮消息传递,每个节点的特征值都发生了显著变化。这反映了图中信息的传播过程。
- 中心性:节点D的特征值增长最快,这反映了它在图中的中心地位(连接度最高)。
- 特征融合:每个节点的新特征不仅包含了自身的信息,还融合了邻居节点的信息。
4. 实例代码
下面是一个完整的 PyTorch 代码,包含了模型定义、数据加载和训练过程:
首先,我们需要安装必要的库:
pip install torch torch_geometric
完整代码如下:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
class GCN(torch.nn.Module):
def __init__(self, num_features, num_classes):
super(GCN, self).__init__()
self.conv1 = GCNConv(num_features, 16)
self.conv2 = GCNConv(16, 16)
self.classifier = torch.nn.Linear(16, num_classes)
def forward(self, x, edge_index):
h = F.relu(self.conv1(x, edge_index))
h = F.dropout(h, p=0.5, training=self.training)
h = self.conv2(h, edge_index)
h = F.dropout(h, p=0.5, training=self.training)
out = self.classifier(h)
return out
# 加载Cora数据集
dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=NormalizeFeatures())
data = dataset[0]
# 初始化模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN(num_features=dataset.num_features, num_classes=dataset.num_classes).to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
# 训练模型
model.train()
for epoch in range(200):
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
if (epoch + 1) % 10 == 0:
print(f'Epoch {epoch+1:3d}, Loss: {loss.item():.4f}')
# 评估模型
model.eval()
pred = model(data.x, data.edge_index).argmax(dim=1)
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
acc = int(correct) / int(data.test_mask.sum())
print(f'Accuracy: {acc:.4f}')
本文只是让各位看官对GNN有个初步认识,便于大家理解更加复杂的结构,如需要更加深刻的认识还是建议去看相关论文。
标签:GNN,杀版,torch,self,神经网络,3.5,data,节点 From: https://blog.csdn.net/m0_59257547/article/details/140529230