当前位置:网站首页>GNN hands on practice (II): reproduction graph attention network gat
GNN hands on practice (II): reproduction graph attention network gat
2022-06-30 10:09:00 【Si Xi is towering】
Reference paper :Graph Attention Networks
One . Preface
GAT( Picture attention network ) yes GNNs Important in SOTA Model , The model is from airspace Angle to define , Able to use Messaging paradigm To explain .GAT And GCN The biggest difference is that it In the process of graph node neighborhood aggregation, attention mechanism is introduced to calculate the importance of neighbors to the nodes currently being aggregated . This article includes : Figure introduction to the architecture of attention network 、 be based on PyG To reproduce GAT Model .
Two .GAT Architecture introduction
As described in Section 1 ,GAT The greatest contribution is to introduce the attention mechanism into graph convolution , Here is the architecture of the model :

As can be seen from the figure ,GAT In the polymerization process , Need to compute 1 The importance of the neighbor node to the current node , namely α ⃗ i j \vec{\alpha}_{ij} αij, And then we do the weighted sum . The following is the mathematical form of the messaging paradigm corresponding to the model :
h i ( l + 1 ) = ∑ j ∈ N ( i ) α i , j W ( l ) h j ( l ) α i j l = softmax j ( e i j l ) = exp ( e i j l ) ∑ k ∈ N i exp ( e i k l ) e i j l = L e a k y R e L U ( a ( W h i ( l ) ∥ W h j ( l ) ) ) \begin{aligned} h_i^{(l+1)} & = \sum_{j\in \mathcal{N}(i)} \alpha_{i,j} W^{(l)} h_j^{(l)} \\ \alpha_{ij}^{l} &=\operatorname{softmax}_{j}\left(e_{ij}^{l}\right)=\frac{\exp \left(e_{ij}^{l}\right)}{\sum_{k \in \mathcal{N}_{i}} \exp \left(e_{ik}^{l}\right)} \\ e_{ij}^{l} & = \mathrm{LeakyReLU}\left(a (W h_{i}^{(l)} \| W h_{j}^{(l)})\right)\end{aligned} hi(l+1)αijleijl=j∈N(i)∑αi,jW(l)hj(l)=softmaxj(eijl)=∑k∈Niexp(eikl)exp(eijl)=LeakyReLU(a(Whi(l)∥Whj(l)))
among h i ( l ) h_i^{(l)} hi(l) and h j ( l ) h_j^{(l)} hj(l) yes GAT In the model l l l Node characteristics of layer , a a a It's a single-layer feedforward neural network , ∥ \| ∥ Represents the splicing operation of vectors , W W W It's the weight matrix , N ( i ) \mathcal{N}(i) N(i) Representation node i i i Of 1 Order neighborhood .
in addition , The author uses Long attention (Multi-Head Attention) Mechanism , That is, the above aggregation formula can be extended to the following form :
h i ( l + 1 ) = ∥ k = 1 K σ ( ∑ j ∈ N i α i j k W k h j ( l ) ) h_{i}^{(l+1)}=\|_{k=1}^{K} \sigma\left(\sum_{j \in \mathcal{N}_{i}} \alpha_{ij}^{k} \mathbf{W}^{k} h_{j}^{(l)}\right) hi(l+1)=∥k=1Kσ⎝⎛j∈Ni∑αijkWkhj(l)⎠⎞
among K K K Indicates the number of attention heads .
It should be noted that , If you use multiple attention mechanisms at the last level , be Use averaging instead of splicing , namely :
h i ( l + 1 ) = σ ( 1 K ∑ k = 1 K ∑ j ∈ N i α i j k W k h j ( l ) ) h_{i}^{(l+1)}=\sigma\left(\frac{1}{K} \sum_{k=1}^{K} \sum_{j \in \mathcal{N}_{i}} \alpha_{i j}^{k} \mathbf{W}^{k} h_{j}^{(l)}\right) hi(l+1)=σ⎝⎛K1k=1∑Kj∈Ni∑αijkWkhj(l)⎠⎞
3、 ... and . Reproduction work
3.1 Reappear GAT Model
about GAT Model , In this paper PyG To reproduce it . If yes PyG For those who don't know much about how to implement message passing neural network, please refer to the previous blog posts of bloggers 《PyG course (6): Custom messaging network 》.
GAT Model contains The attention convolution layer of two figures Of GAT, The nonlinear activation between two convolutions is ELU, The source code of the model is as follows :
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import softmax, add_remaining_self_loops
class GATConv(MessagePassing):
def __init__(self, in_feats, out_feats, alpha, drop_prob, num_heads):
super().__init__(aggr="add")
self.drop_prob = drop_prob
self.num_heads = num_heads
self.out_feats = out_feats // num_heads
self.lin = nn.Linear(in_feats, self.out_feats *
self.num_heads, bias=False)
self.a = nn.Linear(2*self.out_feats, 1)
self.leakrelu = nn.LeakyReLU(alpha)
def forward(self, x, edge_index):
edge_index, _ = add_remaining_self_loops(edge_index)
# Wh
h = self.lin(x)
h_prime = self.propagate(edge_index, x=h)
return h_prime
def message(self, x_i, x_j, edge_index_i):
x_i = x_i.view(-1, self.num_heads, self.out_feats)
x_j = x_j.view(-1, self.num_heads, self.out_feats)
# a(Wh_i, Wh_j)
e = self.a(torch.cat([x_i, x_j], dim=-1)).permute(1, 0, 2)
# LeakReLU(a(Wh_i, Wh_j))
e = self.leakrelu(e.permute(1, 0, 2))
# softmax(e_{ij})
alpha = softmax(e, edge_index_i)
alpha = F.dropout(alpha, self.drop_prob, self.training)
return (x_j * alpha).view(x_j.size(0), -1)
class GAT(nn.Module):
def __init__(self, in_feats, hidden_feats, y_num,
alpha=0.2, drop_prob=0., num_heads=[1, 1]):
super().__init__()
self.drop_prob = drop_prob
self.gatconv1 = GATConv(
in_feats, hidden_feats, alpha, drop_prob, num_heads[0])
self.gatconv2 = GATConv(
hidden_feats, y_num, alpha, drop_prob, num_heads[1])
def forward(self, x, edge_index):
x = self.gatconv1(x, edge_index)
x = F.elu(x)
x = F.dropout(x, self.drop_prob, self.training)
out = self.gatconv2(x, edge_index)
return F.log_softmax(out, dim=1)
if __name__ == "__main__":
conv = GATConv(in_feats=64, out_feats=64, alpha=0.2,
num_heads=8, drop_prob=0.2)
x = torch.rand(4, 64)
edge_index = torch.tensor(
[[0, 1, 1, 2, 0, 2, 0, 3], [1, 0, 2, 1, 2, 0, 3, 0]], dtype=torch.long)
x = conv(x, edge_index)
print(x.shape)
3.2 experiment
3.2.1 Experimental data sets
This article takes Cora Take the data set as an example , This data set is a paper reference network , Contains 2708 Papers , Each paper is written by 1433 The word vector of dimension represents . The network cited in this paper contains 5429 side , It indicates the citation relationship between papers . The papers in the dataset are divided into 7 Categories .
3.2.2 Superparametric configuration
The super parameters of this experiment come from GAT The paper , The details are shown in the table below :
| Parameter | Value |
|---|---|
| dropout rate | 0.6 |
| weight_decay | 5e-4 |
| learning rate | 0.01 |
| hidden size | 64 |
| num_head | [8,1], That is, the first convolution layer contains 8 A head of attention , The second layer contains 1 individual |
| epochs | 300 |
3.2.3 Experimental results show
During the experiment , Use the training set to update the parameters of the model , Then use the validation set to filter the best model , Finally, the best model is evaluated on the test set . The screenshot of the running result of an experiment is as follows :

