Example
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import NeighborLoader
import torch
data = Planetoid('./dataset', name='Cora')[0]
# Assign each node its global node index:
data.n_id = torch.arange(data.num_nodes)
loader = NeighborLoader(
data,
# Sample 30 neighbors for each node for 2 iterations
num_neighbors=[30] * 2,
# Use a batch size of 128 for sampling training nodes
batch_size=128,
input_nodes=data.train_mask,
)
sampled_data = next(iter(loader))
print(sampled_data.batch_size)
print(sampled_data.n_id) # NeighborLoader返回的子图中的节点index是local的,而非在原始data中的index,因此我们要给data增加一个n_id属保存原始节点id,并进行映射
完整示例
API 介绍
部分用法讲解(代码取自完整示例)
- 加载数据
data
要求是torch_geometric.data.Data or torch_geometric.data.HeteroData类型input_nodes
: 中心节点集合,即一个mini-batch内的节点,如果为None,则代表包含data中的所有节点num_neighbors
: 每轮迭代要采样邻居节点的个数,即第i轮要为每个节点采样num_neighbors[i]
个节点,如果为-1,则代表所有邻居节点都将被包含。
kwargs = {'batch_size': 1024, 'num_workers': 6, 'persistent_workers': True}
train_loader = NeighborLoader(data, input_nodes=data.train_mask,
num_neighbors=[25, 10], shuffle=True, **kwargs)
subgraph_loader = NeighborLoader(copy.copy(data), input_nodes=None,
num_neighbors=[-1], shuffle=False, **kwargs)
- 子图index映射
NeighborLoader返回的子图中的节点index是local的,而非在原始data中的index,因此我们要给data增加一个n_id属保存原始节点id,并进行映射
# Add global node index information.
subgraph_loader.data.num_nodes = data.num_nodes
subgraph_loader.data.n_id = torch.arange(data.num_nodes)
标签:NeighborLoader,torch,num,nodes,data,节点,PYG
From: https://www.cnblogs.com/mercurysun/p/16869877.html