首页 > 其他分享 >图神经网络GNN实践入门

图神经网络GNN实践入门

时间:2024-06-06 20:33:27浏览次数:24  
标签:index 入门 GNN train 神经网络 edge data 节点

参考视频网址:https://www.bilibili.com/video/BV1MP41187pv/?spm_id_from=333.999.0.0&vd_source=590f4019caa7ed7b4e57c0e869ad0867

文章目录


图神经网络GNN

GNN是一个广义的术语,涵盖了所有能够处理图结构数据的神经网络模型。GNN的主要思想是利用图的结构信息来更新和学习节点的特征。GNN的设计目标是处理图数据中的节点特征和边信息,捕捉图中节点之间的复杂关系。


一、GNN的优势

1、处理非欧几里得数据

  • 图结构数据:GNN能够处理图结构数据,而传统的神经网络(如CNN和RNN)主要设计用于处理欧几里得数据(如图像和序列)。
  • 灵活性:GNN可以自然地处理社交网络、分子结构、知识图谱等非欧几里得数据,这些数据具有复杂的节点和边关系。

2、捕捉节点间的复杂关系

  • 节点间依赖:GNN能够有效地捕捉图中节点之间的依赖关系,通过聚合邻居节点的信息来更新节点特征。
  • 全局信息:通过多层传播,GNN可以捕捉到全图的结构信息,从而提供比局部特征更丰富的表示。

3、信息聚合和传递

  • 多层信息传播:通过多层GNN,节点特征可以逐层传播和聚合,综合来自不同邻居的信息。
  • 自适应性:GNN能够自适应地调整每个节点的特征更新方式,以更好地适应图的结构和节点的特性。

4、适用于各种图相关任务

  • 节点分类:在图上进行节点分类(如社交网络中的用户分类)。
  • 边预测:预测图中可能存在的边(如推荐系统中的物品推荐)。
  • 图分类:对整个图进行分类(如化学分子结构的功能预测)。
  • 社区发现:检测图中的社区结构(如社交网络中的群体划分)。

二、GNN基础

1.图的基本组成

有向图和相应的邻接矩阵的示例,其中左侧部分是六个节点的连接,右侧部分是邻接矩阵

  • 有向图和相应的邻接矩阵如上图所示,其中左侧部分是六个节点的连接,右侧部分是邻接矩阵

2.Pytorch Geometric

三、pytorch_geometric 基本使用

from torch_geometric.datasets import KarateClub
#加载数据集
dataset = KarateClub()
print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of graphs: {dataset.num_classes}')
#结果
Dataset: KarateClub():
======================
Number of graphs: 1
Number of features: 34
Number of graphs: 4
#获取第一个图形对象
data = dataset[0]	#这里的dataset中只有一个图
print(data)
#结果
Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34])
  • 这里图的表示用Data格式,查看参考文档
  • x = [34,34]:第一个值表示样本个数,第二个值表示特征维度
  • edge_index = [2, 156]:一般都是2*边的个数,2指的是源点和目标点
  • y = [34]:标签数
  • train_mask = [34]:表示哪些节点用于训练
#查看edge_index
edge_index = data.edge_index
print(edge_index.t())
#部分结果示例
tensor([[ 0,  1],
        [ 0,  2],
        [ 0,  3],
        ...
        [13,  2],
        [13,  3],
        ...
        [33, 32]])
  • 由结果可以看出节点之间的关系,例如0节点与1,2,3…相关,13与2,3…相关
  • edge_index是稀疏表示的,并不是n*n的稀疏矩阵

1、使用networks可视化展示

import networkx as nx
import matplotlib.pyplot as plt

#可视化图G并对节点进行着色
def visualize_graph(G, color):
    plt.figure(figsize=(7, 7))
    plt.xticks([])
    plt.yticks([])
    nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False,
                     node_color = color, cmap = "Set2")
    plt.show()
