首页 > 其他分享 >PyG

PyG

时间:2023-10-16 22:14:38浏览次数:34  
标签:index None Tensor graph edge PyG out

目录

PyG

图的构建

普通图

  • 假设我们要构建一个 graph \(\mathcal{G}=\langle \mathcal{V}, \mathcal{E} \rangle\), 其中 \(|\mathcal{V}| = V, |\mathcal{E}| = E\).

[torch_geometric.data.Data]

class Data:

    def __init__(self, x: OptTensor = None, edge_index: OptTensor = None,
                 edge_attr: OptTensor = None, y: OptTensor = None,
                 pos: OptTensor = None, **kwargs):
  • x: torch.Tensor, (V, *).
  • edge_index: torch.Tensor, (2, E)
  • ...
graph = Data()
graph.x = torch.empty((3,))
graph.edge_index = torch.tensor([
    [0, 1],
    [1, 2]
])
>>> graph
Data(x=[3], edge_index=[2, 2])
  • 此图实际上是 \(v_0 \rightarrow v_1 \rightarrow v_2\), 这个有向图, 我们可以通过[to_undirected]快捷地将它转换为无向图:
def to_undirected(
    edge_index: Tensor,
    edge_attr: Optional[Union[Tensor, List[Tensor]]] = None,
    num_nodes: Optional[int] = None,
    reduce: str = "add",
) -> Union[Tensor, Tuple[Tensor, Tensor], Tuple[Tensor, List[Tensor]]]:
>>> graph.edge_index = to_undirected(graph.edge_index)
>>> graph.edge_index
tensor([[0, 1, 1, 2],
        [1, 0, 2, 1]])

异质图

  • 异质图主要围绕着边的类型构建, 主要通过 [HeteroData] 构建, 每个边类型的构建和普通图是完全一致的.

  • 接下来以推荐系统中的二部图为例:

from torch_geometric.data import HeteroData

graph = HeteroData()

# nodes
graph['User'].x = torch.empty((4,))
graph['Item'].x = torch.empty((5,))

# edge type "(User, click, Item)"
graph['User', 'click', 'Item'].edge_index = torch.tensor([
    [0, 1, 2, 3, 3],
    [0, 1, 2, 3, 4]
])

>>> graph
HeteroData(
  User={ x=[4] },
  Item={ x=[5] },
  (User, click, Item)={ edge_index=[2, 5] }
)

>>> graph.num_nodes
9
>>> graph['User'].num_nodes
4
>>> graph[('User', 'click', 'Item')]
{'edge_index': tensor([[0, 1, 2, 3, 3],
        [0, 1, 2, 3, 4]])}
>>> graph['click']
{'edge_index': tensor([[0, 1, 2, 3, 3],
        [0, 1, 2, 3, 4]])}
  • 二部图转为同质图:
graph = graph.coalesce() # 抹去重复的边
graph = graph.to_homogeneous()

>>> graph
Data(edge_index=[2, 5], x=[9], node_type=[9], edge_type=[5])
>>> graph.node_type
tensor([0, 0, 0, 0, 1, 1, 1, 1, 1])
>>> graph.edge_type
tensor([0, 0, 0, 0, 0])
  • 转为无向图:
graph.edge_index, graph.edge_type = to_undirected(graph.edge_index, edge_attr=graph.edge_type)

>>> graph
Data(edge_index=[2, 10], x=[9], node_type=[9], edge_type=[10])

MessagePassing

[torch_geometric.nn.conv.MessagePassing]

  • 一般的 GCN 可以归结为如下形式:

    \[x_i = \phi(x_i, \oplus_{j \in \mathcal{N}(i)} \: \varphi(x_i, x_j, e_{j \rightarrow i})). \]

  • 其中我们需要设定的包括:

    • \(\varphi\), message: 逐边处理的一个函数;
    • \(\oplus_{j \in \mathcal{N}(i)}\), aggr: 聚合操作, 比如常见的, sum, mean, min, max, 也可以是人为定义的.
    • \(\phi\), update: 更新函数, 通常是一些非线性的变换.
