当前位置:网站首页>PyG教程(7):剖析邻域聚合
PyG教程(7):剖析邻域聚合
2022-06-22 05:56:00 【斯曦巍峨】
一.前言
上篇文章《PyG教程(6):自定义消息传递网络》主要介绍了消息传递GNN的大致框架。本文主要聚焦于消息传播中的邻域聚合,本文将介绍PyG是如何将节点的邻居的消息聚合到节点本身的。
二.PyG中的邻域聚合
PyG中邻域聚合是通过aggregate(inputs, index)函数来完成的,该函数的第一个参数inputs为消息构建函数message()构建的消息,该函数还存在一个参数index,这个参数对于消息聚合是十分关键的,它指示了inputs中每条消息属于哪个节点的邻域。下图便很好的解释了PyG中的消息聚合:

上述栗子中展示的是包含4个顶点、8条边的graph,其中input为在8条边上传播的消息、index为各条边上消息的归属,即目标节点的索引。通过index,可以将属于同一个节点邻域的消息聚合到一起,常见的聚合包括sum、mean、mean、mul和min等。
在PyG中通过scatter函数来实现上述过程,查看MessagePassing的源码,可以看到其aggregate函数的定义如下:
def aggregate(self, inputs: Tensor, index: Tensor,
ptr: Optional[Tensor] = None,
dim_size: Optional[int] = None) -> Tensor:
r""" 注释太长略 """
if ptr is not None:
ptr = expand_left(ptr, dim=self.node_dim, dims=inputs.dim())
return segment_csr(inputs, ptr, reduce=self.aggr)
else:
return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size,
reduce=self.aggr)
aggregate函数中scatter函数源码为:
def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None,
reduce: str = "sum") -> torch.Tensor:
r""" 注释太长略 """
if reduce == 'sum' or reduce == 'add':
return scatter_sum(src, index, dim, out, dim_size)
if reduce == 'mul':
return scatter_mul(src, index, dim, out, dim_size)
elif reduce == 'mean':
return scatter_mean(src, index, dim, out, dim_size)
elif reduce == 'min':
return scatter_min(src, index, dim, out, dim_size)[0]
elif reduce == 'max':
return scatter_max(src, index, dim, out, dim_size)[0]
else:
raise ValueError
其中便包含了前面提到的5种聚合方式。对于这些聚合方式,只需要在继承MessagePassing类时,通过super().__init__来向该类传递参数aggr参数的值即可。
三.torch_scatter模块
若用户需要自定义消息聚合,则在重写的aggregate()函数中,同样可以使用MessagePassing中的scatter函数,只需要导入torch_scatter模块即可。
在torch_scatter模块中也实现了scatter函数,其声明如下:
scatter: (src: Tensor, index: Tensor, dim: int = -1, out: Tensor | None = None, dim_size: int | None = None, reduce: str = "sum") -> Tensor
常用参数说明:
| 参数 | 说明 |
|---|---|
src | 每条边上的源节点生成的消息 |
index | 指示每条边上消息需要聚合到哪个节点上 |
dim | 指示沿着那个维度(轴)应用index进行聚合 |
reduce | 聚合操作,包括sum、mul、 mean、 min 和 max |
注意,torch_scatter也为上述的几种聚合单独提供了API:
torch_scatter.scatter_add()
torch_scatter.scatter_max()
torch_scatter.scatter_mean()
torch_scatter.scatter_min()
torch_scatter.scatter_mul()
为了方便理解,下面给出一个栗子,假设存在一个包含3个顶点、6条边的图:

假设0、1、2三个顶点生成的消息分别为1、2、3,则图中6条边的消息inputs和相应的index构造如下:
inputs = torch.tensor([[1], [1], [2], [2], [3], [3]])
index = torch.tensor([1, 2, 0, 2, 0, 1])
应用torch_scatter.scatter()函数的结果如下:
out = torch_scatter.scatter(src=inputs, index=index, dim=0, reduce="sum")
print(out)
""" tensor([[5], [4], [3]]) """
可以看到节点0接受来自节点1,2的消息得到2+3=5,节点1接受来自节点0,2的消息得到1+3=4,而节点2接受来自节点0,1的消息得到1+2=3。
四.结语
参考资料:
通过本文可以加深对PyG中消息聚合过程的理解,这将有助于更好的自定义GNN模型。以上便是本文的全部内容,若有任何错误,请批评指正。
边栏推荐
- 关于MNIST线性模型矩阵顺序问题
- 单细胞论文记录(part12)--Unsupervised Spatial Embedded Deep Representation of Spatial Transcriptomics
- 经验模式分解(EMD)和希尔伯特-黄变换(HHT)
- 单精度,双精度和精度(转载)
- Shengxin visualization (Part4) -- correlation diagram
- Using SystemVerilog to describe a state machine
- MFC tab control add Icon
- TCP connection details
- 组合逻辑块的测试平台
- C#中的泛型
猜你喜欢

Shengxin visualization (Part1) -- histogram

What about computer jam?

Linear regression: least squares, Tellson estimation, RANSAC

idea插件Easy Code的简单使用

性能优化 之 3D资产优化及顶点数据管理

MFC tab control add Icon

Serial port (RS - 232)

Array and foreach traversal in C #

Single cell paper record (Part14) -- costa: unsupervised revolutionary neural network learning for St analysis

格雷码与二进制的转换
随机推荐
DOS bat syntax record I
Unity app improves device availability
TiDB 社区线下交流会,天津 & 石家庄的小伙伴看过来~
Improve your game‘s performance
单细胞论文记录(part12)--Unsupervised Spatial Embedded Deep Representation of Spatial Transcriptomics
Vulkan 预旋转处理设备方向
关于MNIST线性模型矩阵顺序问题
PIR控制器调节器并网逆变器电流谐波抑制策略
idea插件EasyCode的使用
触 发 器
Vulkan pre rotation processing equipment direction
Flink核心功能和原理
Array and foreach traversal in C #
【Rust笔记】01-基本类型
富设备平台突破:基于RK3568的DAYU200进入OpenHarmony 3.1 Release主干
Bat 常用批处理脚本记录
Keil调试时设置断点的高级用法
401-字符串(344. 反转字符串、541. 反转字符串II、题目:剑指Offer 05.替换空格、151. 颠倒字符串中的单词)
W800芯片平台进入OpenHarmony主干
Single cell paper record (Part14) -- costa: unsupervised revolutionary neural network learning for St analysis