from torch_geometric.utils import to_networkx

G = to_networkx(data, to_undirected=True)
visualize_graph(G, color=data.y)

结果如下图所示:
在这里插入图片描述

2、Graph Neutral Networks网络定义

  • GCN的核心公式:节点的特征更新主要通过以下公式完成

    在这里插入图片描述

  • 其中A是图的邻接矩阵,D是度矩阵,H是特征矩阵

  • PyG文档 GCNConv

  • 1)模型的搭建:

import torch
from torch.nn import Linear
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        torch.manual_seed(1234)
        self.conv1 = GCNConv(dataset.num_features, 4)   #只需要定义好输入特征和输出特征即可
        self.conv2 = GCNConv(4, 4)
        self.conv3 = GCNConv(4, 2)
        self.classifier = Linear(2, dataset.num_classes)

    def forward(self, x, edge_index):
        h = self.conv1(x, edge_index)   #输入特征和邻接矩阵
        h = h.tanh()
        h = self.conv2(h, edge_index)
        h = h.tanh()
        h = self.conv3(h, edge_index)
        h = h.tanh()

        #分类层
        out = self.classifier(h)

        return out, h

#打印模型
model = GCN()
print(model)
#结果
GCN(
  (conv1): GCNConv(34, 4)
  (conv2): GCNConv(4, 4)
  (conv3): GCNConv(4, 2)
  (classifier): Linear(in_features=2, out_features=4, bias=True)
)
  • 2)模型输入特征展示:
def visualize_embedding(h, color, epoch = None, loss = None):
    plt.figure(figsize=(7, 7))
    plt.xticks([])
    plt.yticks([])
    h = h.detach().cpu().numpy()
    plt.scatter(h[:, 0], h[:, 1], s = 140, c = color, cmap = "Set2")
    if epoch is not None and loss is not None:
        plt.xlabel(f'Epoch: {epoch}, Loss: {loss.item():.4f}', fontsize = 16)
    plt.show()
model = GCN()

_, h = model(data.x, data.edge_index)
print(f'Embedding shape: {list(h.shape)}')

visualize_embedding(h, color=data.y)
  • ‘ _ , h = model(data.x, data.edge_index)’中使用 ‘_, h’表示只对第二个返回值h感兴趣,并忽略第一个返回值。
#结果
Embedding shape: [34, 2]

初始特征如下图所示:

在这里插入图片描述

  • 3、训练模型及模型输出特征展示:
import time

model = GCN()
criterion = torch.nn.CrossEntropyLoss()     # 定义交叉熵损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)   # 定义 Adam 优化器

def train(data):
    optimizer.zero_grad()   # 清除梯度
    out, h = model(data.x, data.edge_index) #h是二维向量(主要为了画图)
    loss = criterion(out[data.train_mask], data.y[data.train_mask]) # 计算损失
    loss.backward()     # 反向传播计算梯度
    optimizer.step()    # 更新参数
    return loss, h

for epoch in range(401):
    loss, h = train(data)   # 训练一个 epoch
    if epoch % 100 == 0:     # 每隔 100 个 epoch 可视化一次嵌入
        visualize_embedding(h, color=data.y, epoch = epoch, loss = loss)
        time.sleep(0.3)     # 暂停 0.3 秒以便于观察

其中,loss = criterion(out[data.train_mask], data.y[data.train_mask])这行代码的作用是:

  • 提取训练集节点的预测结果 out[data.train_mask],形状为 [num_train_nodes, num_classes];
  • 提取训练集节点的实际标签 data.y[data.train_mask],形状为 [num_train_nodes];
  • 使用交叉熵损失函数计算预测结果和实际标签之间的损失。

输出结果如下图所示:

在这里插入图片描述

总结

以上就是对图神经网络的基本介绍以及相关库的简单应用,但实际应用中输入网络的数据集并不一定满足网络输入的要求,还需要进一步学习。

