当前位置:网站首页>DGL中的消息传递相关内容的讲解
DGL中的消息传递相关内容的讲解
2022-07-05 10:30:00 【Icy Hunter】
前言
学会DGL中的消息传递,基本就能够比较好的来理解编写各种图神经网络的代码了吧。
消息传递范式
消息传递是实现GNN的一种通用框架和编程范式。它从聚合与更新的角度归纳总结了多种GNN模型的实现。
因此在DGL代码编写消息传递部分时,我们需要三个函数,分别是消息函数、聚合函数、更新函数。
简单来说就是:
消息函数用来取边和节点的特征。
聚合函数用来计算边和节点的特征,例如特征求和,根据特征求个注意力权重等等。
更新函数用来更新节点的特征,对聚合函数传来的特征可以过个激活函数等,最后得到最终的节点特征即可更新。
DGL中的自定义消息函数
在DGL中,消息函数 接受一个参数 edges,这是一个 EdgeBatch 的实例, 在消息传递时,它被DGL在内部生成以表示一批边。 edges 有 src、 dst 和 data 共3个成员属性, 分别用于访问源节点、目标节点和边的特征。
用法就是定义一个函数,然后需要传入一个edges参数,这个参数有src、 dst 和 data 共3个成员属性,能够索引对应的特征
例:
def message_func(edges):
print("-"*20)
print("edges.data[x]", edges.data["x"]) # 获得边的特征
print("edges.src[x]", edges.src["x"]) # 获得边的源节点的特征
print("edges.dst[x]", edges.dst["x"]) # 获得边的目标节点的特征
# 返回得到需要传递的消息特征
return {
'e_data': edges.src['h'], "e_src": edges.src["x"], "e_dst": edges.dst["x"]}
DGL中的自定义聚合函数
聚合函数 接受一个参数 nodes,这是一个 NodeBatch 的实例, 在消息传递时,它被DGL在内部生成以表示一批节点。 nodes 的成员属性 mailbox 可以用来访问节点收到的消息。 一些最常见的聚合操作包括 sum、max、min 等。
用法就是定义一个函数,然后需要传入一个nodes参数,这个参数能够通过mailbox索引消息函数return来的特征。
例:
def reduce_func(nodes):
print("+"*20)
# 获取每个节点的边特征的和并储存在节点的e_data中
data_sum = th.sum(nodes.mailbox["e_data"], dim=1)
# 获取每条边的源节点特征并求和储存在节点的src_data中
src_sum = th.sum(nodes.mailbox["e_src"], dim=1)
# 获取每条边的目标节点特征并求和储存在节点的src_data中
dst_sum = th.sum(nodes.mailbox["e_dst"], dim=1)
print("nodes_e_data", data_sum)
print("nodes_e_src", src_sum)
print("nodes_e_dst", dst_sum)
return {
"data_sum":data_sum, "src_sum":src_sum, "dst_sum":dst_sum}
DGL中的自定义更新函数
更新函数 同样接受参数 nodes。此函数对聚合函数的聚合结果进行操作, 通常在消息传递的最后一步将其与节点的特征相结合,并将输出作为节点的新特征。
例:
def apply_node_func(nodes):
# 将x用data_sum更新
return {
'x': nodes.data["data_sum"]}
最后加上
g.update_all(message_func, reduce_func, apply_node_func)
即可完成消息传递的操作。
实例分析
创建图
首先我们先创建那么一张图:
其中黑色的为节点的特征,红色的为边的特征
对应创建代码如下:
import dgl
import dgl.function as fn
import torch
import torch as th
# 构建图
g = dgl.graph(([0, 1, 1, 1, 2, 3, 2, 4, 3, 4, 4], [1, 0, 3, 2, 1, 1, 4, 2, 4, 3, 4]))
# 每个节点的特征都为[1, 1]
g.ndata['x'] = torch.ones(5, 2)
# 每边节点的特征都为[1, 1]
g.edata['x'] = torch.ones(11, 2)
# 节点4的特征为[0.2, 0.5]
g.ndata['x'][4] = torch.tensor([0.2, 0.5])
# 边5的特征为[0.1, 0.1]
g.edata['x'][5] = torch.tensor([0.1, 0.1])
# 消息汇聚更新
# g.update_all(fn.copy_u(u='x', out='m'), fn.sum(msg='m', out='h'))
print(g.ndata['x'])
print(g.edata["x"])
消息传递
然后我们来试着理解一下消息传递:
def message_func(edges):
print("-"*20)
print("edges.data[x]", edges.data["x"]) # 获得边的特征
print("edges.src[x]", edges.src["x"]) # 获得边的源节点的特征
print("edges.dst[x]", edges.dst["x"]) # 获得边的目标节点的特征
# 返回得到需要传递的消息特征
return {
'e_data': edges.data['x'], "e_src": edges.src["x"], "e_dst": edges.dst["x"]}
def reduce_func(nodes):
print("+"*20)
# 获取每个节点的边特征的和并储存在节点的e_data中
data_sum = th.sum(nodes.mailbox["e_data"], dim=1)
# 获取每条边的源节点特征并求和储存在节点的src_data中
src_sum = th.sum(nodes.mailbox["e_src"], dim=1)
# 获取每条边的目标节点特征并求和储存在节点的src_data中
dst_sum = th.sum(nodes.mailbox["e_dst"], dim=1)
print("nodes_e_data", data_sum)
print("nodes_e_src", src_sum)
print("nodes_e_dst", dst_sum)
return {
"data_sum":data_sum, "src_sum":src_sum, "dst_sum":dst_sum}
def apply_node_func(nodes):
# 将x用data_sum更新
return {
'x': nodes.data["data_sum"]}
g.update_all(message_func, reduce_func, apply_node_func)
print(g.ndata["x"])
我们就看一下用边特征更新后的x特征的输出好了。
g.ndata[“x”]的输出:
tensor([[1.0000, 1.0000],
[2.1000, 2.1000],
[2.0000, 2.0000],
[2.0000, 2.0000],
[3.0000, 3.0000]])
说明每个节点的入度的边的特征都求和之后汇聚到x特征上了,还是非常好理解的。
另外两个大概是求源节点相同的边的目标节点的特征的和来更新节点特征
以及求目标节点相同的边的源节点的特征的和来更新节点特征
可能说起来有点绕,但是看看代码运行的结果再结合图应该就懂了,这里就不放运行结果了。
这里是为了演示步骤,一般不再update_all中自己设置更新函数的。
graph.apply_edges
在DGL中,也可以在不涉及消息传递的情况下,通过 apply_edges() 单独调用逐边计算。 apply_edges() 的参数是一个消息函数。并且在默认情况下,这个接口将更新所有的边。
import dgl
import torch
g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4]))
g.ndata['h'] = torch.ones(5, 2)
# 三种方式
# def add(edges):
# return{"x": edges.src['h'] + edges.dst['h']}
# g.apply_edges(add)
# g.apply_edges(lambda edges: {'x' : edges.src['h'] + edges.dst['h']}) # 二者等价
g.apply_edges(fn.u_add_v('h', 'h', 'x')) # 使用内置函数,是最好的
print(g.edata['x'])
算图的注意力机制的时候,可以先计算出每个边的注意力权重,此时直接边计算即可。
参考
https://docs.dgl.ai/guide_cn/message.html
https://docs.dgl.ai/guide_cn/message-api.html
边栏推荐
- “军备竞赛”时期的对比学习
- BOM//
- 【观察】跨境电商“独立站”模式崛起,如何抓住下一个红利爆发时代?
- Nuxt//
- 微信核酸检测预约小程序系统毕业设计毕设(6)开题答辩PPT
- NAS与SAN
- How to write high-quality code?
- C language QQ chat room small project [complete source code]
- 基于昇腾AI丨爱笔智能推出银行网点数字化解决方案,实现从总部到网点的信息数字化全覆盖
- [可能没有默认的字体]Warning: imagettfbbox() [function.imagettfbbox]: Invalid font filename……
猜你喜欢
【DNS】“Can‘t resolve host“ as non-root user, but works fine as root
[dark horse morning post] Luo Yonghao responded to ridicule Oriental selection; Dong Qing's husband Mi Chunlei was executed for more than 700million; Geely officially acquired Meizu; Huawei releases M
双向RNN与堆叠的双向RNN
AtCoder Beginner Contest 254「E bfs」「F st表维护差分数组gcd」
风控模型启用前的最后一道工序,80%的童鞋在这都踩坑
Redis如何实现多可用区?
【js学习笔记五十四】BFC方式
Apple 5g chip research and development failure? It's too early to get rid of Qualcomm
How can non-technical departments participate in Devops?
IDEA新建sprintboot项目
随机推荐
SAP ui5 objectpagelayout control usage sharing
QT implements JSON parsing
flink cdc不能监听mysql日志,大家遇到过这个问题吧?
2022年危险化学品生产单位安全生产管理人员特种作业证考试题库模拟考试平台操作
Have the bosses ever encountered such problems in the implementation of flinksql by Flink CDC mongdb?
Go-3-第一个Go程序
重磅:国产IDE发布,由阿里研发,完全开源!
A large number of virtual anchors in station B were collectively forced to refund: revenue evaporated, but they still owe station B; Jobs was posthumously awarded the U.S. presidential medal of freedo
手机厂商“互卷”之年:“机海战术”失灵,“慢节奏”打法崛起
Sqlserver regularly backup database and regularly kill database deadlock solution
beego跨域问题解决方案-亲试成功
AD20 制作 Logo
Universal double button or single button pop-up
flex4 和 flex3 combox 下拉框长度的解决办法
5G NR系统架构
Pseudo class elements -- before and after
基于昇腾AI丨爱笔智能推出银行网点数字化解决方案,实现从总部到网点的信息数字化全覆盖
Web Components
"Everyday Mathematics" serial 58: February 27
pytorch输出tensor张量时有省略号的解决方案(将tensor完整输出)