当前位置:网站首页>GNN的第一个简单案例:Cora分类
GNN的第一个简单案例:Cora分类
2022-07-06 09:16:00 【想成为风筝】
GNN–Cora分类
Cora数据集是GNN中一个经典的数据集,将2708篇论文分为七类:1)基于案例、2)遗传算法、3)神经网络、4)概率方法、5)、强化学习、6)规则学习、7)理论。每一篇论文看作是一个节点,每个节点有1433个特征。
import os
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.datasets import Planetoid
import torch_geometric.nn as pyg_nn
#load Cora dataset
def get_data(root_dir='D:\Python\python_dataset\GNN_Dataset\Cora',data_name='Cora'):
Cora_dataset = Planetoid(root=root_dir,name=data_name)
print(Cora_dataset)
return Cora_dataset
Cora_dataset = get_data()
print(Cora_dataset.num_classes,Cora_dataset.num_node_features,Cora_dataset.num_edge_features)
print(Cora_dataset.data)
Cora()
7 1433 0
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])
代码中给出GCN、GAT、SGConv、ChebConv、SAGEConv的简单实现
import os
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.datasets import Planetoid
import torch_geometric.nn as pyg_nn
#load Cora dataset
def get_data(root_dir='D:\Python\python_dataset\GNN_Dataset\Cora',data_name='Cora'):
Cora_dataset = Planetoid(root=root_dir,name=data_name)
print(Cora_dataset)
return Cora_dataset
#create the Graph cnn model
""" 2-GATConv """
# class GATConv(nn.Module):
# def __init__(self,in_c,hid_c,out_c):
# super(GATConv,self).__init__()
# self.GATConv1 = pyg_nn.GATConv(in_channels=in_c,out_channels=hid_c)
# self.GATConv2 = pyg_nn.GATConv(in_channels=hid_c, out_channels=hid_c)
#
# def forward(self,data):
# x = data.x
# edge_index = data.edge_index
# hid = self.GATConv1(x=x,edge_index=edge_index)
# hid = F.relu(hid)
#
# out = self.GATConv2(hid,edge_index=edge_index)
# out = F.log_softmax(out,dim=1)
#
# return out
""" 2-SAGE 0.788 """
# class SAGEConv(nn.Module):
# def __init__(self,in_c,hid_c,out_c):
# super(SAGEConv,self).__init__()
# self.SAGEConv1 = pyg_nn.SAGEConv(in_channels=in_c,out_channels=hid_c)
# self.SAGEConv2 = pyg_nn.SAGEConv(in_channels=hid_c, out_channels=hid_c)
#
# def forward(self,data):
# x = data.x
# edge_index = data.edge_index
# hid = self.SAGEConv1(x=x,edge_index=edge_index)
# hid = F.relu(hid)
#
# out = self.SAGEConv2(hid,edge_index=edge_index)
# out = F.log_softmax(out,dim=1)
#
# return out
""" 2-SGConv 0.79 """
class SGConv(nn.Module):
def __init__(self,in_c,hid_c,out_c):
super(SGConv,self).__init__()
self.SGConv1 = pyg_nn.SGConv(in_channels=in_c,out_channels=hid_c)
self.SGConv2 = pyg_nn.SGConv(in_channels=hid_c, out_channels=hid_c)
def forward(self,data):
x = data.x
edge_index = data.edge_index
hid = self.SGConv1(x=x,edge_index=edge_index)
hid = F.relu(hid)
out = self.SGConv2(hid,edge_index=edge_index)
out = F.log_softmax(out,dim=1)
return out
""" 2-ChebConv """
# class ChebConv(nn.Module):
# def __init__(self,in_c,hid_c,out_c):
# super(ChebConv,self).__init__()
#
# self.ChebConv1 = pyg_nn.ChebConv(in_channels=in_c,out_channels=hid_c,K=1)
# self.ChebConv2 = pyg_nn.ChebConv(in_channels=hid_c,out_channels=out_c,K=1)
#
# def forward(self,data):
# x = data.x
# edge_index = data.edge_index
# hid = self.ChebConv1(x=x,edge_index=edge_index)
# hid = F.relu(hid)
#
# out = self.ChebConv2(hid,edge_index=edge_index)
# out = F.log_softmax(out,dim=1)
#
# return out
""" 2-GCN """
# class GraphCNN(nn.Module):
# def __init__(self, in_c,hid_c,out_c):
# super(GraphCNN,self).__init__()
#
# self.conv1 = pyg_nn.GCNConv(in_channels=in_c,out_channels=hid_c)
# self.conv2 = pyg_nn.GCNConv(in_channels=hid_c,out_channels=out_c)
#
# def forward(self,data):
# #data.x,data.edge_index
# x = data.x # [N,C]
# edge_index = data.edge_index # [2,E]
# hid = self.conv1(x=x,edge_index=edge_index) #[N,D]
# hid = F.relu(hid)
#
# out = self.conv2(hid,edge_index=edge_index) # [N,out_c]
#
# out = F.log_softmax(out,dim=1)
#
# return out
def main():
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
Cora_dataset = get_data()
#my_net = GATConv(in_c=Cora_dataset.num_node_features, hid_c=100, out_c=Cora_dataset.num_classes)
#my_net = SAGEConv(in_c=Cora_dataset.num_node_features, hid_c=40, out_c=Cora_dataset.num_classes)
my_net = SGConv(in_c=Cora_dataset.num_node_features,hid_c=100,out_c=Cora_dataset.num_classes)
#my_net = ChebConv(in_c=Cora_dataset.num_node_features,hid_c=20,out_c=Cora_dataset.num_classes)
# my_net = GraphCNN(in_c=Cora_dataset.num_node_features,hid_c=12,out_c=Cora_dataset.num_classes)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
my_net = my_net.to(device)
data = Cora_dataset[0].to(device)
optimizer = torch.optim.Adam(my_net.parameters(),lr=1e-3)
#model train
my_net.train()
for epoch in range(500):
optimizer.zero_grad()
output = my_net(data)
loss = F.nll_loss(output[data.train_mask],data.y[data.train_mask])
loss.backward()
optimizer.step()
print("Epoch",epoch+1,"Loss",loss.item())
#model test
my_net.eval()
_,prediction = my_net(data).max(dim=1)
target = data.y
test_correct = prediction[data.test_mask].eq(target[data.test_mask]).sum().item()
test_number = data.test_mask.sum().item()
print("Accuracy of Test Sample:",test_correct/test_number)
if __name__ == '__main__':
main()
Cora()
Epoch 1 Loss 4.600048542022705
Epoch 2 Loss 4.569146156311035
Epoch 3 Loss 4.535804271697998
Epoch 4 Loss 4.498434543609619
Epoch 5 Loss 4.456351280212402
Epoch 6 Loss 4.409425258636475
Epoch 7 Loss 4.357522964477539
Epoch 8 Loss 4.3007612228393555
Epoch 9 Loss 4.2392096519470215
Epoch 10 Loss 4.172731876373291
Epoch 11 Loss 4.101400375366211
Epoch 12 Loss 4.025243282318115
...............
Epoch 494 Loss 0.004426263272762299
Epoch 495 Loss 0.004407935775816441
Epoch 496 Loss 0.004389731213450432
Epoch 497 Loss 0.004371633753180504
Epoch 498 Loss 0.004353662021458149
Epoch 499 Loss 0.0043357922695577145
Epoch 500 Loss 0.004318032879382372
Accuracy of Test Sample: 0.794
边栏推荐
猜你喜欢
error C4996: ‘strcpy‘: This function or variable may be unsafe. Consider using strcpy_ s instead
【CDH】CDH5.16 配置 yarn 任务集中分配设置不生效问题
Mall project -- day09 -- order module
Vert. x: A simple TCP client and server demo
FTP file upload file implementation, regularly scan folders to upload files in the specified format to the server, C language to realize FTP file upload details and code case implementation
ToggleButton实现一个开关灯的效果
【flink】flink学习
2019 Tencent summer intern formal written examination
Mysql的索引实现之B树和B+树
sklearn之feature_extraction.text.CountVectorizer / TfidVectorizer
随机推荐
[NPUCTF2020]ReadlezPHP
Using LinkedHashMap to realize the caching of an LRU algorithm
使用LinkedHashMap实现一个LRU算法的缓存
Reading notes of difficult career creation
L2-007 family real estate (25 points)
【Flink】CDH/CDP Flink on Yarn 日志配置
Integration test practice (1) theoretical basis
C语言读取BMP文件
Pytoch Foundation
4、安装部署Spark(Spark on Yarn模式)
【CDH】CDH/CDP 环境修改 cloudera manager默认端口7180
常用正则表达式整理
double转int精度丢失问题
Reading BMP file with C language
Principle and implementation of MySQL master-slave replication
TypeScript
[CDH] cdh5.16 configuring the setting of yarn task centralized allocation does not take effect
Encodermappreduce notes
STM32型号与Contex m对应关系
Yarn installation and use