当前位置:网站首页>Explanation of message passing in DGL
Explanation of message passing in DGL
2022-07-05 10:46:00 【Icy Hunter】
List of articles
Preface
Learn to DGL Message passing in , Basically, you can better understand and write the code of various graph Neural Networks .
Messaging paradigm
Messaging is the realization of GNN A general framework and programming paradigm . It summarizes a variety of from the perspective of aggregation and renewal GNN Model implementation .
So in DGL When coding the message passing part , We need three functions , They are message functions 、 Aggregate functions 、 Update function .
In a nutshell :
The message function is used to take the characteristics of edges and nodes .
Aggregate functions are used to calculate the characteristics of edges and nodes , For example, feature summation , Calculate the attention weight according to the characteristics, and so on .
The update function is used to update the characteristics of the node , For the features passed from the aggregate function, you can pass an activation function, etc , Finally, the final node characteristics can be updated .
DGL Custom message functions in
stay DGL in , Message function Take a parameter edges, This is a EdgeBatch Example , During message delivery , It has been DGL Generated internally to represent a batch of edges . edges Yes src、 dst and data common 3 Member properties , Used to access source nodes respectively 、 Characteristics of target nodes and edges .
Usage is to define a function , Then you need to pass in a edges Parameters , This parameter has src、 dst and data common 3 Member properties , Be able to index the corresponding features
example :
def message_func(edges):
print("-"*20)
print("edges.data[x]", edges.data["x"]) # Get the characteristics of the edge
print("edges.src[x]", edges.src["x"]) # Get the characteristics of the source node of the edge
print("edges.dst[x]", edges.dst["x"]) # Get the characteristics of the target node of the edge
# Return the message characteristics that need to be delivered
return {
'e_data': edges.src['h'], "e_src": edges.src["x"], "e_dst": edges.dst["x"]}
DGL Custom aggregate functions in
Aggregate functions Take a parameter nodes, This is a NodeBatch Example , During message delivery , It has been DGL Generated internally to represent a batch of nodes . nodes Member properties of mailbox Can be used to access messages received by nodes . Some of the most common aggregation operations include sum、max、min etc. .
Usage is to define a function , Then you need to pass in a nodes Parameters , This parameter can pass mailbox Index message function return The characteristics of coming .
example :
def reduce_func(nodes):
print("+"*20)
# Get the sum of the edge features of each node and store it in the node e_data in
data_sum = th.sum(nodes.mailbox["e_data"], dim=1)
# Get the source node characteristics of each edge and sum them and store them in the src_data in
src_sum = th.sum(nodes.mailbox["e_src"], dim=1)
# Get the target node characteristics of each edge and sum them and store them in the src_data in
dst_sum = th.sum(nodes.mailbox["e_dst"], dim=1)
print("nodes_e_data", data_sum)
print("nodes_e_src", src_sum)
print("nodes_e_dst", dst_sum)
return {
"data_sum":data_sum, "src_sum":src_sum, "dst_sum":dst_sum}
DGL Custom update function in
Update function Also accept parameters nodes. This function operates on the aggregation result of the aggregation function , It is usually combined with the characteristics of the node in the last step of message passing , And take the output as a new feature of the node .
example :
def apply_node_func(nodes):
# take x use data_sum to update
return {
'x': nodes.data["data_sum"]}
Then add
g.update_all(message_func, reduce_func, apply_node_func)
You can complete the operation of message transmission .
The example analysis
Create diagrams
First, let's create a picture :
The black one is the feature of the node , The red is the feature of the edge
The corresponding creation code is as follows :
import dgl
import dgl.function as fn
import torch
import torch as th
# Build a diagram
g = dgl.graph(([0, 1, 1, 1, 2, 3, 2, 4, 3, 4, 4], [1, 0, 3, 2, 1, 1, 4, 2, 4, 3, 4]))
# Each node is characterized by [1, 1]
g.ndata['x'] = torch.ones(5, 2)
# The characteristics of each side node are [1, 1]
g.edata['x'] = torch.ones(11, 2)
# node 4 Is characterized by [0.2, 0.5]
g.ndata['x'][4] = torch.tensor([0.2, 0.5])
# edge 5 Is characterized by [0.1, 0.1]
g.edata['x'][5] = torch.tensor([0.1, 0.1])
# Message aggregation update
# g.update_all(fn.copy_u(u='x', out='m'), fn.sum(msg='m', out='h'))
print(g.ndata['x'])
print(g.edata["x"])
The messaging
Then let's try to understand messaging :
def message_func(edges):
print("-"*20)
print("edges.data[x]", edges.data["x"]) # Get the characteristics of the edge
print("edges.src[x]", edges.src["x"]) # Get the characteristics of the source node of the edge
print("edges.dst[x]", edges.dst["x"]) # Get the characteristics of the target node of the edge
# Return the message characteristics that need to be delivered
return {
'e_data': edges.data['x'], "e_src": edges.src["x"], "e_dst": edges.dst["x"]}
def reduce_func(nodes):
print("+"*20)
# Get the sum of the edge features of each node and store it in the node e_data in
data_sum = th.sum(nodes.mailbox["e_data"], dim=1)
# Get the source node characteristics of each edge and sum them and store them in the src_data in
src_sum = th.sum(nodes.mailbox["e_src"], dim=1)
# Get the target node characteristics of each edge and sum them and store them in the src_data in
dst_sum = th.sum(nodes.mailbox["e_dst"], dim=1)
print("nodes_e_data", data_sum)
print("nodes_e_src", src_sum)
print("nodes_e_dst", dst_sum)
return {
"data_sum":data_sum, "src_sum":src_sum, "dst_sum":dst_sum}
def apply_node_func(nodes):
# take x use data_sum to update
return {
'x': nodes.data["data_sum"]}
g.update_all(message_func, reduce_func, apply_node_func)
print(g.ndata["x"])
Let's take a look at the updated with edge features x The output of features is good .
g.ndata[“x”] Output :
tensor([[1.0000, 1.0000],
[2.1000, 2.1000],
[2.0000, 2.0000],
[2.0000, 2.0000],
[3.0000, 3.0000]])
It shows that the features of the edges of the penetration of each node are summed and converged to x Features on , It's very easy to understand .
The other two are probably the sum of the characteristics of the target node of the same edge of the source node to update the node characteristics
And find the sum of the characteristics of the source node of the same edge of the target node to update the node characteristics
Maybe it's a little windy , But look at the results of the code and then combine it with the diagram , There will be no running results here .
This is to demonstrate the steps , Generally no longer update_all Set the update function in .
graph.apply_edges
stay DGL in , You can also... Without involving messaging , adopt apply_edges() Call side by side calculation alone . apply_edges() The argument to is a message function . And by default , This interface will update all edges .
import dgl
import torch
g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4]))
g.ndata['h'] = torch.ones(5, 2)
# Three ways
# def add(edges):
# return{"x": edges.src['h'] + edges.dst['h']}
# g.apply_edges(add)
# g.apply_edges(lambda edges: {'x' : edges.src['h'] + edges.dst['h']}) # The two are equivalent
g.apply_edges(fn.u_add_v('h', 'h', 'x')) # Use built-in functions , It's the best
print(g.edata['x'])
When calculating the attention mechanism of the graph , You can calculate the attention weight of each side first , At this time, the direct edge calculation is sufficient .
Reference resources
https://docs.dgl.ai/guide_cn/message.html
https://docs.dgl.ai/guide_cn/message-api.html
边栏推荐
- 分享.NET 轻量级的ORM
- Error: module not found: error: can't resolve 'xxx' in 'XXXX‘
- Solution to the length of flex4 and Flex3 combox drop-down box
- C language QQ chat room small project [complete source code]
- iframe
- 5g NR system architecture
- 谈谈对Flink框架中容错机制及状态的一致性的理解
- Web Components
- 双向RNN与堆叠的双向RNN
- Secteur non technique, comment participer à devops?
猜你喜欢
AD20 制作 Logo
Learning Note 6 - satellite positioning technology (Part 1)
Go语言-1-开发环境配置
"Everyday Mathematics" serial 58: February 27
Web3基金会「Grant计划」赋能开发者,盘点四大成功项目
基于昇腾AI丨爱笔智能推出银行网点数字化解决方案,实现从总部到网点的信息数字化全覆盖
第五届 Polkadot Hackathon 创业大赛全程回顾,获胜项目揭秘!
DGL中的消息传递相关内容的讲解
AtCoder Beginner Contest 258「ABCDEFG」
关于vray 5.2的使用(自研笔记)
随机推荐
beego跨域问题解决方案-亲试成功
非技術部門,如何參與 DevOps?
Golang应用专题 - channel
Flink CDC cannot monitor MySQL logs. Have you ever encountered this problem?
Go-3-第一个Go程序
2022年T电梯修理操作证考试题及答案
非技术部门,如何参与 DevOps?
Qt实现json解析
Idea create a new sprintboot project
Sqlserver regularly backup database and regularly kill database deadlock solution
[paper reading] kgat: knowledge graph attention network for recommendation
AD20 制作 Logo
flex4 和 flex3 combox 下拉框长度的解决办法
LSTM应用于MNIST数据集分类(与CNN做对比)
数组、、、
Lazy loading scheme of pictures
一个可以兼容各种数据库事务的使用范例
Should the dependency given by the official website be Flink SQL connector MySQL CDC, with dependency added
2022鹏城杯web
上拉加载原理