标签:index,入门,GNN,train,神经网络,edge,data,节点
From: https://blog.csdn.net/qq_43798150/article/details/139502463

相关文章

  • ChatGPT Prompt技术全攻略-入门篇:AI提示工程基础
    系列篇章......
  • 【入门教程】5分钟教你快速学会集成Java springboot ~
    介绍ApacheDolphinScheduler是一个分布式易扩展的开源分布式调度系统,支持海量数据处理,具有任务流程调度、任务流程编排、任务监控告警、工作流引擎等功能。本文将介绍如何将ApacheDolphinScheduler集成到JavaSpringboot项目中,以实现更灵活和便捷的调度功能。步骤步骤一:添......
  • Netty 快速入门
    什么是NettyNetty的官网:[https://netty.io/Netty是一个JavaNIO技术的开源异步事件驱动的网络编程框架,用于快速开发可维护的高性能协议服务器和客户端。往通俗了讲,可以将Netty理解为:一个将JavaNIO进行了大量封装,并大大降低JavaNIO使用难度和上手门槛的网络编程框架。Net......
  • FastAPI-3:快速入门
    3快速入门第二章是python基础,故不做介绍。FastAPI是一个现代、快速(高性能)的网络框架,用于使用基于标准Python类型提示的Python3.6+构建API。FastAPI的创建者是SebastiánRamírez。FastAPI由SebastiánRamírez于2018年发布。与大多数PythonWeb框架相比,它在很多方面都更......
  • Android视频开发入门: VideoView、MediaPlayer、 FFmpeg、exoplayer...
    现在,视频功能是越来越普遍的需求。本文将提供一个关于Android视频开发的入门指南,帮助读者快速掌握视频播放、录制和处理等基本功能。1、概述在Android平台上,视频开发主要涉及以下几个方面:视频播放与控制视频录制与处理视频编解码与格式转换视频流媒体与直播接下来,我......
  • 001__C语言程序入门
    一、第一个程序:helloworld配置部署好vsCode之后,就可以直接在上面写代码了,新建一个新的C程序文件,向屏幕输出一串字符“HelloWorld!”下面,从整体上来分析一下这个最简单的C语言程序,将这个最简程序的各个部分剖析清楚,明白我们写下的每一个字符的具体含义。二、C语言的基本结......
  • 探索Adobe XD:高效UI设计软件的中文入门教程
    在这个数字化世界里,创意设计不仅是为了吸引观众的注意,也是用户体验的核心部分。强大的设计工具可以帮助设计师创造出明亮的视觉效果,从而提高用户体验。一、AdobeXD是什么?AdobeXD是一家知名软件公司AdobeSystems用户体验和用户界面设计软件的制作和发布。软件可以帮助设......
  • SQL—数据库查询语言,全面详解演示,入门进阶必会
    文章目录一、基础二、创建表三、修改表四、插入五、更新六、删除七、查询DISTINCTLIMIT八、排序九、过滤十、通配符十一、计算字段十二、函数汇总文本处理日期和时间处理数值处理十三、分组十四、子查询十五、连接内连接自连接自然连接外连接十六、组合查询十七、视图......
  • GitHub 常用操作与常用命令——GitHub入门,看这一文就够了
    文章目录GitHub常用操作in关键词限制搜索范围:stars或fork数量关键词查找:awesome加强搜索:高亮显示某一行的代码:项目内搜索:显示快捷键:Git常用命令初始化命令查看当前git配置信息:设置提交代码时的用户信息在当前目录新建一个Git代码库下载一个项目和它的整个代码版本与......
  • spring入门aop和ioc基于注解
    目录用注解代替xml文件中的部分配置请先观看链接用注解代替xml文件中的部分配置在要注册bean的地方添加注解@Component()不指定名字就是类名的首字母小写@Component("name")bean的名字就是括号中指定的值在注册完以后要开始注册扫描<!--重点是开启注解扫描-->......