当前位置:网站首页>PyG builds R-GCN to realize node classification
PyG builds R-GCN to realize node classification
2022-07-30 04:10:00 【Cyril_KI】
前言
R-GCN的原理请见:ESWC 2018 | R-GCN:Relational data modeling based on graph convolutional networks.
数据处理
导入数据:
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'DBLP')
dataset = DBLP(path)
graph = dataset[0]
print(graph)
输出如下:
HeteroData(
author={
x=[4057, 334],
y=[4057],
train_mask=[4057],
val_mask=[4057],
test_mask=[4057]
},
paper={
x=[14328, 4231] },
term={
x=[7723, 50] },
conference={
num_nodes=20 },
(author, to, paper)={
edge_index=[2, 19645] },
(paper, to, author)={
edge_index=[2, 19645] },
(paper, to, term)={
edge_index=[2, 85810] },
(paper, to, conference)={
edge_index=[2, 14328] },
(term, to, paper)={
edge_index=[2, 85810] },
(conference, to, paper)={
edge_index=[2, 14328] }
)
可以发现,DBLP数据集中有作者(author)、论文(paper)、术语(term)以及会议(conference)四种类型的节点.DBLP中包含14328篇论文(paper), 4057位作者(author), 20个会议(conference), 7723个术语(term).作者分为四个领域:数据库、数据挖掘、机器学习、信息检索.
任务:对author节点进行分类,一共4类.
由于conference节点没有特征,因此需要预先设置特征:
graph['conference'].x = torch.randn((graph['conference'].num_nodes, 50))
所有conferenceThe features of the nodes are all randomly initialized.
获取一些有用的数据:
num_classes = torch.max(graph['author'].y).item() + 1
train_mask, val_mask, test_mask = graph['author'].train_mask, graph['author'].val_mask, graph['author'].test_mask
y = graph['author'].y
node_types, edge_types = graph.metadata()
num_nodes = graph['author'].x.shape[0]
num_relations = len(edge_types)
init_sizes = [graph[x].x.shape[1] for x in node_types]
homogeneous_graph = graph.to_homogeneous()
in_feats, hidden_feats = 128, 64
模型搭建
首先导入包:
from torch_geometric.nn import RGCNConv
模型参数:
- in_channels:输入通道,比如节点分类中表示每个节点的特征数,一般设置为-1.
- out_channels:输出通道,最后一层GCNConv的输出通道为节点类别数(节点分类).
- num_relations:关系数.
- num_bases:If using basis function decomposition regularization,then it represents the radix to use.
- num_blocks:If using block diagonal decomposition regularization,then it represents the number of blocks to use.
- aggr:聚合方式,默认为
mean.
于是模型搭建如下:
class RGCN(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super(RGCN, self).__init__()
self.conv1 = RGCNConv(in_channels, hidden_channels,
num_relations=num_relations, num_bases=30)
self.conv2 = RGCNConv(hidden_channels, out_channels,
num_relations=num_relations, num_bases=30)
self.lins = torch.nn.ModuleList()
for i in range(len(node_types)):
lin = nn.Linear(init_sizes[i], in_channels)
self.lins.append(lin)
def trans_dimensions(self, g):
data = copy.deepcopy(g)
for node_type, lin in zip(node_types, self.lins):
data[node_type].x = lin(data[node_type].x)
return data
def forward(self, data):
data = self.trans_dimensions(data)
homogeneous_data = data.to_homogeneous()
edge_index, edge_type = homogeneous_data.edge_index, homogeneous_data.edge_type
x = self.conv1(homogeneous_data.x, edge_index, edge_type)
x = self.conv2(x, edge_index, edge_type)
x = x[:num_nodes]
x = F.softmax(x, dim=1)
return x
输出一下模型:
model = RGCN(in_feats, hidden_feats, num_classes).to(device)
RGCN(
(conv1): RGCNConv(128, 64, num_relations=6)
(conv2): RGCNConv(64, 4, num_relations=6)
(lins): ModuleList(
(0): Linear(in_features=334, out_features=128, bias=True)
(1): Linear(in_features=4231, out_features=128, bias=True)
(2): Linear(in_features=50, out_features=128, bias=True)
(3): Linear(in_features=50, out_features=128, bias=True)
)
)
1. 前向传播
查看官方文档中RGCNConv的输入输出要求:
可以发现,RGCNConvThe required input in is the node featurex、边索引edge_indexand edge typeedge_type.
We output the initialized featuresDBLP数据集:
HeteroData(
author={
x=[4057, 334],
y=[4057],
train_mask=[4057],
val_mask=[4057],
test_mask=[4057]
},
paper={
x=[14328, 4231] },
term={
x=[7723, 50] },
conference={
num_nodes=20,
x=[20, 50]
},
(author, to, paper)={
edge_index=[2, 19645] },
(paper, to, author)={
edge_index=[2, 19645] },
(paper, to, term)={
edge_index=[2, 85810] },
(paper, to, conference)={
edge_index=[2, 14328] },
(term, to, paper)={
edge_index=[2, 85810] },
(conference, to, paper)={
edge_index=[2, 14328] }
)
可以发现,DBLPThe three values required above are not present in .因此,We first need to convert it to a homogeneous graph:
homogeneous_graph = graph.to_homogeneous()
Data(node_type=[26128], edge_index=[2, 239566], edge_type=[239566])
After converting to a homogeneous map, although there isedge_index和edge_type,But not all node featuresx,This is because in the process of converting a heterogeneous graph to a homogeneous graph,The features of all nodes can be merged only when the feature dimensions of all nodes are the same.因此,We first need to transform the features of all nodes to the same dimension(这里以128为例):
def trans_dimensions(self, g):
data = copy.deepcopy(g)
for node_type, lin in zip(node_types, self.lins):
data[node_type].x = lin(data[node_type].x)
return data
转换后的dataThe feature dimension of all types of nodes in is 128,Then turn it into a homogeneous graph:
data = self.trans_dimensions(data)
homogeneous_data = data.to_homogeneous()
Data(node_type=[26128], x=[26128, 128], edge_index=[2, 239566], edge_type=[239566])
此时,我们就可以将homogeneous_data输入到RGCNConv中:
x = self.conv1(homogeneous_data.x, edge_index, edge_type)
x = self.conv2(x, edge_index, edge_type)
输出的xContains information about all nodes,我们只需要取前4057个,也就是author节点的特征:
x = x[:num_nodes]
x = F.softmax(x, dim=1)
2. 反向传播
在训练时,我们首先利用前向传播计算出输出:
f = model(graph)
f即为最终得到的每个节点的4个概率值,但在实际训练中,我们只需要计算出训练集的损失,所以损失函数这样写:
loss = loss_function(f[train_mask], y[train_mask])
然后计算梯度,反向更新!
3. 训练
训练时返回验证集上表现最优的模型:
def train():
model = RGCN(in_feats, hidden_feats, num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01,
weight_decay=1e-4)
loss_function = torch.nn.CrossEntropyLoss().to(device)
min_epochs = 5
best_val_acc = 0
final_best_acc = 0
model.train()
for epoch in range(100):
f = model(graph)
loss = loss_function(f[train_mask], y[train_mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()
# validation
val_acc, val_loss = test(model, val_mask)
test_acc, test_loss = test(model, test_mask)
if epoch + 1 > min_epochs and val_acc > best_val_acc:
best_val_acc = val_acc
final_best_acc = test_acc
print('Epoch{:3d} train_loss {:.5f} val_acc {:.3f} test_acc {:.3f}'.
format(epoch, loss.item(), val_acc, test_acc))
return final_best_acc
4. 测试
@torch.no_grad()
def test(model, mask):
model.eval()
out = model(graph)
loss_function = torch.nn.CrossEntropyLoss().to(device)
loss = loss_function(out[mask], y[mask])
_, pred = out.max(dim=1)
correct = int(pred[mask].eq(y[mask]).sum().item())
acc = correct / int(test_mask.sum())
return acc, loss.item()
实验结果
数据集采用DBLP网络,训练100轮,分类正确率为93.77%:
RGCN Accuracy: 0.9376727049431992
完整代码
代码地址:GNNs-for-Node-Classification.原创不易,下载时请给个follow和star!感谢!!
边栏推荐
- flutter 记录学习不一样的动画(一)
- 国内首家沉浸式高逼真元宇宙,希元宇宙正式上线
- mysql 结构、索引详解
- Problems caused by List getting the difference
- golang中如何比较struct,slice,map是否相等以及几种对比方法的区别
- WEB penetration of information collection
- Taobao H5 interface to obtain app data 6.0 format
- 解决编译安装gdb-10.1 unistd.h:663:3: error: #error “Please include config.h first.“ 问题
- Smart answer function, CRMEB knowledge payment system must have!
- MySQL data query (subtotal and sorting)
猜你喜欢

Small application project works WeChat integral mall small program of graduation design (4) the opening report of finished product

ospf 导图

The difference between BGP room and ordinary room in Beijing
![[Node accesses MongoDB database]](/img/00/41000fc7a038f4d46151ab479174b9.png)
[Node accesses MongoDB database]

Mini Program Graduation Works WeChat Second-hand Trading Mini Program Graduation Design Finished Works (7) Interim Inspection Report

Resampling a uniformly sampled signal

高并发框架 Disruptor

函数的底层机制

Roperties class configuration file & DOS to view the host network situation

Mini Program Graduation Works WeChat Second-hand Trading Mini Program Graduation Design Finished Works (5) Task Book
随机推荐
Chapter 51 - Knowing the request header parameter analysis【2022-07-28】
Eureka Registry
Mini Program Graduation Works WeChat Second-hand Trading Mini Program Graduation Design Finished Works (6) Question Opening Reply PPT
Eureka注册中心
flutter 记录学习不一样的动画(一)
spicy(二)unit hooks
LeetCode 114. Expand Binary Tree into Linked List (One Question Three Eats)
Pytorch framework to study record 6 - the torch. Nn. The Module and the torch nn. Functional. The use of conv2d
Basic introduction to protect the network operations
Shell脚本基本编辑规范及变量
数组和结构体
Many overseas authoritative media hotly discuss TRON: laying the foundation for the decentralization of the Internet
forward与redirect的区别
How to compare struct, slice, map for equality and the difference between several comparison methods in golang
STM32 SPI+WM8978 voice loopback
RRU, BBU, AAU
恐造成下一个“千年虫”的闰秒,遭科技巨头们联合抵制
新型LaaS协议Elephant Swap给ePLATO提供可持续溢价空间
sqlmap use tutorial Daquan command Daquan (graphics)
小程序毕设作品之微信二手交易小程序毕业设计成品(8)毕业设计论文模板