首页 > 其他分享 >PyG教程:MessagePassing基类

PyG教程:MessagePassing基类

时间:2024-11-29 20:30:00浏览次数:9  
标签:index 函数 torch MessagePassing edge 基类 data 节点 PyG

PyG教程:MessagePassing基类

一、引言

PyG框架中提供了一个消息传递基类torch_geometric.nn.MessagePassing,它实现了消息传递的自动处理,继承该类可以简单方便的构建自己的消息传播GNN。

二、如何自定义消息传递网络

要自定义GNN模型,首先需要继承MessagePassing类,然后重写如下方法:

  • message(...):构建要传递的消息;
  • aggregate(...):将从源节点传递过来的消息聚合到目标结点;
  • update(...):更新节点的消息。

上述方法并不是一定都要自定义,若MessagePassing类默认实现满足你的需求,则可以不重写。

1.构造函数

继承MessagePassing类后,在构造函数中可以通过super().__init__方法来向基类MessagePassing传递参数,来指定消息传递过程中的一些行为。MessagePassing类的初始化函数如下:
在这里插入图片描述
参数说明:

参数名参数说明
aggr消息传递中的消息聚合方式,常用的包括summeanminmaxmul等等。default: sum
flow消息传播的方向,其中source_to_targe表示从源节点到目标节点、target_to_source表示从目标节点到源节点。default:source_to_target
node_dim传播的维度,default:-2
decomposed_layers这个参数没用过,我也还不知道,后面会更新。

2.propagate函数

在具体介绍消息传递的三个相关函数之前,首先先介绍propagate函数,该函数是消息传递的启动函数,调用该函数后依次会执行messageaggregateudpate函数来完成消息的传递聚合更新。该函数的声明如下:
在这里插入图片描述
参数说明:

参数名参数说明
edge_index边索引
size这个参数目前我理解的不是很透彻,后面透彻了补一下
**kwargs构建、聚合和更新消息所需的额外数据,都可以传入propagate函数,这些参数可以在消息传递过程中的三个函数中接收。

该函数一般会传入edge_index和特征x

3.message函数

message函数是用来构建节点的消息的。传递给propagate函数的tensor可以映射到中心(target)节点邻居(source)节点上,只需要在相应变量名后加上_ior_j即可,通常称_i为中心(target)节点,称_j为邻居(source)节点。

source节点和target节点的关系:
在这里插入图片描述
message实现源码:
在这里插入图片描述

从源码的默认实现可以看出,message传递的消息就是邻居节点自身的特征向量。

示例:

def forward(self, data):
	out = self.propagate(edge_index, x=x)
	pass

def message(self, x_i, x_j, edge_index_i, edge_index_j):
	pass

该例子中利用propagate函数传递了两个参数edge_indexx,则message函数可以根据propagate函数中的两个参数构造自己的参数,上述message函数中的构造参数为:

  • x_i:中心节点(target)的特征向量组成的矩阵,注意该矩阵与图节点的矩阵x是不同的;
  • x_j:邻居节点(source)的特征向量组成的矩阵;
  • edge_index_i:中心节点的索引;
  • edge_index_j:邻居节点的索引。

注意,若flow='source_to_target',则消息将由邻居节点传向中心节点,若flow='target_to_source'则消息将从中心节点传向邻居节点,默认为第一种情况

4.aggregate函数

消息聚合函数aggregate用来聚合来自邻居的消息,常用的包括summeanmaxmin等,可以通过super().__init__()中的参数aggr来设定。该函数的第一个参数为message函数的返回值。

5.update函数

update函数用来更新节点的消息,aggregate函数的返回值作为该函数的第一个参数。

默认实现:
在这里插入图片描述

从默认实现可以看出update函数没有进行任何的操作,只是将raggregate函数的返回值返回了而已。

实际写代码的过程中,我们也不会去重写这个方法,而是,在forward函数中调用完propagate(…)函数后编写代码,代替update函数的功能。

三、代码实战

假设我们设计一个GNN模型,其中消息传递过程用公式表示如下:
X i ( k ) = X i ( k − 1 ) + ∑ j ∈ N ( i ) X j ( k − 1 ) (1) X_i^{(k)} = X_i^{(k-1)} + \sum _{j\in {\mathcal {N(i)}}} X_j^{(k-1) }\tag {1} Xi(k)​=Xi(k−1)​+j∈N(i)∑​Xj(k−1)​(1)

  • message生成的消息就是中心节点的邻居节点的特征向量。
  • aggregaet聚合消息的方式是sum,即把所有邻居节点的特征向量加起来。
  • update更新中心节点的方式是:将聚合得到的消息和中心节点自身的特征向量相加。

1.图数据定义

我们有如下数据:

import torch
from torch_geometric.data import Data

edge_index = torch.tensor([[0, 1],
						   [1, 0]], dtype=torch.long)
