当前位置:网站首页>Explanation of message passing in DGL
Explanation of message passing in DGL
2022-07-05 10:46:00 【Icy Hunter】
List of articles
Preface
Learn to DGL Message passing in , Basically, you can better understand and write the code of various graph Neural Networks .
Messaging paradigm
Messaging is the realization of GNN A general framework and programming paradigm . It summarizes a variety of from the perspective of aggregation and renewal GNN Model implementation .
So in DGL When coding the message passing part , We need three functions , They are message functions 、 Aggregate functions 、 Update function .
In a nutshell :
The message function is used to take the characteristics of edges and nodes .
Aggregate functions are used to calculate the characteristics of edges and nodes , For example, feature summation , Calculate the attention weight according to the characteristics, and so on .
The update function is used to update the characteristics of the node , For the features passed from the aggregate function, you can pass an activation function, etc , Finally, the final node characteristics can be updated .
DGL Custom message functions in
stay DGL in , Message function Take a parameter edges, This is a EdgeBatch Example , During message delivery , It has been DGL Generated internally to represent a batch of edges . edges Yes src、 dst and data common 3 Member properties , Used to access source nodes respectively 、 Characteristics of target nodes and edges .
Usage is to define a function , Then you need to pass in a edges Parameters , This parameter has src、 dst and data common 3 Member properties , Be able to index the corresponding features
example :
def message_func(edges):
print("-"*20)
print("edges.data[x]", edges.data["x"]) # Get the characteristics of the edge
print("edges.src[x]", edges.src["x"]) # Get the characteristics of the source node of the edge
print("edges.dst[x]", edges.dst["x"]) # Get the characteristics of the target node of the edge
# Return the message characteristics that need to be delivered
return {
'e_data': edges.src['h'], "e_src": edges.src["x"], "e_dst": edges.dst["x"]}
DGL Custom aggregate functions in
Aggregate functions Take a parameter nodes, This is a NodeBatch Example , During message delivery , It has been DGL Generated internally to represent a batch of nodes . nodes Member properties of mailbox Can be used to access messages received by nodes . Some of the most common aggregation operations include sum、max、min etc. .
Usage is to define a function , Then you need to pass in a nodes Parameters , This parameter can pass mailbox Index message function return The characteristics of coming .
example :
def reduce_func(nodes):
print("+"*20)
# Get the sum of the edge features of each node and store it in the node e_data in
data_sum = th.sum(nodes.mailbox["e_data"], dim=1)
# Get the source node characteristics of each edge and sum them and store them in the src_data in
src_sum = th.sum(nodes.mailbox["e_src"], dim=1)
# Get the target node characteristics of each edge and sum them and store them in the src_data in
dst_sum = th.sum(nodes.mailbox["e_dst"], dim=1)
print("nodes_e_data", data_sum)
print("nodes_e_src", src_sum)
print("nodes_e_dst", dst_sum)
return {
"data_sum":data_sum, "src_sum":src_sum, "dst_sum":dst_sum}
DGL Custom update function in
Update function Also accept parameters nodes. This function operates on the aggregation result of the aggregation function , It is usually combined with the characteristics of the node in the last step of message passing , And take the output as a new feature of the node .
example :
def apply_node_func(nodes):
# take x use data_sum to update
return {
'x': nodes.data["data_sum"]}
Then add
g.update_all(message_func, reduce_func, apply_node_func)
You can complete the operation of message transmission .
The example analysis
Create diagrams
First, let's create a picture :
The black one is the feature of the node , The red is the feature of the edge
The corresponding creation code is as follows :
import dgl
import dgl.function as fn
import torch
import torch as th
# Build a diagram
g = dgl.graph(([0, 1, 1, 1, 2, 3, 2, 4, 3, 4, 4], [1, 0, 3, 2, 1, 1, 4, 2, 4, 3, 4]))
# Each node is characterized by [1, 1]
g.ndata['x'] = torch.ones(5, 2)
# The characteristics of each side node are [1, 1]
g.edata['x'] = torch.ones(11, 2)
# node 4 Is characterized by [0.2, 0.5]
g.ndata['x'][4] = torch.tensor([0.2, 0.5])
# edge 5 Is characterized by [0.1, 0.1]
g.edata['x'][5] = torch.tensor([0.1, 0.1])
# Message aggregation update
# g.update_all(fn.copy_u(u='x', out='m'), fn.sum(msg='m', out='h'))
print(g.ndata['x'])
print(g.edata["x"])
The messaging
Then let's try to understand messaging :
def message_func(edges):
print("-"*20)
print("edges.data[x]", edges.data["x"]) # Get the characteristics of the edge
print("edges.src[x]", edges.src["x"]) # Get the characteristics of the source node of the edge
print("edges.dst[x]", edges.dst["x"]) # Get the characteristics of the target node of the edge
# Return the message characteristics that need to be delivered
return {
'e_data': edges.data['x'], "e_src": edges.src["x"], "e_dst": edges.dst["x"]}
def reduce_func(nodes):
print("+"*20)
# Get the sum of the edge features of each node and store it in the node e_data in
data_sum = th.sum(nodes.mailbox["e_data"], dim=1)
# Get the source node characteristics of each edge and sum them and store them in the src_data in
src_sum = th.sum(nodes.mailbox["e_src"], dim=1)
# Get the target node characteristics of each edge and sum them and store them in the src_data in
dst_sum = th.sum(nodes.mailbox["e_dst"], dim=1)
print("nodes_e_data", data_sum)
print("nodes_e_src", src_sum)
print("nodes_e_dst", dst_sum)
return {
"data_sum":data_sum, "src_sum":src_sum, "dst_sum":dst_sum}
def apply_node_func(nodes):
# take x use data_sum to update
return {
'x': nodes.data["data_sum"]}
g.update_all(message_func, reduce_func, apply_node_func)
print(g.ndata["x"])
Let's take a look at the updated with edge features x The output of features is good .
g.ndata[“x”] Output :
tensor([[1.0000, 1.0000],
[2.1000, 2.1000],
[2.0000, 2.0000],
[2.0000, 2.0000],
[3.0000, 3.0000]])
It shows that the features of the edges of the penetration of each node are summed and converged to x Features on , It's very easy to understand .
The other two are probably the sum of the characteristics of the target node of the same edge of the source node to update the node characteristics
And find the sum of the characteristics of the source node of the same edge of the target node to update the node characteristics
Maybe it's a little windy , But look at the results of the code and then combine it with the diagram , There will be no running results here .
This is to demonstrate the steps , Generally no longer update_all Set the update function in .
graph.apply_edges
stay DGL in , You can also... Without involving messaging , adopt apply_edges() Call side by side calculation alone . apply_edges() The argument to is a message function . And by default , This interface will update all edges .
import dgl
import torch
g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4]))
g.ndata['h'] = torch.ones(5, 2)
# Three ways
# def add(edges):
# return{"x": edges.src['h'] + edges.dst['h']}
# g.apply_edges(add)
# g.apply_edges(lambda edges: {'x' : edges.src['h'] + edges.dst['h']}) # The two are equivalent
g.apply_edges(fn.u_add_v('h', 'h', 'x')) # Use built-in functions , It's the best
print(g.edata['x'])
When calculating the attention mechanism of the graph , You can calculate the attention weight of each side first , At this time, the direct edge calculation is sufficient .
Reference resources
https://docs.dgl.ai/guide_cn/message.html
https://docs.dgl.ai/guide_cn/message-api.html
边栏推荐
- Should the dependency given by the official website be Flink SQL connector MySQL CDC, with dependency added
- 字符串、、
- web安全
- 【SWT组件】内容滚动组件 ScrolledComposite
- 2022年危险化学品经营单位主要负责人特种作业证考试题库及答案
- DDOS攻击原理,被ddos攻击的现象
- Learning notes 5 - high precision map solution
- Talk about the understanding of fault tolerance mechanism and state consistency in Flink framework
- 2022年T电梯修理操作证考试题及答案
- Golang应用专题 - channel
猜你喜欢
Learning Note 6 - satellite positioning technology (Part 1)
谈谈对Flink框架中容错机制及状态的一致性的理解
Comparative learning in the period of "arms race"
微信核酸检测预约小程序系统毕业设计毕设(8)毕业设计论文模板
[dark horse morning post] Luo Yonghao responded to ridicule Oriental selection; Dong Qing's husband Mi Chunlei was executed for more than 700million; Geely officially acquired Meizu; Huawei releases M
2022年危险化学品生产单位安全生产管理人员特种作业证考试题库模拟考试平台操作
关于vray5.2怎么关闭日志窗口
How does redis implement multiple zones?
关于vray 5.2的使用(自研笔记)
[observation] with the rise of the "independent station" model of cross-border e-commerce, how to seize the next dividend explosion era?
随机推荐
Based on shengteng AI Yisa technology, it launched a full target structured solution for video images, reaching the industry-leading level
Common functions of go-2-vim IDE
Sqlserver regularly backup database and regularly kill database deadlock solution
Coneroller执行时候的-26374及-26377错误
脚手架开发基础
流程控制、
What are the top ten securities companies? Is it safe to open an account online?
小程序框架Taro
dsPIC33EP 时钟初始化程序
分享.NET 轻量级的ORM
埋点111
正则表达式
DDOS攻击原理,被ddos攻击的现象
风控模型启用前的最后一道工序,80%的童鞋在这都踩坑
2021 Shandong provincial competition question bank topic capture
2022年危险化学品生产单位安全生产管理人员特种作业证考试题库模拟考试平台操作
2022年T电梯修理操作证考试题及答案
websocket
Go-3-the first go program
AtCoder Beginner Contest 258「ABCDEFG」