当前位置:网站首页>PyG搭建R-GCN实现节点分类
PyG搭建R-GCN实现节点分类
2022-07-30 04:06:00 【Cyril_KI】
前言
R-GCN的原理请见:ESWC 2018 | R-GCN:基于图卷积网络的关系数据建模。
数据处理
导入数据:
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'DBLP')
dataset = DBLP(path)
graph = dataset[0]
print(graph)
输出如下:
HeteroData(
author={
x=[4057, 334],
y=[4057],
train_mask=[4057],
val_mask=[4057],
test_mask=[4057]
},
paper={
x=[14328, 4231] },
term={
x=[7723, 50] },
conference={
num_nodes=20 },
(author, to, paper)={
edge_index=[2, 19645] },
(paper, to, author)={
edge_index=[2, 19645] },
(paper, to, term)={
edge_index=[2, 85810] },
(paper, to, conference)={
edge_index=[2, 14328] },
(term, to, paper)={
edge_index=[2, 85810] },
(conference, to, paper)={
edge_index=[2, 14328] }
)
可以发现,DBLP数据集中有作者(author)、论文(paper)、术语(term)以及会议(conference)四种类型的节点。DBLP中包含14328篇论文(paper), 4057位作者(author), 20个会议(conference), 7723个术语(term)。作者分为四个领域:数据库、数据挖掘、机器学习、信息检索。
任务:对author
节点进行分类,一共4类。
由于conference
节点没有特征,因此需要预先设置特征:
graph['conference'].x = torch.randn((graph['conference'].num_nodes, 50))
所有conference节点的特征都随机初始化。
获取一些有用的数据:
num_classes = torch.max(graph['author'].y).item() + 1
train_mask, val_mask, test_mask = graph['author'].train_mask, graph['author'].val_mask, graph['author'].test_mask
y = graph['author'].y
node_types, edge_types = graph.metadata()
num_nodes = graph['author'].x.shape[0]
num_relations = len(edge_types)
init_sizes = [graph[x].x.shape[1] for x in node_types]
homogeneous_graph = graph.to_homogeneous()
in_feats, hidden_feats = 128, 64
模型搭建
首先导入包:
from torch_geometric.nn import RGCNConv
模型参数:
- in_channels:输入通道,比如节点分类中表示每个节点的特征数,一般设置为-1。
- out_channels:输出通道,最后一层GCNConv的输出通道为节点类别数(节点分类)。
- num_relations:关系数。
- num_bases:如果使用基函数分解正则化,则其表示要使用的基数。
- num_blocks:如果使用块对角分解正则化,则其表示要使用的块数。
- aggr:聚合方式,默认为
mean
。
于是模型搭建如下:
class RGCN(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super(RGCN, self).__init__()
self.conv1 = RGCNConv(in_channels, hidden_channels,
num_relations=num_relations, num_bases=30)
self.conv2 = RGCNConv(hidden_channels, out_channels,
num_relations=num_relations, num_bases=30)
self.lins = torch.nn.ModuleList()
for i in range(len(node_types)):
lin = nn.Linear(init_sizes[i], in_channels)
self.lins.append(lin)
def trans_dimensions(self, g):
data = copy.deepcopy(g)
for node_type, lin in zip(node_types, self.lins):
data[node_type].x = lin(data[node_type].x)
return data
def forward(self, data):
data = self.trans_dimensions(data)
homogeneous_data = data.to_homogeneous()
edge_index, edge_type = homogeneous_data.edge_index, homogeneous_data.edge_type
x = self.conv1(homogeneous_data.x, edge_index, edge_type)
x = self.conv2(x, edge_index, edge_type)
x = x[:num_nodes]
x = F.softmax(x, dim=1)
return x
输出一下模型:
model = RGCN(in_feats, hidden_feats, num_classes).to(device)
RGCN(
(conv1): RGCNConv(128, 64, num_relations=6)
(conv2): RGCNConv(64, 4, num_relations=6)
(lins): ModuleList(
(0): Linear(in_features=334, out_features=128, bias=True)
(1): Linear(in_features=4231, out_features=128, bias=True)
(2): Linear(in_features=50, out_features=128, bias=True)
(3): Linear(in_features=50, out_features=128, bias=True)
)
)
1. 前向传播
查看官方文档中RGCNConv的输入输出要求:
可以发现,RGCNConv中需要输入的是节点特征x
、边索引edge_index
以及边类型edge_type
。
我们输出初始化特征后的DBLP数据集:
HeteroData(
author={
x=[4057, 334],
y=[4057],
train_mask=[4057],
val_mask=[4057],
test_mask=[4057]
},
paper={
x=[14328, 4231] },
term={
x=[7723, 50] },
conference={
num_nodes=20,
x=[20, 50]
},
(author, to, paper)={
edge_index=[2, 19645] },
(paper, to, author)={
edge_index=[2, 19645] },
(paper, to, term)={
edge_index=[2, 85810] },
(paper, to, conference)={
edge_index=[2, 14328] },
(term, to, paper)={
edge_index=[2, 85810] },
(conference, to, paper)={
edge_index=[2, 14328] }
)
可以发现,DBLP中并没有上述要求的三个值。因此,我们首先需要将其转为同质图:
homogeneous_graph = graph.to_homogeneous()
Data(node_type=[26128], edge_index=[2, 239566], edge_type=[239566])
转为同质图后虽然有了edge_index
和edge_type
,但没有所有节点的特征x
,这是因为在将异质图转为同质图的过程中,只有所有节点的特征维度相同时才能将所有节点的特征进行合并。因此,我们首先需要将所有节点的特征转换到同一维度(这里以128为例):
def trans_dimensions(self, g):
data = copy.deepcopy(g)
for node_type, lin in zip(node_types, self.lins):
data[node_type].x = lin(data[node_type].x)
return data
转换后的data中所有类型节点的特征维度都为128,然后再将其转为同质图:
data = self.trans_dimensions(data)
homogeneous_data = data.to_homogeneous()
Data(node_type=[26128], x=[26128, 128], edge_index=[2, 239566], edge_type=[239566])
此时,我们就可以将homogeneous_data输入到RGCNConv中:
x = self.conv1(homogeneous_data.x, edge_index, edge_type)
x = self.conv2(x, edge_index, edge_type)
输出的x
包含所有节点的信息,我们只需要取前4057个,也就是author
节点的特征:
x = x[:num_nodes]
x = F.softmax(x, dim=1)
2. 反向传播
在训练时,我们首先利用前向传播计算出输出:
f = model(graph)
f
即为最终得到的每个节点的4个概率值,但在实际训练中,我们只需要计算出训练集的损失,所以损失函数这样写:
loss = loss_function(f[train_mask], y[train_mask])
然后计算梯度,反向更新!
3. 训练
训练时返回验证集上表现最优的模型:
def train():
model = RGCN(in_feats, hidden_feats, num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01,
weight_decay=1e-4)
loss_function = torch.nn.CrossEntropyLoss().to(device)
min_epochs = 5
best_val_acc = 0
final_best_acc = 0
model.train()
for epoch in range(100):
f = model(graph)
loss = loss_function(f[train_mask], y[train_mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()
# validation
val_acc, val_loss = test(model, val_mask)
test_acc, test_loss = test(model, test_mask)
if epoch + 1 > min_epochs and val_acc > best_val_acc:
best_val_acc = val_acc
final_best_acc = test_acc
print('Epoch{:3d} train_loss {:.5f} val_acc {:.3f} test_acc {:.3f}'.
format(epoch, loss.item(), val_acc, test_acc))
return final_best_acc
4. 测试
@torch.no_grad()
def test(model, mask):
model.eval()
out = model(graph)
loss_function = torch.nn.CrossEntropyLoss().to(device)
loss = loss_function(out[mask], y[mask])
_, pred = out.max(dim=1)
correct = int(pred[mask].eq(y[mask]).sum().item())
acc = correct / int(test_mask.sum())
return acc, loss.item()
实验结果
数据集采用DBLP网络,训练100轮,分类正确率为93.77%:
RGCN Accuracy: 0.9376727049431992
完整代码
代码地址:GNNs-for-Node-Classification。原创不易,下载时请给个follow和star!感谢!!
边栏推荐
- Mysql版本升级,直接复制Data文件,查询特别慢
- ospf 导图
- Mini Program Graduation Works WeChat Points Mall Mini Program Graduation Design Finished Work (5) Task Book
- Pytorch framework learning record 1 - Dataset class code combat
- 小程序毕设作品之微信二手交易小程序毕业设计成品(5)任务书
- Smart answer function, CRMEB knowledge payment system must have!
- Microservice CAP Principles
- Mini Program Graduation Works WeChat Second-hand Trading Mini Program Graduation Design Finished Works (5) Task Book
- 高并发框架 Disruptor
- Transformation of traditional projects
猜你喜欢
Nacos cluster partition
小程序毕设作品之微信积分商城小程序毕业设计成品(2)小程序功能
Roperties class configuration file & DOS to view the host network situation
Pytorch框架学习记录2——TensorBoard的使用
Transformation of traditional projects
Introduction to management for technical people 1: What is management
The first immersive and high-fidelity metaverse in China, Xiyuan Universe is officially launched
How to Effectively Conduct Retrospective Meetings (Part 1)?
【C进阶】数组传参与函数指针
小程序毕设作品之微信二手交易小程序毕业设计成品(5)任务书
随机推荐
mysql structure, index details
小程序毕设作品之微信二手交易小程序毕业设计成品(4)开题报告
Pytorch framework learning record 7 - convolutional layer
Transformation of traditional projects
[Node accesses MongoDB database]
Taobao/Tmall get the list of sold product orders API
Alibaba search new product data API by keyword
Why is the Kirin 9000 5G version suddenly back in stock?
Basic introduction to protect the network operations
小程序毕设作品之微信二手交易小程序毕业设计成品(8)毕业设计论文模板
Nacos Configuration Center
对均匀采样信号进行重采样
Has been empty, a straightforward, continue to copy the top off!
redis分布式锁的原子保证
逆向理论知识3【UI修改篇】
When the EasyNVR platform is cascaded to the EasyCVR, why can't the video be played after a while?
Usage of exists in sql
函数的底层机制
Boutique: Taobao/Tmall Get Order Details API for Purchased Products
一直空、一直爽,继续抄顶告捷!