当前位置:网站首页>PyG搭建异质图注意力网络HAN实现DBLP节点预测
PyG搭建异质图注意力网络HAN实现DBLP节点预测
2022-07-28 17:08:00 【Cyril_KI】
前言
HAN的原理请见:WWW 2019 | HAN:异质图注意力网络。
数据处理
导入数据:
path = os.path.abspath(os.path.dirname(os.getcwd())) + '\data\DBLP'
dataset = DBLP(path)
graph = dataset[0]
print(graph)
输出如下:
HeteroData(
author={
x=[4057, 334],
y=[4057],
train_mask=[4057],
val_mask=[4057],
test_mask=[4057]
},
paper={
x=[14328, 4231] },
term={
x=[7723, 50] },
conference={
num_nodes=20 },
(author, to, paper)={
edge_index=[2, 19645] },
(paper, to, author)={
edge_index=[2, 19645] },
(paper, to, term)={
edge_index=[2, 85810] },
(paper, to, conference)={
edge_index=[2, 14328] },
(term, to, paper)={
edge_index=[2, 85810] },
(conference, to, paper)={
edge_index=[2, 14328] }
)
可以发现,DBLP数据集中有作者(author)、论文(paper)、术语(term)以及会议(conference)四种类型的节点。DBLP中包含14328篇论文(paper), 4057位作者(author), 20个会议(conference), 7723个术语(term)。作者分为四个领域:数据库、数据挖掘、机器学习、信息检索。
任务:对author节点进行分类,一共4类。
由于conference节点没有特征,因此需要预先设置特征:
graph['conference'].x = torch.ones((graph['conference'].num_nodes, 1))
所有conference节点的特征都初始化为[1]。
获取一些有用的数据:
num_classes = torch.max(graph['author'].y).item() + 1
train_mask, val_mask, test_mask = graph['author'].train_mask, graph['author'].val_mask, graph['author'].test_mask
y = graph['author'].y
模型搭建
首先导入包:
from torch_geometric.nn import HANConv
模型参数:
- in_channels:输入通道,比如节点分类中表示每个节点的特征数,一般设置为-1。
- out_channels:输出通道,最后一层GCNConv的输出通道为节点类别数(节点分类)。
- heads:多头注意力机制中的头数。值得注意的是,GANConv和GATConv不一样的地方在于,GANConv模型是把多头注意力的结果直接展平,而不是进行concat操作。
- negative_slope:LeakyRELU的参数。
于是模型搭建如下:
class HAN(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super(HAN, self).__init__()
# H, D = self.heads, self.out_channels // self.heads
self.conv1 = HANConv(in_channels, hidden_channels, graph.metadata(), heads=8)
self.conv2 = HANConv(hidden_channels, out_channels, graph.metadata(), heads=4)
def forward(self, data):
x_dict, edge_index_dict = data.x_dict, data.edge_index_dict
x = self.conv1(x_dict, edge_index_dict)
x = self.conv2(x, edge_index_dict)
x = F.softmax(x['author'], dim=1)
return x
输出一下模型:
model = HAN(-1, 64, num_classes).to(device)
HAN(
(conv1): HANConv(64, heads=8)
(conv2): HANConv(4, heads=4)
)
1. 前向传播
查看官方文档中HANConv的输入输出要求:
可以发现,HANConv中需要输入的是节点特征字典x_dict和邻接关系字典edge_index_dict。
因此有:
x_dict, edge_index_dict = data.x_dict, data.edge_index_dict
x = self.conv1(x_dict, edge_index_dict)
此时我们不妨输出一下x['author']及其size:
tensor([[0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[0.0969, 0.0601, 0.0000, ..., 0.0000, 0.0000, 0.0251],
[0.0000, 0.0000, 0.0000, ..., 0.1288, 0.0000, 0.0602],
...,
[0.0000, 0.0000, 0.0000, ..., 0.0096, 0.0000, 0.0240],
[0.0000, 0.0000, 0.0000, ..., 0.0096, 0.0000, 0.0240],
[0.0801, 0.0558, 0.0837, ..., 0.0277, 0.0347, 0.0000]],
device='cuda:0', grad_fn=<SumBackward1>)
torch.Size([4057, 64])
此时的x一共4057行,每一行表示一个author节点经过第一层卷积更新后的状态向量。
那么同理,由于:
x = self.conv2(x, edge_index_dict)
所以经过第二层卷积后得到的x['author']的size应该为:
torch.Size([4057, 4])
即每个author节点的维度为4的状态向量。
由于我们需要进行4分类,所以最后需要加上一个softmax:
x = F.softmax(x, dim=1)
dim=1表示对每一行进行运算,最终每一行之和加起来为1,也就表示了该节点为每一类的概率。输出此时的x:
tensor([[0.2591, 0.2539, 0.2435, 0.2435],
[0.3747, 0.2067, 0.2029, 0.2157],
[0.2986, 0.2338, 0.2338, 0.2338],
...,
[0.2740, 0.2453, 0.2403, 0.2403],
[0.2740, 0.2453, 0.2403, 0.2403],
[0.3414, 0.2195, 0.2195, 0.2195]], device='cuda:0',
grad_fn=<SoftmaxBackward0>)
2. 反向传播
在训练时,我们首先利用前向传播计算出输出:
f = model(graph)
f即为最终得到的每个节点的4个概率值,但在实际训练中,我们只需要计算出训练集的损失,所以损失函数这样写:
loss = loss_function(f[train_mask], y[train_mask])
然后计算梯度,反向更新!
3. 训练
训练时返回验证集上表现最优的模型:
def train():
model = HAN(-1, 64, num_classes).to(device)
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-4)
loss_function = torch.nn.CrossEntropyLoss().to(device)
min_epochs = 5
best_val_acc = 0
final_best_acc = 0
model.train()
for epoch in range(100):
f = model(graph)
loss = loss_function(f[train_mask], y[train_mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()
# validation
val_acc, val_loss = test(model, val_mask)
test_acc, test_loss = test(model, test_mask)
if epoch + 1 > min_epochs and val_acc > best_val_acc:
best_val_acc = val_acc
final_best_acc = test_acc
print('Epoch {:3d} train_loss {:.5f} val_acc {:.3f} test_acc {:.3f}'
.format(epoch, loss.item(), val_acc, test_acc))
return final_best_acc
4. 测试
def test(model, mask):
model.eval()
with torch.no_grad():
out = model(graph)
loss_function = torch.nn.CrossEntropyLoss().to(device)
loss = loss_function(out[mask], y[mask])
_, pred = out.max(dim=1)
correct = int(pred[mask].eq(y[mask]).sum().item())
acc = correct / int(test_mask.sum())
return acc, loss.item()
实验结果
数据集采用DBLP网络,训练100轮,分类正确率为78.54%:
HAN Accuracy: 0.7853853239177156
边栏推荐
- 历史上的今天:微软收购 QDOS;模型检测先驱出生;第一张激光照排的中文报纸...
- Mongodb database replication table
- Is it useful to learn software testing?
- redis持久化之RDB和AOF的区别
- Self cultivation of Electronic Engineers - when a project is developed
- 112. 使用自开发的代理服务器解决 SAP UI5 FileUploader 上传文件时遇到的跨域访问错误
- Win11系统svchost.exe一直在下载怎么办?
- Introduction and advanced MySQL (7)
- 1.2 queue
- jvm调优
猜你喜欢

What if you don't understand the difference between modularity, componentization and plug-in?

112. 使用自开发的代理服务器解决 SAP UI5 FileUploader 上传文件时遇到的跨域访问错误

Getting started with QT & OpenGL

Configuration tutorial: how does the organizational structure of the new version of easycvr (v2.5.0) cascade to the superior platform?

什么样的知识付费系统功能,更有利于平台与讲师发展?

Three minutes to understand, come to new media

Decimal to binary advanced version (can convert negative numbers and boundary values)

@The difference between Autowired and @resource

1.1. Sparse array

Redis advantages and data structure related knowledge
随机推荐
2022杭电多校第二场1011 DOS Card(线段树)
Interpretation of ue4.25 slate source code
kotlin:Nothing
APP为什么用JSON协议与服务端交互:序列化相关知识
kotlin:out in
How to break through the bottleneck of professional development for software testing engineers
What is one hot code? Why use it and when?
Xiaobai must see the development route of software testing
专题讲座6 树形dp 学习心得(长期更新)
Redis advantages and data structure related knowledge
How does the mqtt server built with emqx forward data and save it to the cloud database?
Self cultivation of Electronic Engineers - when a project is developed
How to choose between software testing and software development?
先验、后验、似然
Kotlin:Sealed class密封类详解
2022年牛客多校第2场 J . Link with Arithmetic Progression (三分+枚举)
Software testing dry goods
Special Lecture 6 tree DP learning experience (long-term update)
数字经济时代的开源数据库创新 | 2022开放原子全球开源峰会数据库分论坛圆满召开
Configuration tutorial: how does the organizational structure of the new version of easycvr (v2.5.0) cascade to the superior platform?