当前位置:网站首页>PyG教程(3):邻居采样
PyG教程(3):邻居采样
2022-06-21 06:24:00 【斯曦巍峨】
一.为什么需要邻居采样?
在GNN领域,大图是非常常见的,但由于GPU显存的限制,大图是无法放到GPU上进行训练的。为此,可以采用邻居采样,这样一来可以将GNN扩展到大图上。在PyG中,邻居采样的方式有很多种,具体详解torch_geometric.loader。本文以GraphSage中的邻居采样为例进行介绍,其在PyG中实现为NeighborLoader。
NeighborSampler也是PyG中关于GraphSage中邻居采样的实现,但已经被弃用,在未来版本中会被删除。
二.NeighborLoader详解
2.1 GraphSage邻居采样原理
假设采样的层数为 K K K,每层采样的邻居数为 S k S_k Sk,GraphSage中邻居采样是这样进行的:
- 步骤一:首先给定要采样邻居的小批量节点集 B \mathcal{B} B;
- 步骤二:对 B \mathcal{B} B的 1 1 1跳(hop)邻居进行采样,然后得到 B 1 \mathcal{B}_1 B1,然后对 B 1 \mathcal{B}_1 B1的 1 1 1跳邻居进行采样(即最初结点集的 2 2 2跳邻居)得到 B 2 \mathcal{B}_2 B2,如此往复进行 K K K次,得到最初小批量节点集相关的一个子图。
下图左是GraphSage中给出的一个2层邻居采样的示例,其中每层采样的邻居数 S k S_k Sk是相等的(图中为3)。

2.2 API介绍
PyG中,GraphSage的邻居采样实现为torch_geometric.loader.NeighborLoader,其初始化函数参数为:
def __init__(
self,
data: Union[Data, HeteroData],
num_neighbors: NumNeighbors,
input_nodes: InputNodes = None,
replace: bool = False,
directed: bool = True,
transform: Callable = None,
neighbor_sampler: Optional[NeighborSampler] = None,
**kwargs,
)
常用参数说明如下:
data:要采样的图对象,可以为异构图HeteroData,也可以为同构图Data;num_neighbors:每个节点每次迭代(每层)采样的最大邻居数,List[int]类型,例如[2,2]表示采样2层,每层中每个节点最多采样2个邻居;input_nodes:从原始图中采样得到的子图中需要包含的原始图中节点索引,即2.1节中最初的 B \mathcal{B} B,torch.Tensor()类型;directed:如果设置为False,将包括所有采样节点之间的所有边;**kwargs:torch.utils.data.DataLoader的额外参数,例如batch_size,shuffle(具体详见该API)。
2.3 采样实践
为了可视化的美观性,本小节采用的图数据是PyG中提供的KarateClub数据集,该数据集描述了一个空手道俱乐部会员的社交关系,节点为34名会员,如果两位会员在俱乐部之外仍保持社交关系,则在对应节点间连边,该数据集的可视化如下所示:

下面是对该数据集的加载、可视化以及邻居采样的源码:
import torch
from torch_geometric.datasets import KarateClub
from torch_geometric.utils import to_networkx
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.loader import NeighborLoader
def draw(graph):
nids = graph.n_id
graph = to_networkx(graph)
for i, nid in enumerate(nids):
graph.nodes[i]['txt'] = str(nid.item())
node_labels = nx.get_node_attributes(graph, 'txt')
# print(node_labels)
# {0: '14', 1: '32', 2: '33', 3: '18', 4: '30', 5: '28', 6: '20'}
nx.draw_networkx(graph, labels=node_labels, node_color='#00BFFF')
plt.axis("off")
plt.show()
dataset = KarateClub()
g = dataset[0]
# print(g)
# Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34])
g.n_id = torch.arange(g.num_nodes)
for s in NeighborLoader(g, num_neighbors=[2, 2], input_nodes=torch.Tensor([14])):
draw(s)
break
在上述源码中,设置的采样层数为2层、每个节点每层采样最多采样2个邻居,采样的初始节点集为{14},其对应的采样结果如下所示:

从上图可以看出,在第一次迭代中,采样了节点{14}的两个1跳邻居{32,33},然后在第二次迭代中对{32,33}分别进行采样得到{2,8]}和{18,30}。
需要注意是通过NeighborLoader返回的子图中,全局节点索引会映射到到与该子图对应的局部索引。因此,若要将当前采样子图中的节点映射会原来图中对应的节点,可以在原始图中创建一个属性来完成两者之间的映射,例如采样实践源码中的:
g.n_id = torch.arange(g.num_nodes)
如此以来,采样后子图中的节点同样包含n_id属性,这样就可以将子图的节点映射回去了,上述示例中对图进行可视化便利用了这一点,其对应的映射为:
{
0: '14', 1: '32', 2: '33', 3: '18', 4: '30', 5: '28', 6: '20'}
结语
PyG中对于邻居采样的实现远远不止上述这一种,具体参见如下官网资料:
边栏推荐
- Which is better for children's consumption type serious diseases at present? Are there any recommended children's products
- How to access MySQL database through JDBC? Hand to hand login interface (illustration + complete code)
- Memorizing Normality to Detect Anomaly: Memory-augmented Deep Autoencoder for Unsupervised Anomaly D
- [data mining] final review Chapter 3
- FPGA - 7系列 FPGA SelectIO -02- 源语简介
- [reproduce ms08-067 via MSF tool intranet]
- [is the network you are familiar with really safe?] Wanziwen
- 数据可视化实战:数据处理
- docker 安装mysql
- FPGA - 7 Series FPGA selectio -04- ideay and ideayctrl of logical resources
猜你喜欢

Memorizing Normality to Detect Anomaly: Memory-augmented Deep Autoencoder for Unsupervised Anomaly D

高考那些事

Pycharm的快捷键Button 4 Click是什么?

Contos7 installing SVN server

认知语言学之框架与脚本

Aurora8B10B IP使用 -05- 收发测试应用示例

5254. dynamic planning of selling wood blocks

FPGA - 7 Series FPGA selectio -02- introduction to source language

DDD Practice Manual (4. aggregate aggregate)

如何通过JDBC访问MySQL数据库?手把手实现登录界面(图解+完整代码)
随机推荐
Port occupancy resolution
机器学习之数据归一化(Feature Scaling)
Module 14 - 15: network application communication test
Latest analysis on operation of refrigeration and air conditioning equipment in 2022 and examination question bank for operation of refrigeration and air conditioning equipment
Aurora8B10B IP使用 -03- IP配置应用指南
【【毕业季·进击的技术er】------老学长心得分享
C语言实现模拟银行存取款管理系统课程设计(纯C语言版)
第13期:Flink零基础学习路线
C语言程序设计——三子棋(学期小作业)
IP - 射频数据转换器 -04- API使用指南 - 系统设置相关函数
MySQL数据库基础:子查询
笔记 How Powerful are Spectral Graph Neural Networks
[JDBC from starting to Real combat] JDBC Basic clearance tutoriel (Summary of the first part)
Cache cache (notes on principles of computer composition)
【利用MSF工具内网复现MS08-067】
Digital signal processing-07-dds IP application example
5254. dynamic planning of selling wood blocks
【笔记自用】myeclipse连接MySQL数据库详细步骤
Regular expression Basics
leetcode 675. Cutting down trees for golf competitions - (day29)