From the final results , The result is similar to that of the corresponding data set reported in the paper . Of course , Limited to time reasons , There is no detailed adjustment and participation in doing some visualization related work , Interested partners can study by themselves .
Four . Conclusion
Complete project Github Address :GAT
The above is the whole content of this article , If you think it's good, just like it or pay attention to the blogger , Your support is the motivation for bloggers to continue their creation , Of course, if you have any questions, please criticize and correct !!!
边栏推荐
- Use and description of event delegation
- 【JVM】CMS简述
- Applying applet container technology to IOT ecological construction
- UAV project tracking record 83 -- PCB diagram completion
- Golang magic code
- Appium自动化测试基础 — 12.APPium自动化测试框架介绍
- OSError: [Errno 28] No space left on device
- Returnjson, which allows more custom data or class names to be returned
- AttributeError: ‘Version‘ object has no attribute ‘major‘
- Redis docker master-slave mode and sentinel
猜你喜欢

NLopt--非线性优化--原理介绍及使用方法

unable to convert expression into double array

力扣 428. 序列化和反序列化 N 叉树 DFS

文章内容无法复制复制不了

Network based dynamic routing protocol (OSPF)

Detailed explanation of SolidWorks mass characteristics (inertia tensor, moment of inertia, inertia spindle)

磁悬浮3D灯

Difference between bow and cbow

NFS shared services

JS obtient la chaîne spécifiée spécifiant la position du caractère & sous - chaîne spécifiant la plage de position du caractère 【 détails simples 】
随机推荐
Object detection yolov5 open source project debugging
浏览器复制的网址粘贴到文档是超链接
调试方法和技巧详解
2022第六季完美童模 合肥赛区 初赛圆满落幕
About the split and join operations of strings
qmlplugindump executable not found. It is required to generate the qmltypes file for VTK Qml
采坑:Didn‘t receive robot state (joint angles) with recent timestamp within 1 seconds.
log4j
What makes flutter special
[C language quick start] let you know C language and get started with zero basics ③
log4j
Nlopt -- Nonlinear Optimization -- principle introduction and application method
Practice of super integration and transformation of core production business of public funds
Right click to open CMD (command line)
Network based BGP
Critical applications and hyper converged infrastructure: the time has come
二极管如何工作?
100个句子记完7000个雅思词汇,实际只有1043个词汇(包括 I and you 等简单词汇)
Appium自动化测试基础 — adb shell 命令
【AGC】构建服务3-认证服务示例