class MessagePassing:

    def __init__(
        self,
        aggr: Optional[Union[str, List[str], Aggregation]] = "add",
        *,
        aggr_kwargs: Optional[Dict[str, Any]] = None,
        flow: str = "source_to_target",
        node_dim: int = -2,
        decomposed_layers: int = 1,
        **kwargs,
    ): ...

    def propagate(self, edge_index: Adj, size: Size = None, **kwargs): ...
    def edge_updater(self, edge_index: Adj, **kwargs): ...
    def message(self, x_j: Tensor) -> Tensor: ...
    def aggregate(self, inputs: Tensor, index: Tensor,
                  ptr: Optional[Tensor] = None,
                  dim_size: Optional[int] = None) -> Tensor: ...
    def message_and_aggregate(self, adj_t: SparseTensor) -> Tensor: ...
    def update(self, inputs: Tensor) -> Tensor: ...
    def edge_update(self) -> Tensor: ...
  • aggr: 上面提到的, 支持多种方式;

  • aggr_kwargs: 自定义的一些聚合方式可能需要传一些参数;

  • flow: source_to_target (默认) or target_to_source. 我们知道, edge_index 是 (2, E) 大小的 tensor, 每一列表示一条边 \(e_{ji}\), 如果是前者则这条边方向是 \(e_{j \rightarrow i}\), 否则方向为 \(e_{i \rightarrow j}\).

  • node_dim: The axis along which to propagate. (default: :obj:-2) 这个主要用在 aggregation 的时候. 比如 aggregation 的输入为 (V, D) 大小的 tensor, 默认的 node_dim=-2 就能够保证是将不同的结点的特征聚合起来, 如果 node_dim=-1, 还要起到相同的效果就得输入 (D, V) 格式. 至于为什么不是 dim=0, 大概是因为可能有些时候会遇到是 (B, V, D) 之类的情况. 但是感觉文档没有突出它的重要性啊, 应该是很重要的参数.

  • MessagePassing 和普通的 nn.Module 类似, 主要脚本在 forward 中, 一般我们会在 forward 中调用 propagate 方法来管理卷积过程:

    • edge_index: (2, E);
    • size: 如果为 None, 则表示默认处理的是普通的图, 此时要求 source, target 的结点数目是一致的; 如果显式给定 (M, N), 则表示 source, target 的结点数目分别为 (M, N).
  • propagate 的执行流程如下:

    1. 检查输入, 整理得到合适的输入格式, 记为 coll_dict, 基于此得到适合各函数的输入: msg_kwargs, aggr_kwargs, update_kwargs, coll_dict 中包含如下的特殊的关键字:
          if isinstance(edge_index, Tensor):
              out['adj_t'] = None
              out['edge_index'] = edge_index
              out['edge_index_i'] = edge_index[i]
              out['edge_index_j'] = edge_index[j]
              out['ptr'] = None
          elif isinstance(edge_index, SparseTensor):
              out['adj_t'] = edge_index
              out['edge_index'] = None
              out['edge_index_i'] = edge_index.storage.row()
              out['edge_index_j'] = edge_index.storage.col()
              out['ptr'] = edge_index.storage.rowptr()
              if out.get('edge_weight', None) is None:
                  out['edge_weight'] = edge_index.storage.value()
              if out.get('edge_attr', None) is None:
                  out['edge_attr'] = edge_index.storage.value()
              if out.get('edge_type', None) is None:
                  out['edge_type'] = edge_index.storage.value()
      
          out['index'] = out['edge_index_i']
          out['size'] = size
          out['size_i'] = size[i] if size[i] is not None else size[j]
          out['size_j'] = size[j] if size[j] is not None else size[i]
          out['dim_size'] = out['size_i']
      
    2. 如果 message_and_aggregate 实现了, 则调用它, 否则向下执行;
    3. out = self.message(**msg_kwargs);
    4. out = self.aggregate(out, **aggr_kwargs);
    5. out = self.update(out, **update_kwargs);
    6. 然后输出
  • message 部分默认接受 \(x_j\), 默认情况下,

    x_j = x[edge_index[0]]
    

    这是个 \((E, D)\) 大小的 tensor, 这里假设在边上的操作只和 source 有关. 如果我们要弄一个复杂一点的, 比如:

    \[\varphi(x_i, x_j) = W[x_i\|x_j], \]

    就可以这么定义:

    def message(x_i: torch.Tensor, x_j: torch.Tensor):
        return self.mlp(torch.cat((x_i, x_j), dim=-1))
    
  • aggregate 部分接受 message 的输出 (E, D) 大小的 tensor 和一些其它的可选参数:

    def aggregate(self, inputs: Tensor, index: Tensor,
                  ptr: Optional[Tensor] = None,
                  dim_size: Optional[int] = None) -> Tensor:
        # inputs: (E, D)
        # index: edge_index_i, namely target index
        return self.aggr_module(inputs, index, ptr=ptr, dim_size=dim_size,
                                dim=self.node_dim)
    
  • update 部分默认接受 aggregate 的 (V, D) 的输出:

    def update(self, inputs: Tensor) -> Tensor:
        # inputs: (V, D)
        return inputs
    

    稍微复杂点的, 比如

    \[\phi(x_i, x_i^{aggr}) = \text{ReLU}(x_i + x_i^{aggr}). \]

    def update(self, aggregated: Tensor, x) -> Tensor:
        # aggregated: (V, D)
        # x: (V, D)
        return self.relu(inputs + x)
    

