当前位置:网站首页>GNN hands on practice (II): reproduction graph attention network gat
GNN hands on practice (II): reproduction graph attention network gat
2022-06-30 10:09:00 【Si Xi is towering】
Reference paper :Graph Attention Networks
One . Preface
GAT( Picture attention network ) yes GNNs Important in SOTA Model , The model is from airspace Angle to define , Able to use Messaging paradigm To explain .GAT And GCN The biggest difference is that it In the process of graph node neighborhood aggregation, attention mechanism is introduced to calculate the importance of neighbors to the nodes currently being aggregated . This article includes : Figure introduction to the architecture of attention network 、 be based on PyG To reproduce GAT Model .
Two .GAT Architecture introduction
As described in Section 1 ,GAT The greatest contribution is to introduce the attention mechanism into graph convolution , Here is the architecture of the model :

As can be seen from the figure ,GAT In the polymerization process , Need to compute 1 The importance of the neighbor node to the current node , namely α ⃗ i j \vec{\alpha}_{ij} αij, And then we do the weighted sum . The following is the mathematical form of the messaging paradigm corresponding to the model :
h i ( l + 1 ) = ∑ j ∈ N ( i ) α i , j W ( l ) h j ( l ) α i j l = softmax j ( e i j l ) = exp ( e i j l ) ∑ k ∈ N i exp ( e i k l ) e i j l = L e a k y R e L U ( a ( W h i ( l ) ∥ W h j ( l ) ) ) \begin{aligned} h_i^{(l+1)} & = \sum_{j\in \mathcal{N}(i)} \alpha_{i,j} W^{(l)} h_j^{(l)} \\ \alpha_{ij}^{l} &=\operatorname{softmax}_{j}\left(e_{ij}^{l}\right)=\frac{\exp \left(e_{ij}^{l}\right)}{\sum_{k \in \mathcal{N}_{i}} \exp \left(e_{ik}^{l}\right)} \\ e_{ij}^{l} & = \mathrm{LeakyReLU}\left(a (W h_{i}^{(l)} \| W h_{j}^{(l)})\right)\end{aligned} hi(l+1)αijleijl=j∈N(i)∑αi,jW(l)hj(l)=softmaxj(eijl)=∑k∈Niexp(eikl)exp(eijl)=LeakyReLU(a(Whi(l)∥Whj(l)))
among h i ( l ) h_i^{(l)} hi(l) and h j ( l ) h_j^{(l)} hj(l) yes GAT In the model l l l Node characteristics of layer , a a a It's a single-layer feedforward neural network , ∥ \| ∥ Represents the splicing operation of vectors , W W W It's the weight matrix , N ( i ) \mathcal{N}(i) N(i) Representation node i i i Of 1 Order neighborhood .
in addition , The author uses Long attention (Multi-Head Attention) Mechanism , That is, the above aggregation formula can be extended to the following form :
h i ( l + 1 ) = ∥ k = 1 K σ ( ∑ j ∈ N i α i j k W k h j ( l ) ) h_{i}^{(l+1)}=\|_{k=1}^{K} \sigma\left(\sum_{j \in \mathcal{N}_{i}} \alpha_{ij}^{k} \mathbf{W}^{k} h_{j}^{(l)}\right) hi(l+1)=∥k=1Kσ⎝⎛j∈Ni∑αijkWkhj(l)⎠⎞
among K K K Indicates the number of attention heads .
It should be noted that , If you use multiple attention mechanisms at the last level , be Use averaging instead of splicing , namely :
h i ( l + 1 ) = σ ( 1 K ∑ k = 1 K ∑ j ∈ N i α i j k W k h j ( l ) ) h_{i}^{(l+1)}=\sigma\left(\frac{1}{K} \sum_{k=1}^{K} \sum_{j \in \mathcal{N}_{i}} \alpha_{i j}^{k} \mathbf{W}^{k} h_{j}^{(l)}\right) hi(l+1)=σ⎝⎛K1k=1∑Kj∈Ni∑αijkWkhj(l)⎠⎞
3、 ... and . Reproduction work
3.1 Reappear GAT Model
about GAT Model , In this paper PyG To reproduce it . If yes PyG For those who don't know much about how to implement message passing neural network, please refer to the previous blog posts of bloggers 《PyG course (6): Custom messaging network 》.
GAT Model contains The attention convolution layer of two figures Of GAT, The nonlinear activation between two convolutions is ELU, The source code of the model is as follows :
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import softmax, add_remaining_self_loops
class GATConv(MessagePassing):
def __init__(self, in_feats, out_feats, alpha, drop_prob, num_heads):
super().__init__(aggr="add")
self.drop_prob = drop_prob
self.num_heads = num_heads
self.out_feats = out_feats // num_heads
self.lin = nn.Linear(in_feats, self.out_feats *
self.num_heads, bias=False)
self.a = nn.Linear(2*self.out_feats, 1)
self.leakrelu = nn.LeakyReLU(alpha)
def forward(self, x, edge_index):
edge_index, _ = add_remaining_self_loops(edge_index)
# Wh
h = self.lin(x)
h_prime = self.propagate(edge_index, x=h)
return h_prime
def message(self, x_i, x_j, edge_index_i):
x_i = x_i.view(-1, self.num_heads, self.out_feats)
x_j = x_j.view(-1, self.num_heads, self.out_feats)
# a(Wh_i, Wh_j)
e = self.a(torch.cat([x_i, x_j], dim=-1)).permute(1, 0, 2)
# LeakReLU(a(Wh_i, Wh_j))
e = self.leakrelu(e.permute(1, 0, 2))
# softmax(e_{ij})
alpha = softmax(e, edge_index_i)
alpha = F.dropout(alpha, self.drop_prob, self.training)
return (x_j * alpha).view(x_j.size(0), -1)
class GAT(nn.Module):
def __init__(self, in_feats, hidden_feats, y_num,
alpha=0.2, drop_prob=0., num_heads=[1, 1]):
super().__init__()
self.drop_prob = drop_prob
self.gatconv1 = GATConv(
in_feats, hidden_feats, alpha, drop_prob, num_heads[0])
self.gatconv2 = GATConv(
hidden_feats, y_num, alpha, drop_prob, num_heads[1])
def forward(self, x, edge_index):
x = self.gatconv1(x, edge_index)
x = F.elu(x)
x = F.dropout(x, self.drop_prob, self.training)
out = self.gatconv2(x, edge_index)
return F.log_softmax(out, dim=1)
if __name__ == "__main__":
conv = GATConv(in_feats=64, out_feats=64, alpha=0.2,
num_heads=8, drop_prob=0.2)
x = torch.rand(4, 64)
edge_index = torch.tensor(
[[0, 1, 1, 2, 0, 2, 0, 3], [1, 0, 2, 1, 2, 0, 3, 0]], dtype=torch.long)
x = conv(x, edge_index)
print(x.shape)
3.2 experiment
3.2.1 Experimental data sets
This article takes Cora Take the data set as an example , This data set is a paper reference network , Contains 2708 Papers , Each paper is written by 1433 The word vector of dimension represents . The network cited in this paper contains 5429 side , It indicates the citation relationship between papers . The papers in the dataset are divided into 7 Categories .
3.2.2 Superparametric configuration
The super parameters of this experiment come from GAT The paper , The details are shown in the table below :
| Parameter | Value |
|---|---|
| dropout rate | 0.6 |
| weight_decay | 5e-4 |
| learning rate | 0.01 |
| hidden size | 64 |
| num_head | [8,1], That is, the first convolution layer contains 8 A head of attention , The second layer contains 1 individual |
| epochs | 300 |
3.2.3 Experimental results show
During the experiment , Use the training set to update the parameters of the model , Then use the validation set to filter the best model , Finally, the best model is evaluated on the test set . The screenshot of the running result of an experiment is as follows :

