当前位置:网站首页>Explanation of message passing in DGL
Explanation of message passing in DGL
2022-07-05 10:46:00 【Icy Hunter】
List of articles
Preface
Learn to DGL Message passing in , Basically, you can better understand and write the code of various graph Neural Networks .
Messaging paradigm
Messaging is the realization of GNN A general framework and programming paradigm . It summarizes a variety of from the perspective of aggregation and renewal GNN Model implementation .
So in DGL When coding the message passing part , We need three functions , They are message functions 、 Aggregate functions 、 Update function .
In a nutshell :
The message function is used to take the characteristics of edges and nodes .
Aggregate functions are used to calculate the characteristics of edges and nodes , For example, feature summation , Calculate the attention weight according to the characteristics, and so on .
The update function is used to update the characteristics of the node , For the features passed from the aggregate function, you can pass an activation function, etc , Finally, the final node characteristics can be updated .
DGL Custom message functions in
stay DGL in , Message function Take a parameter edges, This is a EdgeBatch Example , During message delivery , It has been DGL Generated internally to represent a batch of edges . edges Yes src、 dst and data common 3 Member properties , Used to access source nodes respectively 、 Characteristics of target nodes and edges .
Usage is to define a function , Then you need to pass in a edges Parameters , This parameter has src、 dst and data common 3 Member properties , Be able to index the corresponding features
example :
def message_func(edges):
print("-"*20)
print("edges.data[x]", edges.data["x"]) # Get the characteristics of the edge
print("edges.src[x]", edges.src["x"]) # Get the characteristics of the source node of the edge
print("edges.dst[x]", edges.dst["x"]) # Get the characteristics of the target node of the edge
# Return the message characteristics that need to be delivered
return {
'e_data': edges.src['h'], "e_src": edges.src["x"], "e_dst": edges.dst["x"]}
DGL Custom aggregate functions in
Aggregate functions Take a parameter nodes, This is a NodeBatch Example , During message delivery , It has been DGL Generated internally to represent a batch of nodes . nodes Member properties of mailbox Can be used to access messages received by nodes . Some of the most common aggregation operations include sum、max、min etc. .
Usage is to define a function , Then you need to pass in a nodes Parameters , This parameter can pass mailbox Index message function return The characteristics of coming .
example :
def reduce_func(nodes):
print("+"*20)
# Get the sum of the edge features of each node and store it in the node e_data in
data_sum = th.sum(nodes.mailbox["e_data"], dim=1)
# Get the source node characteristics of each edge and sum them and store them in the src_data in
src_sum = th.sum(nodes.mailbox["e_src"], dim=1)
# Get the target node characteristics of each edge and sum them and store them in the src_data in
dst_sum = th.sum(nodes.mailbox["e_dst"], dim=1)
print("nodes_e_data", data_sum)
print("nodes_e_src", src_sum)
print("nodes_e_dst", dst_sum)
return {
"data_sum":data_sum, "src_sum":src_sum, "dst_sum":dst_sum}
DGL Custom update function in
Update function Also accept parameters nodes. This function operates on the aggregation result of the aggregation function , It is usually combined with the characteristics of the node in the last step of message passing , And take the output as a new feature of the node .
example :
def apply_node_func(nodes):
# take x use data_sum to update
return {
'x': nodes.data["data_sum"]}
Then add
g.update_all(message_func, reduce_func, apply_node_func)
You can complete the operation of message transmission .
The example analysis
Create diagrams
First, let's create a picture :
The black one is the feature of the node , The red is the feature of the edge
The corresponding creation code is as follows :
import dgl
import dgl.function as fn
import torch
import torch as th
# Build a diagram
g = dgl.graph(([0, 1, 1, 1, 2, 3, 2, 4, 3, 4, 4], [1, 0, 3, 2, 1, 1, 4, 2, 4, 3, 4]))
# Each node is characterized by [1, 1]
g.ndata['x'] = torch.ones(5, 2)
# The characteristics of each side node are [1, 1]
g.edata['x'] = torch.ones(11, 2)
# node 4 Is characterized by [0.2, 0.5]
g.ndata['x'][4] = torch.tensor([0.2, 0.5])
# edge 5 Is characterized by [0.1, 0.1]
g.edata['x'][5] = torch.tensor([0.1, 0.1])
# Message aggregation update
# g.update_all(fn.copy_u(u='x', out='m'), fn.sum(msg='m', out='h'))
print(g.ndata['x'])
print(g.edata["x"])
The messaging
Then let's try to understand messaging :
def message_func(edges):
print("-"*20)
print("edges.data[x]", edges.data["x"]) # Get the characteristics of the edge
print("edges.src[x]", edges.src["x"]) # Get the characteristics of the source node of the edge
print("edges.dst[x]", edges.dst["x"]) # Get the characteristics of the target node of the edge
# Return the message characteristics that need to be delivered
return {
'e_data': edges.data['x'], "e_src": edges.src["x"], "e_dst": edges.dst["x"]}
def reduce_func(nodes):
print("+"*20)
# Get the sum of the edge features of each node and store it in the node e_data in
data_sum = th.sum(nodes.mailbox["e_data"], dim=1)
# Get the source node characteristics of each edge and sum them and store them in the src_data in
src_sum = th.sum(nodes.mailbox["e_src"], dim=1)
# Get the target node characteristics of each edge and sum them and store them in the src_data in
dst_sum = th.sum(nodes.mailbox["e_dst"], dim=1)
print("nodes_e_data", data_sum)
print("nodes_e_src", src_sum)
print("nodes_e_dst", dst_sum)
return {
"data_sum":data_sum, "src_sum":src_sum, "dst_sum":dst_sum}
def apply_node_func(nodes):
# take x use data_sum to update
return {
'x': nodes.data["data_sum"]}
g.update_all(message_func, reduce_func, apply_node_func)
print(g.ndata["x"])
Let's take a look at the updated with edge features x The output of features is good .
g.ndata[“x”] Output :
tensor([[1.0000, 1.0000],
[2.1000, 2.1000],
[2.0000, 2.0000],
[2.0000, 2.0000],
[3.0000, 3.0000]])
It shows that the features of the edges of the penetration of each node are summed and converged to x Features on , It's very easy to understand .
The other two are probably the sum of the characteristics of the target node of the same edge of the source node to update the node characteristics
And find the sum of the characteristics of the source node of the same edge of the target node to update the node characteristics
Maybe it's a little windy , But look at the results of the code and then combine it with the diagram , There will be no running results here .
This is to demonstrate the steps , Generally no longer update_all Set the update function in .
graph.apply_edges
stay DGL in , You can also... Without involving messaging , adopt apply_edges() Call side by side calculation alone . apply_edges() The argument to is a message function . And by default , This interface will update all edges .
import dgl
import torch
g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4]))
g.ndata['h'] = torch.ones(5, 2)
# Three ways
# def add(edges):
# return{"x": edges.src['h'] + edges.dst['h']}
# g.apply_edges(add)
# g.apply_edges(lambda edges: {'x' : edges.src['h'] + edges.dst['h']}) # The two are equivalent
g.apply_edges(fn.u_add_v('h', 'h', 'x')) # Use built-in functions , It's the best
print(g.edata['x'])
When calculating the attention mechanism of the graph , You can calculate the attention weight of each side first , At this time, the direct edge calculation is sufficient .
Reference resources
https://docs.dgl.ai/guide_cn/message.html
https://docs.dgl.ai/guide_cn/message-api.html
边栏推荐
- Learning note 4 -- Key Technologies of high-precision map (Part 2)
- 使用GBase 8c数据库过程中报错:80000502,Cluster:%s is busy,是怎么回事?
- Solution to the length of flex4 and Flex3 combox drop-down box
- Comparative learning in the period of "arms race"
- 基于昇腾AI丨爱笔智能推出银行网点数字化解决方案,实现从总部到网点的信息数字化全覆盖
- Idea create a new sprintboot project
- websocket
- ModuleNotFoundError: No module named ‘scrapy‘ 终极解决方式
- [dark horse morning post] Luo Yonghao responded to ridicule Oriental selection; Dong Qing's husband Mi Chunlei was executed for more than 700million; Geely officially acquired Meizu; Huawei releases M
- Talk about the understanding of fault tolerance mechanism and state consistency in Flink framework
猜你喜欢
风控模型启用前的最后一道工序,80%的童鞋在这都踩坑
非技术部门,如何参与 DevOps?
【黑马早报】罗永浩回应调侃东方甄选;董卿丈夫密春雷被执行超7亿;吉利正式收购魅族;华为发布问界M7;豆瓣为周杰伦专辑提前开分道歉...
在C# 中实现上升沿,并模仿PLC环境验证 If 语句使用上升沿和不使用上升沿的不同
Ad20 make logo
5g NR system architecture
双向RNN与堆叠的双向RNN
Learning note 4 -- Key Technologies of high-precision map (Part 2)
Based on shengteng AI Aibi intelligence, we launched a digital solution for bank outlets to achieve full digital coverage of information from headquarters to outlets
第五届 Polkadot Hackathon 创业大赛全程回顾,获胜项目揭秘!
随机推荐
Based on shengteng AI Yisa technology, it launched a full target structured solution for video images, reaching the industry-leading level
一个可以兼容各种数据库事务的使用范例
C language QQ chat room small project [complete source code]
关于 “原型” 的那些事你真的理解了吗?【上篇】
Implement the rising edge in C #, and simulate the PLC environment to verify the difference between if statement using the rising edge and not using the rising edge
Go-3-the first go program
在C# 中实现上升沿,并模仿PLC环境验证 If 语句使用上升沿和不使用上升沿的不同
赛克瑞浦动力电池首台产品正式下线
2021年山东省赛题库题目抓包
Pull up loading principle
Nine degrees 1480: maximum ascending subsequence sum (dynamic programming idea for the maximum value)
第五届 Polkadot Hackathon 创业大赛全程回顾,获胜项目揭秘!
微信核酸检测预约小程序系统毕业设计毕设(8)毕业设计论文模板
DDOS攻击原理,被ddos攻击的现象
The first product of Sepp power battery was officially launched
括号匹配问题(STL)
5G NR系统架构
字符串、、
【SWT组件】内容滚动组件 ScrolledComposite
关于vray 5.2的使用(自研笔记)