当前位置:网站首页>Pyg builds heterogeneous graph attention network han to realize DBLP node prediction
Pyg builds heterogeneous graph attention network han to realize DBLP node prediction
2022-07-28 19:14:00 【Cyril_ KI】
Catalog
Preface
HAN The principle of is shown in :WWW 2019 | HAN: Heterogeneous graph attention networks .
Data processing
Import data :
path = os.path.abspath(os.path.dirname(os.getcwd())) + '\data\DBLP'
dataset = DBLP(path)
graph = dataset[0]
print(graph)
Output is as follows :
HeteroData(
author={
x=[4057, 334],
y=[4057],
train_mask=[4057],
val_mask=[4057],
test_mask=[4057]
},
paper={
x=[14328, 4231] },
term={
x=[7723, 50] },
conference={
num_nodes=20 },
(author, to, paper)={
edge_index=[2, 19645] },
(paper, to, author)={
edge_index=[2, 19645] },
(paper, to, term)={
edge_index=[2, 85810] },
(paper, to, conference)={
edge_index=[2, 14328] },
(term, to, paper)={
edge_index=[2, 85810] },
(conference, to, paper)={
edge_index=[2, 14328] }
)
You can find ,DBLP There are authors in the dataset (author)、 The paper (paper)、 The term (term) And meetings (conference) Four types of nodes .DBLP Contained in the 14328 Papers (paper), 4057 authors (author), 20 A meeting (conference), 7723 A term (term). The author is divided into four areas : database 、 data mining 、 machine learning 、 Information retrieval .
Mission : Yes author Nodes are classified , altogether 4 class .
because conference Nodes have no characteristics , Therefore, it is necessary to preset features :
graph['conference'].x = torch.ones((graph['conference'].num_nodes, 1))
all conference The characteristics of nodes are initialized to [1].
Get some useful data :
num_classes = torch.max(graph['author'].y).item() + 1
train_mask, val_mask, test_mask = graph['author'].train_mask, graph['author'].val_mask, graph['author'].test_mask
y = graph['author'].y
Model structures,
First import the package :
from torch_geometric.nn import HANConv
Model parameters :
- in_channels: Input channel , For example, the number of features of each node in the node classification , Generally set as -1.
- out_channels: Output channel , The last layer GCNConv The output channel of is the number of node categories ( Node classification ).
- heads: The number of heads in the multi head attention mechanism . It is worth noting that ,GANConv and GATConv What's different is ,GANConv The model is to flatten the result of long-term attention , Instead of going ahead concat operation .
- negative_slope:LeakyRELU Parameters of .
So the model is built as follows :
class HAN(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super(HAN, self).__init__()
# H, D = self.heads, self.out_channels // self.heads
self.conv1 = HANConv(in_channels, hidden_channels, graph.metadata(), heads=8)
self.conv2 = HANConv(hidden_channels, out_channels, graph.metadata(), heads=4)
def forward(self, data):
x_dict, edge_index_dict = data.x_dict, data.edge_index_dict
x = self.conv1(x_dict, edge_index_dict)
x = self.conv2(x, edge_index_dict)
x = F.softmax(x['author'], dim=1)
return x
Output the model :
model = HAN(-1, 64, num_classes).to(device)
HAN(
(conv1): HANConv(64, heads=8)
(conv2): HANConv(4, heads=4)
)
1. Forward propagation
Check the official documents HANConv Input and output requirements for :
You can find ,HANConv What needs to be input in is the node feature dictionary x_dict And adjacency dictionary edge_index_dict.
So there is :
x_dict, edge_index_dict = data.x_dict, data.edge_index_dict
x = self.conv1(x_dict, edge_index_dict)
At this time, we might as well output x['author'] And its size:
tensor([[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0969, 0.0601, 0.0000, ..., 0.0000, 0.0000, 0.0251],
[0.0000, 0.0000, 0.0000, ..., 0.1288, 0.0000, 0.0602],
...,
[0.0000, 0.0000, 0.0000, ..., 0.0096, 0.0000, 0.0240],
[0.0000, 0.0000, 0.0000, ..., 0.0096, 0.0000, 0.0240],
[0.0801, 0.0558, 0.0837, ..., 0.0277, 0.0347, 0.0000]],
device='cuda:0', grad_fn=<SumBackward1>)
torch.Size([4057, 64])
At this time x altogether 4057 That's ok , Each line represents a author The state vector of the node after the first convolution update .
So in the same way , because :
x = self.conv2(x, edge_index_dict)
So after the second convolution x['author'] Of size Should be :
torch.Size([4057, 4])
each author The dimension of the node is 4 State vector of .
Because we need to 4 classification , So finally, we need to add a softmax:
x = F.softmax(x, dim=1)
dim=1 Represents the operation on each line , Finally, the sum of each line is 1, That is, the probability that the node is of each class . Output at this time x:
tensor([[0.2591, 0.2539, 0.2435, 0.2435],
[0.3747, 0.2067, 0.2029, 0.2157],
[0.2986, 0.2338, 0.2338, 0.2338],
...,
[0.2740, 0.2453, 0.2403, 0.2403],
[0.2740, 0.2453, 0.2403, 0.2403],
[0.3414, 0.2195, 0.2195, 0.2195]], device='cuda:0',
grad_fn=<SoftmaxBackward0>)
2. Back propagation
During the training , We first use forward propagation to calculate the output :
f = model(graph)
f That is, the final Every node Of 4 Probability values , But in practice , We just need to calculate the loss of the training set , So the loss function is written like this :
loss = loss_function(f[train_mask], y[train_mask])
Then calculate the gradient , Reverse update !
3. Training
When training, return to the model with the best performance on the verification set :
def train():
model = HAN(-1, 64, num_classes).to(device)
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-4)
loss_function = torch.nn.CrossEntropyLoss().to(device)
min_epochs = 5
best_val_acc = 0
final_best_acc = 0
model.train()
for epoch in range(100):
f = model(graph)
loss = loss_function(f[train_mask], y[train_mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()
# validation
val_acc, val_loss = test(model, val_mask)
test_acc, test_loss = test(model, test_mask)
if epoch + 1 > min_epochs and val_acc > best_val_acc:
best_val_acc = val_acc
final_best_acc = test_acc
print('Epoch {:3d} train_loss {:.5f} val_acc {:.3f} test_acc {:.3f}'
.format(epoch, loss.item(), val_acc, test_acc))
return final_best_acc
4. test
def test(model, mask):
model.eval()
with torch.no_grad():
out = model(graph)
loss_function = torch.nn.CrossEntropyLoss().to(device)
loss = loss_function(out[mask], y[mask])
_, pred = out.max(dim=1)
correct = int(pred[mask].eq(y[mask]).sum().item())
acc = correct / int(test_mask.sum())
return acc, loss.item()
experimental result
The dataset uses DBLP The Internet , Training 100 round , The classification accuracy is 78.54%:
HAN Accuracy: 0.7853853239177156
边栏推荐
- C and SQL mixed programming, vs need to download what things
- EasyCVR接入设备后播放视频出现卡顿现象的原因分析及解决
- QT user defined control user guide (flying Qingyun)
- GC garbage collector details
- Why app uses JSON protocol to interact with server: serialization related knowledge
- Introduction and advanced MySQL (III)
- What if the content of software testing is too simple?
- How to choose between software testing and software development?
- Why did wechat change from "small and beautiful" to "big and fat" when it expanded 575 times in 11 years?
- Fundamentals of software testing and development | practical development of several tools in testing and development
猜你喜欢

GPIO port configuration of K60

When unity customizes the editor, let the subclass inherit the inspector display effect of the parent class

As for the white box test, you have to be skillful in these skills~

pytest 自定义HOOK函数

When the new version of easycvr is linked at the same level, the subordinate platform passes up the cause analysis of the incomplete display of the hierarchical directory

New progress in the implementation of the industry | the openatom openharmony sub forum of the 2022 open atom global open source summit was successfully held

The wechat installation package has expanded 575 times in 11 years, and the up owner: "98% of the documents are garbage"; Apple App store was exposed to a large number of pornographic apps; Four techn

Win11电脑摄像头打开看不见,显示黑屏如何解决?

Self cultivation of Electronic Engineers - when a project is developed

【物理应用】水下浮动风力涡轮机的尾流诱导动态模拟风场附matlab代码
随机推荐
Bm11 list addition (II)
2022杭电多校第二场1011 DOS Card(线段树)
What is the future of software testing? How to learn?
优麒麟系统安装BeyondComare
If you want to change to it, does it really matter if you don't have a major?
vim学习手册
How big is it suitable for learning software testing?
QT widget promoted to QWidget
JVM four reference types
关于ASM冗余问题
Is there a future for changing careers in learning software testing?
What if svchost.exe of win11 system has been downloading?
Efficiency comparison of JS array splicing push() concat() methods
Why did wechat change from "small and beautiful" to "big and fat" when it expanded 575 times in 11 years?
AI 改变千行万业,开发者如何投身 AI 语音新“声”态
Special Lecture 6 tree DP learning experience (long-term update)
QT running image
Minio distributed file system learning notes
DevCon. Exe export output to the specified file
Software testing dry goods