当前位置:网站首页>Pyg builds heterogeneous graph attention network han to realize DBLP node prediction
Pyg builds heterogeneous graph attention network han to realize DBLP node prediction
2022-07-28 19:14:00 【Cyril_ KI】
Catalog
Preface
HAN The principle of is shown in :WWW 2019 | HAN: Heterogeneous graph attention networks .
Data processing
Import data :
path = os.path.abspath(os.path.dirname(os.getcwd())) + '\data\DBLP'
dataset = DBLP(path)
graph = dataset[0]
print(graph)
Output is as follows :
HeteroData(
author={
x=[4057, 334],
y=[4057],
train_mask=[4057],
val_mask=[4057],
test_mask=[4057]
},
paper={
x=[14328, 4231] },
term={
x=[7723, 50] },
conference={
num_nodes=20 },
(author, to, paper)={
edge_index=[2, 19645] },
(paper, to, author)={
edge_index=[2, 19645] },
(paper, to, term)={
edge_index=[2, 85810] },
(paper, to, conference)={
edge_index=[2, 14328] },
(term, to, paper)={
edge_index=[2, 85810] },
(conference, to, paper)={
edge_index=[2, 14328] }
)
You can find ,DBLP There are authors in the dataset (author)、 The paper (paper)、 The term (term) And meetings (conference) Four types of nodes .DBLP Contained in the 14328 Papers (paper), 4057 authors (author), 20 A meeting (conference), 7723 A term (term). The author is divided into four areas : database 、 data mining 、 machine learning 、 Information retrieval .
Mission : Yes author Nodes are classified , altogether 4 class .
because conference Nodes have no characteristics , Therefore, it is necessary to preset features :
graph['conference'].x = torch.ones((graph['conference'].num_nodes, 1))
all conference The characteristics of nodes are initialized to [1].
Get some useful data :
num_classes = torch.max(graph['author'].y).item() + 1
train_mask, val_mask, test_mask = graph['author'].train_mask, graph['author'].val_mask, graph['author'].test_mask
y = graph['author'].y
Model structures,
First import the package :
from torch_geometric.nn import HANConv
Model parameters :
- in_channels: Input channel , For example, the number of features of each node in the node classification , Generally set as -1.
- out_channels: Output channel , The last layer GCNConv The output channel of is the number of node categories ( Node classification ).
- heads: The number of heads in the multi head attention mechanism . It is worth noting that ,GANConv and GATConv What's different is ,GANConv The model is to flatten the result of long-term attention , Instead of going ahead concat operation .
- negative_slope:LeakyRELU Parameters of .
So the model is built as follows :
class HAN(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super(HAN, self).__init__()
# H, D = self.heads, self.out_channels // self.heads
self.conv1 = HANConv(in_channels, hidden_channels, graph.metadata(), heads=8)
self.conv2 = HANConv(hidden_channels, out_channels, graph.metadata(), heads=4)
def forward(self, data):
x_dict, edge_index_dict = data.x_dict, data.edge_index_dict
x = self.conv1(x_dict, edge_index_dict)
x = self.conv2(x, edge_index_dict)
x = F.softmax(x['author'], dim=1)
return x
Output the model :
model = HAN(-1, 64, num_classes).to(device)
HAN(
(conv1): HANConv(64, heads=8)
(conv2): HANConv(4, heads=4)
)
1. Forward propagation
Check the official documents HANConv Input and output requirements for :
You can find ,HANConv What needs to be input in is the node feature dictionary x_dict And adjacency dictionary edge_index_dict.
So there is :
x_dict, edge_index_dict = data.x_dict, data.edge_index_dict
x = self.conv1(x_dict, edge_index_dict)
At this time, we might as well output x['author'] And its size:
tensor([[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0969, 0.0601, 0.0000, ..., 0.0000, 0.0000, 0.0251],
[0.0000, 0.0000, 0.0000, ..., 0.1288, 0.0000, 0.0602],
...,
[0.0000, 0.0000, 0.0000, ..., 0.0096, 0.0000, 0.0240],
[0.0000, 0.0000, 0.0000, ..., 0.0096, 0.0000, 0.0240],
[0.0801, 0.0558, 0.0837, ..., 0.0277, 0.0347, 0.0000]],
device='cuda:0', grad_fn=<SumBackward1>)
torch.Size([4057, 64])
At this time x altogether 4057 That's ok , Each line represents a author The state vector of the node after the first convolution update .
So in the same way , because :
x = self.conv2(x, edge_index_dict)
So after the second convolution x['author'] Of size Should be :
torch.Size([4057, 4])
each author The dimension of the node is 4 State vector of .
Because we need to 4 classification , So finally, we need to add a softmax:
x = F.softmax(x, dim=1)
dim=1 Represents the operation on each line , Finally, the sum of each line is 1, That is, the probability that the node is of each class . Output at this time x:
tensor([[0.2591, 0.2539, 0.2435, 0.2435],
[0.3747, 0.2067, 0.2029, 0.2157],
[0.2986, 0.2338, 0.2338, 0.2338],
...,
[0.2740, 0.2453, 0.2403, 0.2403],
[0.2740, 0.2453, 0.2403, 0.2403],
[0.3414, 0.2195, 0.2195, 0.2195]], device='cuda:0',
grad_fn=<SoftmaxBackward0>)
2. Back propagation
During the training , We first use forward propagation to calculate the output :
f = model(graph)
f That is, the final Every node Of 4 Probability values , But in practice , We just need to calculate the loss of the training set , So the loss function is written like this :
loss = loss_function(f[train_mask], y[train_mask])
Then calculate the gradient , Reverse update !
3. Training
When training, return to the model with the best performance on the verification set :
def train():
model = HAN(-1, 64, num_classes).to(device)
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-4)
loss_function = torch.nn.CrossEntropyLoss().to(device)
min_epochs = 5
best_val_acc = 0
final_best_acc = 0
model.train()
for epoch in range(100):
f = model(graph)
loss = loss_function(f[train_mask], y[train_mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()
# validation
val_acc, val_loss = test(model, val_mask)
test_acc, test_loss = test(model, test_mask)
if epoch + 1 > min_epochs and val_acc > best_val_acc:
best_val_acc = val_acc
final_best_acc = test_acc
print('Epoch {:3d} train_loss {:.5f} val_acc {:.3f} test_acc {:.3f}'
.format(epoch, loss.item(), val_acc, test_acc))
return final_best_acc
4. test
def test(model, mask):
model.eval()
with torch.no_grad():
out = model(graph)
loss_function = torch.nn.CrossEntropyLoss().to(device)
loss = loss_function(out[mask], y[mask])
_, pred = out.max(dim=1)
correct = int(pred[mask].eq(y[mask]).sum().item())
acc = correct / int(test_mask.sum())
return acc, loss.item()
experimental result
The dataset uses DBLP The Internet , Training 100 round , The classification accuracy is 78.54%:
HAN Accuracy: 0.7853853239177156
边栏推荐
- More loading in applets (i.e. list paging)
- Is the software testing industry really saturated?
- How to solve the problem that the win11 computer camera cannot be seen when it is turned on and the display screen is black?
- vim学习手册
- Bm11 list addition (II)
- OAI L3 and L2 interface analysis
- BM14 链表的奇偶重排
- QT - CPP database operation
- Implementation of grayscale publishing with haproxy
- 2022年牛客多校第2场 J . Link with Arithmetic Progression (三分+枚举)
猜你喜欢

Win11系统svchost.exe一直在下载怎么办?

Fundamentals of software testing and development | practical development of several tools in testing and development

1、 My first wechat applet

Pytorch GPU yolov5 reports an error

The wechat installation package has expanded 575 times in 11 years, and the up owner: "98% of the documents are garbage"; Apple App store was exposed to a large number of pornographic apps; Four techn

What does real HTAP mean to users and developers?

Introduction and advanced level of MySQL (II)

BM11 链表相加(二)

BM16 delete duplicate elements in the ordered linked list -ii

微信安装包11年膨胀575倍,UP主:“98%的文件是垃圾”;苹果应用商店被曝大量色情App;四大科技巨头呼吁废除闰秒|极客头条...
随机推荐
Why app uses JSON protocol to interact with server: serialization related knowledge
Pytest custom hook function
How long does software testing training take?
2、 Uni app login function page Jump
A priori, a posteriori, likelihood
How big is it suitable for learning software testing?
[GXYCTF2019]StrongestMind
Getting started with QT & OpenGL
Introduction and advanced MySQL (7)
uwb模块实现人员精确定位,超宽带脉冲技术方案,实时厘米级定位应用
Parity rearrangement of Bm14 linked list
Leetcode skimming - super power 372 medium
Mongodb database shell command execution
4 年后,Debian 终夺回“debian.community”域名!
N32替换STM32,这些细节别忽略!
Cause analysis and solution of video jam after easycvr is connected to the device
Is the software testing training institution reliable?
【雷达】基于核聚类实现雷达信号在线分选附matlab代码
What if svchost.exe of win11 system has been downloading?
How new people get started learning software testing