官网:https://pytorch-geometric.readthedocs.io/en/latest/
geometric
图数据
pyg里一个图是torch_geometric.data.Data
的instance
Data 试图模仿常规 Python 字典的行为。
参数:
- x (torch.Tensor, optional) – 图里节点的feature,大小是[num_nodes, num_node_features]
- edge_index (LongTensor, optional) – COO格式地去描述图的连接性,[2, num_edges]。也就是记录每条边的头尾实体。edge index在{0,...,num_nodes-1}的范围
- edge_attr (torch.Tensor, optional) – 图里边的feature,大小是[num_edges, num_edge_features]
- y (torch.Tensor, optional) – 具有任意形状的图级graph-level或节点级node-level真实标签。图级[1,],node级[num_nodes,]。【不懂啥意思】
- pos (torch.Tensor, optional) – 节点的位置矩阵,大小为[num_nodes, num_dimensions]【不懂啥意思】
- **kwargs (optional) – 其他自行添加的属性
假设目前构造了一个图Data,命名为graph
可能常用的函数:
- to_dict() - 返回图里各个参数的值的dict,key是参数名,比如graph.to_dict()['edge_attr']
- update(data: Union[Data, Dict[str, Any]]) - 根据其他的data来更新当前data【不知道是怎么Union的】
- subgraph(subset: Tensor) - 返回子图 subset表示了node indices,可以是LongTensor or BoolTensor表示留存的nodes【但是好像只有新版本有这个函数,python3.7torch_geometric2.0.1没有】
- edge_subgraph(subset: Tensor) - 返回子图,但是目前会保留所有的nodes(即使是isolated)【新版本】
- to_heterogeneous(node_type: Optional[Tensor] = None, edge_type: Optional[Tensor] = None, node_type_names: Optional[List[str]] = None, edge_type_names: Optional[List[Tuple[str, str, str]]] = None)
转换为异质图【新版本,没细看】 - apply(func: Callable, *args: str) - *args是图里需要修改的参数,func是对参数进行修改的函数,需要返回修改后的参数值。比如:
def test_func(tensor):
shape = tensor.shape
tensor = torch.zeros(shape)
return tensor
graph.apply(test_func, "x", "edge_attr")
这个apply的作用就是对data里的x属性和edge_attr属性做test_func操作(令tensor全为0)
同理有apply_函数,作用是不需要func返回修改后的值,直接就能修改了。
- clone() - copy.deepcopy当前graph
- coalesce() - 删除重复出现的边
图数据的设备切换:
- graph.cpu() # 官网做法
- graph.cuda('cuda:0') # 官网做法
- graph.to(device='cuda:0')
数据的detach(不求梯度)【注意freeze是求梯度但是不更新】举例:https://blog.csdn.net/weixin_44562957/article/details/120950157
- detach(*args: str)- detach全部或者只detach args中的参数
detach_类似
freeze是如下:
for param in B.parameters():
param.requires_grad = False
合法性检查:
- data.validate(raise_on_error=True) 但是只在新版上可以用
有其他检查函数:
data.has_isolated_nodes()
data.has_self_loops()
data.is_directed()
...
Message Passing
propagate的参数:edge_index, size=None, 多传的参数
propagate也可以重新,参考源码。
propagate的执行顺序:
1.out = self.message(args)
2.out = self.aggregate(out, args)
3.out = self.update(out, args)
注意,message传参的时候需要定义参数名字。但是实际上aggregate、update里的参数可以不仅是message里的参数
查看message_passing.py源码后可以发现,每次执行函数之前都会通过inspector.py的distribute函数来寻找参数。
根据源码可知,如果edge_index是SparseTensor类型,就会把message和aggregate结合为一个函数message_and_aggregate
自定义message,输入是(propagate多传的参数),每个node处理每个邻居、邻边信息。如下,norm是自定义加的参数,可以继续加
def message(self, x_j, norm):
# x_j has shape [E, out_channels] 代表当前nodes接收到的邻居/邻边信息
return norm.view(-1, 1) * x_j # 返回当前node聚合的结果
自定义aggregate,输入是(message的输出, propagate多传的参数),每个node聚合message得到的每个邻居邻边的信息。
比如可以直接在MP的参数里传aggr='add'
,也可以直接自定义,一般情况下等价自定义如下:
def aggregate(self, inputs: Tensor, index: Tensor, dim_size: Optional[int] = None) -> Tensor:
return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size, reduce=self.aggr)
自定义update,输入是(aggregate的输出,propagate多传的参数),每个node根据聚合的信息来更新表示。一般来说直接输出aggregate的输出,不需要重写update。
def update(self, inputs:Tensor):
return inputs
torch_scatter
https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html
把src的数据,根据index的dim axis分类,执行reduce操作(sum/mul/mean/min/max),输出到out张量里(或者scatter函数直接返回)
from torch_scatter import scatter
src = torch.randn(10, 6, 64)
index = torch.tensor([0, 1, 0, 1, 2, 1])
# Broadcasting in the first and last dim.
out = scatter(src, index, dim=1, reduce="sum")
print(out.size())
输出:
torch.Size([10, 3, 64])
一些代码解释
inspector.py
标签:node,Tensor,torch,学习,edge,geometric,data,pyg From: https://www.cnblogs.com/ReflexFox/p/17708758.html