From the final results , The result is similar to that of the corresponding data set reported in the paper . Of course , Limited to time reasons , There is no detailed adjustment and participation in doing some visualization related work , Interested partners can study by themselves .
Four . Conclusion
Complete project Github Address :GAT
The above is the whole content of this article , If you think it's good, just like it or pay attention to the blogger , Your support is the motivation for bloggers to continue their creation , Of course, if you have any questions, please criticize and correct !!!
边栏推荐
- About the split and join operations of strings
- 机械臂速成小指南(四):机械臂关键部件之减速机
- 事件对象的说明》
- 100个句子记完7000个雅思词汇,实际只有1043个词汇(包括 I and you 等简单词汇)
- Magnetic levitation 3D lamp
- unable to convert expression into double array
- Hospital integration platform super fusion infrastructure transformation scheme
- G 代码解释|最重要的 G 代码命令列表
- 7. development of mobile login function
- MIT-6874-Deep Learning in the Life Sciences Week5
猜你喜欢

CRF (conditional random field) learning summary

Appium automation test foundation - 12 Introduction to appium automated testing framework

Quick completion guide for mechanical arm (V): end effector

Installation and use

JUL简介

Object detection yolov5 open source project debugging

UAV project tracking record 83 -- PCB diagram completion

采坑:Didn‘t receive robot state (joint angles) with recent timestamp within 1 seconds.

Cloud native database

Machine learning note 9: prediction model optimization (to prevent under fitting and over fitting problems)
随机推荐
How do databases go to the enterprise cloud? Click to view the answer
Detailed explanation of SolidWorks mass characteristics (inertia tensor, moment of inertia, inertia spindle)
无人机项目跟踪记录八十三---pcb图完成
How to reduce the delay in live broadcast in the development of live broadcast source code with goods?
Use and description of event delegation
Dart development skills
1033 To Fill or Not to Fill
正则表达式基础
JUL简介
事件对象的说明》
MIT-6874-Deep Learning in the Life Sciences Week4
跳跃表介绍
Automated stock trading ensemble strategy based on Reinforcement Learning
G code explanation | list of the most important G code commands
【AGC】构建服务3-认证服务示例
Financial private cloud infrastructure scheme evaluation (Architecture and storage)
Hospital integration platform super fusion infrastructure transformation scheme
‘Failed to fetch current robot state‘ when using the ‘plan_ kinematic_ path‘ service #868
After recording 7000 IELTS words in 100 sentences, there are only 1043 words (including simple words such as I and you)
Force buckle 428 Serialize and deserialize n-tree DFS