标签:index,None,Tensor,graph,edge,PyG,out
From: https://www.cnblogs.com/MTandHJ/p/17768490.html

相关文章

  • Python游戏开发:Pygame库入门
    Pygame是一个开源的Python库,用于开发2D游戏。它提供了许多功能,如游戏开发、音频处理和事件处理。安装Pygame库您可以通过以下命令在终端中安装Pygame库:pipinstallpygame创建游戏窗口要创建一个游戏窗口,您可以使用以下代码:importpygamepygame.init()#设置窗口尺寸window_......
  • python报错:pyglet.canvas.xlib.NoSuchDisplayException: Cannot connect to "None"
    运行python代码报错:       问题发现:问题其实十分的狗血,这个代码是在服务器上运行的,运行之前其实并没有看具体的代码情况,gitclone下载下来就直接运行了,原来这个代码需要进行图片绘制,说直白些就是需要显示屏,于是解决方法也十分简单,就是换个带桌面的电脑或者使用......
  • pyg学习
    官网:https://pytorch-geometric.readthedocs.io/en/latest/目录geometric图数据MessagePassingtorch_scatter一些代码解释geometric图数据pyg里一个图是torch_geometric.data.Data的instanceData试图模仿常规Python字典的行为。参数:x(torch.Tensor,optional)–图......
  • D. Andrey and Escape from Capygrad
    D.AndreyandEscapefromCapygradAnincidentoccurredinCapygrad,thecapitalofTyagoland,whereallthecapybarasinthecitywentcrazyandstartedthrowingmandarins.Andreywasforcedtoescapefromthecityasfaraspossible,usingportals.Tyag......
  • 在 Python 中使用 Pygal 绘制世界地图
    在Python的Pygal库的帮助下,我们可以在Python中创建令人惊叹的世界地图,因为它提供了不同的功能来创建和自定义图形。本文探讨了绘制世界地图、自定义地图样式、添加数据以突出显示国家或地区以及将地图呈现为SVG文件的分步过程。无论您是想可视化地理数据、展示国际统计数据......
  • Python pygame实现中国象棋单机版源码
    今天给大家带来的是关于Python实战的相关知识,文章围绕着用Pythonpygame实现中国象棋单机游戏版展开,文中有非常详细的代码示例,需要的朋友可以参考下#-*-coding:utf-8-*-"""CreatedonSunJun1315:41:562021@author:Administrator"""importpygamefrompygame.local......
  • pyGTK实战(1)
    目录简介hello,world事件驱动简介PyGTK是一套用Python和C语言编写的GTK+GUI库的包装器。它是GNOME项目的一部分。它为用Python构建桌面应用程序提供了全面的工具。PyGObject是一个Python包,它为基于GObject的库(如GTK、GStreamer、WebKitGTK、GLib、GIO等)提供绑定。它支持Linux、......
  • 基于Python+tkinter+pygame的音乐播放器完整源码
    importosimporttkinterimporttkinter.filedialogimportrandomimporttimeimportthreadingimportpygamefolder=''defplay():#folder用来表示存放MP3音乐文件的文件夹globalfoldermusics=[folder+'\\'+musicfo......
  • Spyglass的CDC检查
    接着前面Lint检查之后需要对RTL进行CDC检查,以下是简单的步骤。1.在完成lint检查后,也就是确保没有语法错误之后,点击GoalSetup,然后勾选主窗口下cdc_setup_check,然后点击RunGoal(s) ,当运行完成,会自动弹出AnalyzeResult窗口。2.得到分析结果后,cdc/cdc_setup_check......
  • Spyglass的Lint检查的步骤
    SpyGalss是Synopsys(新思科技)推出的一款静态Signoff平台,目前业界唯一可靠的RTLSignoff解决方案,可以帮助客户在设计早期发现潜在问题,保证产品质量,极大的减少设计风险,降低设计成本。笔者在转行做IC前没用过该软件,后面是入行后老员工指导需要用该软件进行跨时钟域检查,他说该软......