当前位置:网站首页>The first simple case of GNN: Cora classification
The first simple case of GNN: Cora classification
2022-07-06 11:59:00 【Want to be a kite】
GNN–Cora classification
Cora The dataset is GNN A classic dataset in , take 2708 The papers are divided into seven categories :1) Based on the case 、2) Genetic algorithm (ga) 、3) neural network 、4) Probability method 、5)、 Reinforcement learning 、6) Rule learning 、7) theory . Each paper is regarded as a node , Each node has 1433 Features .
import os
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.datasets import Planetoid
import torch_geometric.nn as pyg_nn
#load Cora dataset
def get_data(root_dir='D:\Python\python_dataset\GNN_Dataset\Cora',data_name='Cora'):
Cora_dataset = Planetoid(root=root_dir,name=data_name)
print(Cora_dataset)
return Cora_dataset
Cora_dataset = get_data()
print(Cora_dataset.num_classes,Cora_dataset.num_node_features,Cora_dataset.num_edge_features)
print(Cora_dataset.data)
Cora()
7 1433 0
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])
The code gives GCN、GAT、SGConv、ChebConv、SAGEConv Simple implementation of
import os
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.datasets import Planetoid
import torch_geometric.nn as pyg_nn
#load Cora dataset
def get_data(root_dir='D:\Python\python_dataset\GNN_Dataset\Cora',data_name='Cora'):
Cora_dataset = Planetoid(root=root_dir,name=data_name)
print(Cora_dataset)
return Cora_dataset
#create the Graph cnn model
""" 2-GATConv """
# class GATConv(nn.Module):
# def __init__(self,in_c,hid_c,out_c):
# super(GATConv,self).__init__()
# self.GATConv1 = pyg_nn.GATConv(in_channels=in_c,out_channels=hid_c)
# self.GATConv2 = pyg_nn.GATConv(in_channels=hid_c, out_channels=hid_c)
#
# def forward(self,data):
# x = data.x
# edge_index = data.edge_index
# hid = self.GATConv1(x=x,edge_index=edge_index)
# hid = F.relu(hid)
#
# out = self.GATConv2(hid,edge_index=edge_index)
# out = F.log_softmax(out,dim=1)
#
# return out
""" 2-SAGE 0.788 """
# class SAGEConv(nn.Module):
# def __init__(self,in_c,hid_c,out_c):
# super(SAGEConv,self).__init__()
# self.SAGEConv1 = pyg_nn.SAGEConv(in_channels=in_c,out_channels=hid_c)
# self.SAGEConv2 = pyg_nn.SAGEConv(in_channels=hid_c, out_channels=hid_c)
#
# def forward(self,data):
# x = data.x
# edge_index = data.edge_index
# hid = self.SAGEConv1(x=x,edge_index=edge_index)
# hid = F.relu(hid)
#
# out = self.SAGEConv2(hid,edge_index=edge_index)
# out = F.log_softmax(out,dim=1)
#
# return out
""" 2-SGConv 0.79 """
class SGConv(nn.Module):
def __init__(self,in_c,hid_c,out_c):
super(SGConv,self).__init__()
self.SGConv1 = pyg_nn.SGConv(in_channels=in_c,out_channels=hid_c)
self.SGConv2 = pyg_nn.SGConv(in_channels=hid_c, out_channels=hid_c)
def forward(self,data):
x = data.x
edge_index = data.edge_index
hid = self.SGConv1(x=x,edge_index=edge_index)
hid = F.relu(hid)
out = self.SGConv2(hid,edge_index=edge_index)
out = F.log_softmax(out,dim=1)
return out
""" 2-ChebConv """
# class ChebConv(nn.Module):
# def __init__(self,in_c,hid_c,out_c):
# super(ChebConv,self).__init__()
#
# self.ChebConv1 = pyg_nn.ChebConv(in_channels=in_c,out_channels=hid_c,K=1)
# self.ChebConv2 = pyg_nn.ChebConv(in_channels=hid_c,out_channels=out_c,K=1)
#
# def forward(self,data):
# x = data.x
# edge_index = data.edge_index
# hid = self.ChebConv1(x=x,edge_index=edge_index)
# hid = F.relu(hid)
#
# out = self.ChebConv2(hid,edge_index=edge_index)
# out = F.log_softmax(out,dim=1)
#
# return out
""" 2-GCN """
# class GraphCNN(nn.Module):
# def __init__(self, in_c,hid_c,out_c):
# super(GraphCNN,self).__init__()
#
# self.conv1 = pyg_nn.GCNConv(in_channels=in_c,out_channels=hid_c)
# self.conv2 = pyg_nn.GCNConv(in_channels=hid_c,out_channels=out_c)
#
# def forward(self,data):
# #data.x,data.edge_index
# x = data.x # [N,C]
# edge_index = data.edge_index # [2,E]
# hid = self.conv1(x=x,edge_index=edge_index) #[N,D]
# hid = F.relu(hid)
#
# out = self.conv2(hid,edge_index=edge_index) # [N,out_c]
#
# out = F.log_softmax(out,dim=1)
#
# return out
def main():
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
Cora_dataset = get_data()
#my_net = GATConv(in_c=Cora_dataset.num_node_features, hid_c=100, out_c=Cora_dataset.num_classes)
#my_net = SAGEConv(in_c=Cora_dataset.num_node_features, hid_c=40, out_c=Cora_dataset.num_classes)
my_net = SGConv(in_c=Cora_dataset.num_node_features,hid_c=100,out_c=Cora_dataset.num_classes)
#my_net = ChebConv(in_c=Cora_dataset.num_node_features,hid_c=20,out_c=Cora_dataset.num_classes)
# my_net = GraphCNN(in_c=Cora_dataset.num_node_features,hid_c=12,out_c=Cora_dataset.num_classes)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
my_net = my_net.to(device)
data = Cora_dataset[0].to(device)
optimizer = torch.optim.Adam(my_net.parameters(),lr=1e-3)
#model train
my_net.train()
for epoch in range(500):
optimizer.zero_grad()
output = my_net(data)
loss = F.nll_loss(output[data.train_mask],data.y[data.train_mask])
loss.backward()
optimizer.step()
print("Epoch",epoch+1,"Loss",loss.item())
#model test
my_net.eval()
_,prediction = my_net(data).max(dim=1)
target = data.y
test_correct = prediction[data.test_mask].eq(target[data.test_mask]).sum().item()
test_number = data.test_mask.sum().item()
print("Accuracy of Test Sample:",test_correct/test_number)
if __name__ == '__main__':
main()
Cora()
Epoch 1 Loss 4.600048542022705
Epoch 2 Loss 4.569146156311035
Epoch 3 Loss 4.535804271697998
Epoch 4 Loss 4.498434543609619
Epoch 5 Loss 4.456351280212402
Epoch 6 Loss 4.409425258636475
Epoch 7 Loss 4.357522964477539
Epoch 8 Loss 4.3007612228393555
Epoch 9 Loss 4.2392096519470215
Epoch 10 Loss 4.172731876373291
Epoch 11 Loss 4.101400375366211
Epoch 12 Loss 4.025243282318115
...............
Epoch 494 Loss 0.004426263272762299
Epoch 495 Loss 0.004407935775816441
Epoch 496 Loss 0.004389731213450432
Epoch 497 Loss 0.004371633753180504
Epoch 498 Loss 0.004353662021458149
Epoch 499 Loss 0.0043357922695577145
Epoch 500 Loss 0.004318032879382372
Accuracy of Test Sample: 0.794
边栏推荐
- Basic use of pytest
- B tree and b+ tree of MySQL index implementation
- Mysql的索引实现之B树和B+树
- arduino UNO R3的寄存器写法(1)-----引脚电平状态变化
- R & D thinking 01 ----- classic of embedded intelligent product development process
- 【presto】presto 参数配置优化
- Some concepts often asked in database interview
- 分布式节点免密登录
- 电商数据分析--用户行为分析
- 機器學習--線性回歸(sklearn)
猜你喜欢
4. Install and deploy spark (spark on Yan mode)
MongoDB
数据分析之缺失值填充(重点讲解多重插值法Miceforest)
ToggleButton实现一个开关灯的效果
Comparison of solutions of Qualcomm & MTK & Kirin mobile platform USB3.0
Detailed explanation of Union [C language]
MongoDB
Redis interview questions
Mysql database interview questions
机器学习--线性回归(sklearn)
随机推荐
TypeScript
Apprentissage automatique - - régression linéaire (sklearn)
【yarn】Yarn container 日志清理
SQL time injection
Basic use of pytest
Machine learning -- linear regression (sklearn)
[CDH] modify the default port 7180 of cloudera manager in cdh/cdp environment
【CDH】CDH5.16 配置 yarn 任务集中分配设置不生效问题
数据库面试常问的一些概念
open-mmlab labelImg mmdetection
机器学习--决策树(sklearn)
[Flink] cdh/cdp Flink on Yan log configuration
I2C bus timing explanation
Characteristics, task status and startup of UCOS III
OPPO VOOC快充电路和协议
C language callback function [C language]
C语言回调函数【C语言】
OSPF message details - LSA overview
JS object and event learning notes
互联网协议详解