文章目录
图神经网络GNN
GNN是一个广义的术语,涵盖了所有能够处理图结构数据的神经网络模型。GNN的主要思想是利用图的结构信息来更新和学习节点的特征。GNN的设计目标是处理图数据中的节点特征和边信息,捕捉图中节点之间的复杂关系。
一、GNN的优势
1、处理非欧几里得数据
- 图结构数据:GNN能够处理图结构数据,而传统的神经网络(如CNN和RNN)主要设计用于处理欧几里得数据(如图像和序列)。
- 灵活性:GNN可以自然地处理社交网络、分子结构、知识图谱等非欧几里得数据,这些数据具有复杂的节点和边关系。
2、捕捉节点间的复杂关系
- 节点间依赖:GNN能够有效地捕捉图中节点之间的依赖关系,通过聚合邻居节点的信息来更新节点特征。
- 全局信息:通过多层传播,GNN可以捕捉到全图的结构信息,从而提供比局部特征更丰富的表示。
3、信息聚合和传递
- 多层信息传播:通过多层GNN,节点特征可以逐层传播和聚合,综合来自不同邻居的信息。
- 自适应性:GNN能够自适应地调整每个节点的特征更新方式,以更好地适应图的结构和节点的特性。
4、适用于各种图相关任务
- 节点分类:在图上进行节点分类(如社交网络中的用户分类)。
- 边预测:预测图中可能存在的边(如推荐系统中的物品推荐)。
- 图分类:对整个图进行分类(如化学分子结构的功能预测)。
- 社区发现:检测图中的社区结构(如社交网络中的群体划分)。
二、GNN基础
1.图的基本组成
- 有向图和相应的邻接矩阵如上图所示,其中左侧部分是六个节点的连接,右侧部分是邻接矩阵
2.Pytorch Geometric
- 这个库实现了图神经网络中的各种方法
- github网址:https://github.com/pyg-team/pytorch_geometric?tab=readme-ov-file
三、pytorch_geometric 基本使用
- 数据集API:https://www.journals.uchicago.edu/doi/abs/10.1086/jar.33.4.3629752
- 在Jupyter notebook环境中实现
from torch_geometric.datasets import KarateClub
#加载数据集
dataset = KarateClub()
print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of graphs: {dataset.num_classes}')
#结果
Dataset: KarateClub():
======================
Number of graphs: 1
Number of features: 34
Number of graphs: 4
#获取第一个图形对象
data = dataset[0] #这里的dataset中只有一个图
print(data)
#结果
Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34])
- 这里图的表示用Data格式,查看参考文档
- x = [34,34]:第一个值表示样本个数,第二个值表示特征维度
- edge_index = [2, 156]:一般都是2*边的个数,2指的是源点和目标点
- y = [34]:标签数
- train_mask = [34]:表示哪些节点用于训练
#查看edge_index
edge_index = data.edge_index
print(edge_index.t())
#部分结果示例
tensor([[ 0, 1],
[ 0, 2],
[ 0, 3],
...
[13, 2],
[13, 3],
...
[33, 32]])
- 由结果可以看出节点之间的关系,例如0节点与1,2,3…相关,13与2,3…相关
- edge_index是稀疏表示的,并不是n*n的稀疏矩阵
1、使用networks可视化展示
import networkx as nx
import matplotlib.pyplot as plt
#可视化图G并对节点进行着色
def visualize_graph(G, color):
plt.figure(figsize=(7, 7))
plt.xticks([])
plt.yticks([])
nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False,
node_color = color, cmap = "Set2")
plt.show()
from torch_geometric.utils import to_networkx
G = to_networkx(data, to_undirected=True)
visualize_graph(G, color=data.y)
结果如下图所示:
2、Graph Neutral Networks网络定义
-
GCN的核心公式:节点的特征更新主要通过以下公式完成
-
其中A是图的邻接矩阵,D是度矩阵,H是特征矩阵
-
PyG文档 GCNConv
-
1)模型的搭建:
import torch
from torch.nn import Linear
from torch_geometric.nn import GCNConv
class GCN(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
torch.manual_seed(1234)
self.conv1 = GCNConv(dataset.num_features, 4) #只需要定义好输入特征和输出特征即可
self.conv2 = GCNConv(4, 4)
self.conv3 = GCNConv(4, 2)
self.classifier = Linear(2, dataset.num_classes)
def forward(self, x, edge_index):
h = self.conv1(x, edge_index) #输入特征和邻接矩阵
h = h.tanh()
h = self.conv2(h, edge_index)
h = h.tanh()
h = self.conv3(h, edge_index)
h = h.tanh()
#分类层
out = self.classifier(h)
return out, h
#打印模型
model = GCN()
print(model)
#结果
GCN(
(conv1): GCNConv(34, 4)
(conv2): GCNConv(4, 4)
(conv3): GCNConv(4, 2)
(classifier): Linear(in_features=2, out_features=4, bias=True)
)
- 2)模型输入特征展示:
def visualize_embedding(h, color, epoch = None, loss = None):
plt.figure(figsize=(7, 7))
plt.xticks([])
plt.yticks([])
h = h.detach().cpu().numpy()
plt.scatter(h[:, 0], h[:, 1], s = 140, c = color, cmap = "Set2")
if epoch is not None and loss is not None:
plt.xlabel(f'Epoch: {epoch}, Loss: {loss.item():.4f}', fontsize = 16)
plt.show()
model = GCN()
_, h = model(data.x, data.edge_index)
print(f'Embedding shape: {list(h.shape)}')
visualize_embedding(h, color=data.y)
- ‘ _ , h = model(data.x, data.edge_index)’中使用 ‘_, h’表示只对第二个返回值h感兴趣,并忽略第一个返回值。
#结果
Embedding shape: [34, 2]
初始特征如下图所示:
- 3、训练模型及模型输出特征展示:
import time
model = GCN()
criterion = torch.nn.CrossEntropyLoss() # 定义交叉熵损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) # 定义 Adam 优化器
def train(data):
optimizer.zero_grad() # 清除梯度
out, h = model(data.x, data.edge_index) #h是二维向量(主要为了画图)
loss = criterion(out[data.train_mask], data.y[data.train_mask]) # 计算损失
loss.backward() # 反向传播计算梯度
optimizer.step() # 更新参数
return loss, h
for epoch in range(401):
loss, h = train(data) # 训练一个 epoch
if epoch % 100 == 0: # 每隔 100 个 epoch 可视化一次嵌入
visualize_embedding(h, color=data.y, epoch = epoch, loss = loss)
time.sleep(0.3) # 暂停 0.3 秒以便于观察
其中,loss = criterion(out[data.train_mask], data.y[data.train_mask])这行代码的作用是:
- 提取训练集节点的预测结果 out[data.train_mask],形状为 [num_train_nodes, num_classes];
- 提取训练集节点的实际标签 data.y[data.train_mask],形状为 [num_train_nodes];
- 使用交叉熵损失函数计算预测结果和实际标签之间的损失。
输出结果如下图所示:
总结
以上就是对图神经网络的基本介绍以及相关库的简单应用,但实际应用中输入网络的数据集并不一定满足网络输入的要求,还需要进一步学习。
标签:index,入门,GNN,train,神经网络,edge,data,节点 From: https://blog.csdn.net/qq_43798150/article/details/139502463