当前位置:网站首页>torch_ geometric message passing network
torch_ geometric message passing network
2022-06-12 13:06:00 【Dongxuan】

1. Node update formula review
In the previous section, we just called the regular GCN layer. But sometimes we need to define new ones according to our own needs GCN Convolution operation . At this time, you need to customize the network .
So this section mainly looks at GCNlayer How to spread the neighborhood information inside , Or how to change .
The above formula is a common node update formula . among k Is the sequence number of the layer .r Update formula for ,
It is the aggregate representation of neighbor information and edge information . A square represents gathering neighbor information into a vector .i Represents the serial number of the current node .j Is the sequence number of the neighbor node .
2. Some basic classes
Message Passing The basic class can satisfy the implementation of the above formula , We can customize Neighbor and edge fusion functions message() And the update formula of the current node update() for example add mean max

MessagePassing(aggr = "add", flow="source_to_target", node_dim=2)
I think it should refer to the above formula ,node_dim Refers to the dimension in which it should be aggregated . Of course, it is the dimension of the number of nodes , The last dimension is the characteristic dimension of the node, so the final choice is -2
Then it introduces some message and update Function details . It seems that the specific parameter passing depends on how to call in the example .
Let us verify this by re-implementing two popular GNN variants, the GCN layer from Kipf and Welling and the EdgeConv layer from Wang et al..
3. Achieve one GCN Layer layer

The above is the update formula to be implemented
We can see that the sequence numbers of the neighbor nodes and their own nodes are merged , That is, there is a self ring
Add self-loops to the adjacency matrix. First, add the self ring edge
Linearly transform node feature matrix. W Matrix implementation of , Linear layer
Compute normalization coefficients. In and out calculations
Normalize node features in ϕ. Form the normalization term ( The denominator )
Sum up neighboring node features (
"add"aggregation). Add the neighbor nodes and their own characteristicsApply a final bias vector. Apply one b vector .
among 4,5 The above basic classes will be involved MessagePassing. 1-3 It's just the feature preprocessing stage

import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super().__init__(aggr='add') # "Add" aggregation (Step 5).
self.lin = Linear(in_channels, out_channels, bias=False) # bias Just add... At the end
self.bias = Parameter(torch.Tensor(out_channels)) # This is the last addition bias, Dimension is the output node feature dimension
self.reset_parameters()# Reset the learned parameters and bias
def reset_parameters(self):
self.lin.reset_parameters()
self.bias.data.zero_()
def forward(self, x, edge_index):
# x has shape [N, in_channels] It's just one. mini batch Of Big picture .N It refers to the number of nodes ,in_channels Is the characteristic dimension of the input node
# edge_index has shape [2, E] Is the sparse connection matrix of edges
# Step 1: Add self-loops to the adjacency matrix.
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))# Direct use function , You can add... To the linked list of the original node side From the ring edge . The operation should be append Corresponding extension .x.size(0) It refers to the number of nodes in the large graph
# Step 2: Linearly transform node feature matrix.
x = self.lin(x)
# Step 3: Compute normalization.
row, col = edge_index # Follow the direction , Should be row It refers to the source node and the neighbor node ,col It refers to the current node ,row Is the beginning of the arrow ,col Is the end of the arrow
deg = degree(col, x.size(0), dtype=x.dtype)# So if it's a digraph , It should be in degree
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 # It refers to the processing of some nodes without edge entry .1/0 Not enough 0 Express , That is, the normalization coefficient is 0 .
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] # According to edge Calculate the normalization coefficient of each edge .
# Step 4-5: Start propagating messages.
out = self.propagate(edge_index, x=x, norm=norm)
# What we should pay attention to here is propagate Parameter dimension of .edge_index [2, E], x still [N, out_channels],norm yes [E, ]
# Step 6: Apply a final bias vector.
out += self.bias
return out
def message(self, x_j, norm):
# x_j has shape [E, out_channels]
# Step 4: Normalize node features.
return norm.view(-1, 1) * x_jWe then call propagate(), which internally calls message(), aggregate() and update().
propagate It will automatically call message 、aggregate 、 update Functions .
The outer layer calls the GCN Conv class
conv = GCNConv(16, 32)
x = conv(x, edge_index)边栏推荐
- 403 you don't have permission to access this resource
- [cloud native | kubernetes] kubernetes networkpolicy
- Experience and learning path of introductory deep learning and machine learning
- import torch_geometric 加载一些常见数据集
- list和dict的应用
- Install MySQL database independently on Debian 10
- 构建嵌入式系统软件开发环境-建立交叉编译环境
- 干货满满,这些知识你必须拿下
- 【微信小程序开发】第1篇:开发工具安装及程序配置
- 嵌入式系统硬件构成-嵌入式系统硬件体系结构
猜你喜欢

【刷题篇】抽牌获胜的概率

Introduction to application design scheme of intelligent garbage can voice chip, wt588f02b-8s

Array -- seven array topics with double pointer technique
![[wechat applet development] Part 1: development tool installation and program configuration](/img/a8/f4dcbde295ba7cf738d878464b3af0.png)
[wechat applet development] Part 1: development tool installation and program configuration

看完这一篇就够了,web中文开发

Experience and learning path of introductory deep learning and machine learning

How to balance multiple losses in deep learning?

微信web开发者工具使用教程,web开发问题

torch_geometric message passing network

Build an embedded system software development environment - build a cross compilation environment
随机推荐
It is enough to read this article. Web Chinese development
【微信小程序开发】第1篇:开发工具安装及程序配置
VTK three views
442 authors, 100 pages! It took Google 2 years to release the new benchmark big bench | open source
Vant tab bar + pull-up loading + pull-down refresh demo van tabs + van pull refresh + van list demo
Binary tree (serialization)
unittest框架
Redis消息队列重复消费问题
What is the function tag? Article to understand its role and its best practices
Openmax (OMX) framework
About paiwen
Theoretical knowledge of improved DH parameters and standard DH parameters of manipulator
【刷题篇】超级洗衣机
[cloud native | kubernetes] learn more about ingress
Binary tree (thoughts)
[wechat applet development] Part 1: development tool installation and program configuration
itk itk::BSplineDeformableTransform
B站分布式KV存储混沌工程实践
Newton method for solving roots of polynomials
嵌入式系统硬件构成-基于ARM的嵌入式开发板介绍