x = torch.tensor([[-1, 1], [0, 1], [1, 1]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index.contiguous())

在这里插入图片描述

2.实现GNN的消息传递过程

class MyConv(MessagePassing):
	def __init__(self):
		super().__init__(aggr='sum')

	def forward(self, data):
		out = self.propagate(data.edge_index, x=data.x)
		# out = out + x 
		return out

	def message(self, x_i, x_j, edge_index_i, edge_index_j):
		# 生成的消息就是邻居节点的特征向量,直接使用 x_j 访问获取就行
		return x_j

	def aggregate(self, message, edge_index_i):
		# 这里只是写的样例,实际上一般不会重写这个方法,直接使用默认的就好了,只需要自己选择一下聚合的方式即可
		return super().aggregate(message, edge_index_i, dim_size=len(x))

	def update(self, aggregate, x):
		# 一般也不会重写这个方法的,update阶段可以在forward函数中调用完propagate(...)函数后编写代码。
		return x + aggregate

3.完整代码

import torch
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing


class MyConv(MessagePassing):
	def __init__(self):
		super().__init__(aggr='sum')

	def forward(self, data):
		out = self.propagate(data.edge_index, x=data.x)
		out = out + data.x
		return out

	def message(self, x_i, x_j, edge_index_i, edge_index_j):
		# 生成的消息就是邻居节点的特征向量,直接使用 x_j 访问获取就行
		return x_j

	# def aggregate(self, message, edge_index_i):
	# 	return super().aggregate(message, edge_index_i, dim_size=len(x))

	# def update(self, aggregate, x):
	# 	return x + aggregate


if __name__ == '__main__':
	edge_index = torch.tensor([[0, 1],
							   [1, 0]], dtype=torch.long)
	x = torch.tensor([[-1, 1], [0, 1], [1, 1]], dtype=torch.float)
	data = Data(x=x, edge_index=edge_index.contiguous())

	myConv = MyConv()
	print(myConv(data))

4.完整代码的精简版本

import torch
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops


class MyConv(MessagePassing):
	def __init__(self):
		super().__init__(aggr='sum')

	def forward(self, data):
		edge_index, _ = add_self_loops(data.edge_index, num_nodes=len(data.x))
		out = self.propagate(edge_index, x=data.x)
		return out

if __name__ == '__main__':
	edge_index = torch.tensor([[0, 1],
							   [1, 0]], dtype=torch.long)
	x = torch.tensor([[-1, 1], [0, 1], [1, 1]], dtype=torch.float)
	data = Data(x=x, edge_index=edge_index.contiguous())

	myConv = MyConv()
	print(myConv(data))

思考:大家可以根据上面讲解的细节,理解一下这个精简版本的代码的实现逻辑和过程。

四、总结

1.MessagePassing各个函数的执行顺序

在这里插入图片描述

2.参考资料

标签:index,函数,torch,MessagePassing,edge,基类,data,节点,PyG
From: https://blog.csdn.net/weixin_65032328/article/details/144136272

相关文章

  • 针对Qwen-Agent框架的Function Call及ReAct的源码阅读与解析:Agent基类篇
    文章目录Agent继承链Agent类总体架构初始化方法`__init__`方法:`_init_tool`方法:对话生成方法`_call_llm`方法:工具调用方法`_call_tool`方法:`_detect_tool`方法:整体执行方法`run`方法:`_run`方法:`run_nonstream`方法......
  • C++ 多继承基类析构虚函数
    Demo:classAnimal{public:Animal(){cout<<"animal..."<<endl;}virtual~Animal(){cout<<"~animal..."<<endl;}virtualvoidShowAnimal()=0;};classCa......
  • Spyglass:更改默认编辑器
    相关阅读Spyglasshttps://blog.csdn.net/weixin_45791458/category_12828934.html?spm=1001.2014.3001.5482    Spyglass默认使用的是Vim(SmallVersion)作为其文本编辑器,如果希望使用其他文本编辑器(比如gedit、nano、VSCode、SublimeText),需要进行一些设置。 ......
  • Python中的pygame骨骼设想
    序骨骼一般来说都是在3D建模中,动画之类里面的比较常见,pygame里面的话,我是没咋听说过用到骨骼这样的东西,所以我这里也只是一个设想。一、核心思想考虑有些同学可能不是很清楚骨骼,那我就以我个人的理解方法来说一下。1.理解这个骨骼,如其名,就看你自身好了,通俗点就是理解为:你......
  • 使用Pygal库创建可缩放的矢量图表:从基础到高级自定义详解
    在数据可视化的世界中,创建可缩放的矢量图表是至关重要的,因为它们可以无损地在各种设备和分辨率下进行展示。Python中有许多强大的库可供选择,其中Pygal是一个出色的选择,它提供了创建各种类型的交互式矢量图表的功能。什么是Pygal?Pygal是一个Python库,专门用于创建可缩放的矢量图表。......
  • Unity中常用的三种单例模式基类
    提到框架,大家可能会觉得很复杂,我也是有同感但好在,也不是所有的框架都需要我们掌握的但这期提到单例模式却是大家一定要掌握!!!因为它真的非常重要!!!这里我整理了三种开发中常用的单例模式基类,希望大家看完后能多敲几遍,把它变成自己的知识(如果想直接拿到代码的话可以在我上传的......
  • 【可变参模板】基类参数包的展开
    一、基类参数包的展开1.1基类参数包的展开C++C++C++是一个支持多继承的语言,因此继承的类也可以是一个基类的......
  • 一个C++的 线程基类
      #include<iostream>#include<thread>#include<mutex>#include<condition_variable>#include<atomic>classThreadBase{public:ThreadBase():thread_(nullptr),stopFlag_(false){}virtual~ThreadBase(){......
  • (3-5)绘制散点图和折线图:Flask+pygal+SQLite实现数据分析
    3.5 Flask+pygal+SQLite实现数据分析在本节的内容中,将使用Flask+pygal+SQLite3实现数据分析功能。将需要分析的数据保存在SQLite3数据库中,然后在FlaskWeb网页中使用库pygal绘制出对应的统计图。3.5.1 创建数据库首先使用PyCharm创建一个FlaskWeb项目,然后通过文件model......
  • Unity实战案例 2D小游戏HappyGlass(模拟水珠)
    本案例素材和教程都来自Siki学院,十分感谢教程中的老师本文仅作学习笔记分享交流,不作任何商业用途预制体  在这个小案例中,水可以做成圆形但是带碰撞体,碰撞体比图形小一圈,顺便加上Trailrenderer组件 材质将碰撞材质的friction为0,bonciness可以按照需要修改脚本 ......