当前位置:网站首页>DGL中的消息传递相关内容的讲解
DGL中的消息传递相关内容的讲解
2022-07-05 10:30:00 【Icy Hunter】
前言
学会DGL中的消息传递,基本就能够比较好的来理解编写各种图神经网络的代码了吧。
消息传递范式
消息传递是实现GNN的一种通用框架和编程范式。它从聚合与更新的角度归纳总结了多种GNN模型的实现。
因此在DGL代码编写消息传递部分时,我们需要三个函数,分别是消息函数、聚合函数、更新函数。
简单来说就是:
消息函数用来取边和节点的特征。
聚合函数用来计算边和节点的特征,例如特征求和,根据特征求个注意力权重等等。
更新函数用来更新节点的特征,对聚合函数传来的特征可以过个激活函数等,最后得到最终的节点特征即可更新。
DGL中的自定义消息函数
在DGL中,消息函数 接受一个参数 edges,这是一个 EdgeBatch 的实例, 在消息传递时,它被DGL在内部生成以表示一批边。 edges 有 src、 dst 和 data 共3个成员属性, 分别用于访问源节点、目标节点和边的特征。
用法就是定义一个函数,然后需要传入一个edges参数,这个参数有src、 dst 和 data 共3个成员属性,能够索引对应的特征
例:
def message_func(edges):
print("-"*20)
print("edges.data[x]", edges.data["x"]) # 获得边的特征
print("edges.src[x]", edges.src["x"]) # 获得边的源节点的特征
print("edges.dst[x]", edges.dst["x"]) # 获得边的目标节点的特征
# 返回得到需要传递的消息特征
return {
'e_data': edges.src['h'], "e_src": edges.src["x"], "e_dst": edges.dst["x"]}
DGL中的自定义聚合函数
聚合函数 接受一个参数 nodes,这是一个 NodeBatch 的实例, 在消息传递时,它被DGL在内部生成以表示一批节点。 nodes 的成员属性 mailbox 可以用来访问节点收到的消息。 一些最常见的聚合操作包括 sum、max、min 等。
用法就是定义一个函数,然后需要传入一个nodes参数,这个参数能够通过mailbox索引消息函数return来的特征。
例:
def reduce_func(nodes):
print("+"*20)
# 获取每个节点的边特征的和并储存在节点的e_data中
data_sum = th.sum(nodes.mailbox["e_data"], dim=1)
# 获取每条边的源节点特征并求和储存在节点的src_data中
src_sum = th.sum(nodes.mailbox["e_src"], dim=1)
# 获取每条边的目标节点特征并求和储存在节点的src_data中
dst_sum = th.sum(nodes.mailbox["e_dst"], dim=1)
print("nodes_e_data", data_sum)
print("nodes_e_src", src_sum)
print("nodes_e_dst", dst_sum)
return {
"data_sum":data_sum, "src_sum":src_sum, "dst_sum":dst_sum}
DGL中的自定义更新函数
更新函数 同样接受参数 nodes。此函数对聚合函数的聚合结果进行操作, 通常在消息传递的最后一步将其与节点的特征相结合,并将输出作为节点的新特征。
例:
def apply_node_func(nodes):
# 将x用data_sum更新
return {
'x': nodes.data["data_sum"]}
最后加上
g.update_all(message_func, reduce_func, apply_node_func)
即可完成消息传递的操作。
实例分析
创建图
首先我们先创建那么一张图:
其中黑色的为节点的特征,红色的为边的特征
对应创建代码如下:
import dgl
import dgl.function as fn
import torch
import torch as th
# 构建图
g = dgl.graph(([0, 1, 1, 1, 2, 3, 2, 4, 3, 4, 4], [1, 0, 3, 2, 1, 1, 4, 2, 4, 3, 4]))
# 每个节点的特征都为[1, 1]
g.ndata['x'] = torch.ones(5, 2)
# 每边节点的特征都为[1, 1]
g.edata['x'] = torch.ones(11, 2)
# 节点4的特征为[0.2, 0.5]
g.ndata['x'][4] = torch.tensor([0.2, 0.5])
# 边5的特征为[0.1, 0.1]
g.edata['x'][5] = torch.tensor([0.1, 0.1])
# 消息汇聚更新
# g.update_all(fn.copy_u(u='x', out='m'), fn.sum(msg='m', out='h'))
print(g.ndata['x'])
print(g.edata["x"])
消息传递
然后我们来试着理解一下消息传递:
def message_func(edges):
print("-"*20)
print("edges.data[x]", edges.data["x"]) # 获得边的特征
print("edges.src[x]", edges.src["x"]) # 获得边的源节点的特征
print("edges.dst[x]", edges.dst["x"]) # 获得边的目标节点的特征
# 返回得到需要传递的消息特征
return {
'e_data': edges.data['x'], "e_src": edges.src["x"], "e_dst": edges.dst["x"]}
def reduce_func(nodes):
print("+"*20)
# 获取每个节点的边特征的和并储存在节点的e_data中
data_sum = th.sum(nodes.mailbox["e_data"], dim=1)
# 获取每条边的源节点特征并求和储存在节点的src_data中
src_sum = th.sum(nodes.mailbox["e_src"], dim=1)
# 获取每条边的目标节点特征并求和储存在节点的src_data中
dst_sum = th.sum(nodes.mailbox["e_dst"], dim=1)
print("nodes_e_data", data_sum)
print("nodes_e_src", src_sum)
print("nodes_e_dst", dst_sum)
return {
"data_sum":data_sum, "src_sum":src_sum, "dst_sum":dst_sum}
def apply_node_func(nodes):
# 将x用data_sum更新
return {
'x': nodes.data["data_sum"]}
g.update_all(message_func, reduce_func, apply_node_func)
print(g.ndata["x"])
我们就看一下用边特征更新后的x特征的输出好了。
g.ndata[“x”]的输出:
tensor([[1.0000, 1.0000],
[2.1000, 2.1000],
[2.0000, 2.0000],
[2.0000, 2.0000],
[3.0000, 3.0000]])
说明每个节点的入度的边的特征都求和之后汇聚到x特征上了,还是非常好理解的。
另外两个大概是求源节点相同的边的目标节点的特征的和来更新节点特征
以及求目标节点相同的边的源节点的特征的和来更新节点特征
可能说起来有点绕,但是看看代码运行的结果再结合图应该就懂了,这里就不放运行结果了。
这里是为了演示步骤,一般不再update_all中自己设置更新函数的。
graph.apply_edges
在DGL中,也可以在不涉及消息传递的情况下,通过 apply_edges() 单独调用逐边计算。 apply_edges() 的参数是一个消息函数。并且在默认情况下,这个接口将更新所有的边。
import dgl
import torch
g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4]))
g.ndata['h'] = torch.ones(5, 2)
# 三种方式
# def add(edges):
# return{"x": edges.src['h'] + edges.dst['h']}
# g.apply_edges(add)
# g.apply_edges(lambda edges: {'x' : edges.src['h'] + edges.dst['h']}) # 二者等价
g.apply_edges(fn.u_add_v('h', 'h', 'x')) # 使用内置函数,是最好的
print(g.edata['x'])
算图的注意力机制的时候,可以先计算出每个边的注意力权重,此时直接边计算即可。
参考
https://docs.dgl.ai/guide_cn/message.html
https://docs.dgl.ai/guide_cn/message-api.html
边栏推荐
- 微信核酸检测预约小程序系统毕业设计毕设(6)开题答辩PPT
- 双向RNN与堆叠的双向RNN
- [paper reading] ckan: collaborative knowledge aware autonomous network for adviser systems
- How does redis implement multiple zones?
- 正则表达式
- 2022年危险化学品经营单位主要负责人特种作业证考试题库及答案
- [可能没有默认的字体]Warning: imagettfbbox() [function.imagettfbbox]: Invalid font filename……
- What is the most suitable book for programmers to engage in open source?
- Web Components
- AtCoder Beginner Contest 258「ABCDEFG」
猜你喜欢
【黑马早报】罗永浩回应调侃东方甄选;董卿丈夫密春雷被执行超7亿;吉利正式收购魅族;华为发布问界M7;豆瓣为周杰伦专辑提前开分道歉...
WorkManager学习一
Learning Note 6 - satellite positioning technology (Part 1)
"Everyday Mathematics" serial 58: February 27
Learning notes 5 - high precision map solution
重磅:国产IDE发布,由阿里研发,完全开源!
Who is the "conscience" domestic brand?
2022鹏城杯web
2022年流动式起重机司机考试题库及模拟考试
[dark horse morning post] Luo Yonghao responded to ridicule Oriental selection; Dong Qing's husband Mi Chunlei was executed for more than 700million; Geely officially acquired Meizu; Huawei releases M
随机推荐
Who is the "conscience" domestic brand?
双向RNN与堆叠的双向RNN
2022年危险化学品经营单位主要负责人特种作业证考试题库及答案
微信核酸检测预约小程序系统毕业设计毕设(7)中期检查报告
SLAM 01.人类识别环境&路径的模型建立
Solution of ellipsis when pytorch outputs tensor (output tensor completely)
2021年山东省赛题库题目抓包
2022年化工自动化控制仪表考试试题及在线模拟考试
The most complete is an I2C summary
App各大应用商店/应用市场网址汇总
数据库中的范式:第一范式,第二范式,第三范式
2022年危险化学品生产单位安全生产管理人员特种作业证考试题库模拟考试平台操作
What are the top ten securities companies? Is it safe to open an account online?
Apple 5g chip research and development failure? It's too early to get rid of Qualcomm
5g NR system architecture
Flink CDC cannot monitor MySQL logs. Have you ever encountered this problem?
A usage example that can be compatible with various database transactions
dsPIC33EP 时钟初始化程序
Learning II of workmanager
Secteur non technique, comment participer à devops?