当前位置:网站首页>Pyg builds GCN to realize link prediction
Pyg builds GCN to realize link prediction
2022-07-25 04:35:00 【Cyril_ KI】
Catalog
Preface
For the introduction of link prediction and the division of data sets in link prediction, please refer to : Training set in link prediction 、 Division of verification set and test set ( With PyG Of RandomLinkSplit For example ).
1. Data processing
Here we use CiteSeer Take the Internet for example :Citeseer The network is a citation network , The node is a paper , altogether 3327 Papers . The thesis is divided into six categories :Agents、AI( Artificial intelligence )、DB( database )、IR( Information retrieval )、ML( machine language ) and HCI. If there is a citation relationship between two papers , Then there is a link between them .
Load data :
dataset = Planetoid('data', name='CiteSeer')
print(dataset[0])
Output :
Data(x=[3327, 3703], edge_index=[2, 9104], y=[3327], train_mask=[3327], val_mask=[3327], test_mask=[3327])
x=[3327, 3703] Means that there are 3327 Nodes , Then the characteristic dimension of the node is 3703, This is actually to remove stop words and appear less frequently in the document than 10 The second word , Put together 3703 It's a unique word .edge_index=[2, 9104], All in all 9104 strip edge, There are two lines of data , Each line represents the node number .
utilize PyG Packaged RandomLinkSplit We can easily realize the division of data sets :
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = T.Compose([
T.NormalizeFeatures(),
T.ToDevice(device),
T.RandomLinkSplit(num_val=0.1, num_test=0.1, is_undirected=True,
add_negative_train_samples=False),
])
dataset = Planetoid('data', name='CiteSeer', transform=transform)
train_data, val_data, test_data = dataset[0]
In the end we get train_data, val_data, test_data.
Output the original data set and the three divided data sets :
Data(x=[3327, 3703], edge_index=[2, 9104], y=[3327], train_mask=[3327], val_mask=[3327], test_mask=[3327])
Data(x=[3327, 3703], edge_index=[2, 7284], y=[3327], train_mask=[3327], val_mask=[3327], test_mask=[3327], edge_label=[3642], edge_label_index=[2, 3642])
Data(x=[3327, 3703], edge_index=[2, 7284], y=[3327], train_mask=[3327], val_mask=[3327], test_mask=[3327], edge_label=[910], edge_label_index=[2, 910])
Data(x=[3327, 3703], edge_index=[2, 8194], y=[3327], train_mask=[3327], val_mask=[3327], test_mask=[3327], edge_label=[910], edge_label_index=[2, 910])
From top to bottom, the original data set 、 Training set 、 Validation set and test set . among , There are... In the training set 3642 A positive sample , Both the validation set and the test set are 455 A positive sample +455 Negative samples .
2. GCN Link prediction
This experiment uses GCN To predict links : The first use of GCN Code the nodes in the training set , Get the vector representation of the node , Then these vectors are used to represent the positive and negative samples in the training set ( Resample negative samples at each round of training ) Conduct supervised learning , Specifically, it is to use the node vector to obtain the inner product of the node pair in the sample , Then calculate the loss with the label , Finally, back propagation updates parameters .
2.1 Negative sampling
In each round of link prediction training, we need to sample the training set to get the same number of negative samples as the positive samples , The verification set and test set have been negatively sampled in the data set division stage , Therefore, no more sampling is necessary .
Negative sampling function :
def negative_sample():
# Sample the same number of negative edges as the positive edges from the training set
neg_edge_index = negative_sampling(
edge_index=train_data.edge_index, num_nodes=train_data.num_nodes,
num_neg_samples=train_data.edge_label_index.size(1), method='sparse')
# print(neg_edge_index.size(1)) # 3642 Negative edge , That is, the negative side with the same number of positive sides in each sampling and training set
edge_label_index = torch.cat(
[train_data.edge_label_index, neg_edge_index],
dim=-1,
)
edge_label = torch.cat([
train_data.edge_label,
train_data.edge_label.new_zeros(neg_edge_index.size(1))
], dim=0)
return edge_label, edge_label_index
It's used here negative_sampling Method , The parameters are :
In particular ,negative_sampling Method uses the incoming edge_index Parameters are negatively sampled , That is sampling num_neg_samples strip edge_index Edges that do not exist in .num_nodes Specify the number of nodes ,method Specify the sampling method , Yes sparse and dense The two methods .
After sampling neg_edge_index And the original positive samples in the training set train.edge_label_index Splice to get a complete sample set , At the same time, it also needs to be in the original train_data.edge_label Add a specified number of 0 Used to represent negative samples .
2.2 Model structures,
GCN The link prediction model is built as follows :
class GCN_LP(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
def encode(self, x, edge_index):
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
return x
def decode(self, z, edge_label_index):
# z The representation vector of all nodes
src = z[edge_label_index[0]]
dst = z[edge_label_index[1]]
# print(dst.size()) # (7284, 64)
r = (src * dst).sum(dim=-1)
# print(r.size()) (7284)
return r
def forward(self, x, edge_index, edge_label_index):
z = self.encode(x, edge_index)
return self.decode(z, edge_label_index)
The encoder consists of two layers GCN form , Used to get the vector representation of the nodes in the training set , The decoder is used to obtain the inner product between the node pair vectors in the training set .
It can be seen from the above that the number of positive samples in the training set is 3642, After negative sampling function negative_sample obtain 3642 Negative samples , altogether 7284 Samples , Finally, the decoder returns 7284 Inner product between pairs of nodes .
The loss function uses BCEWithLogitsLoss, To understand BCEWithLogitsLoss, We must first understand BCELoss.
BCELoss Is a binary cross entropy loss :
and BCEWithLogitsLoss It is in BCELoss That's an increase from Sigmoid Options , That is, first pass the input through a Sigmoid, And then calculate BCELoss.
The evaluation index adopts AUC:
roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy())
2.3 model training / test
The code is as follows :
def test(model, data):
model.eval()
with torch.no_grad():
z = model.encode(data.x, data.edge_index)
out = model.decode(z, data.edge_label_index).view(-1).sigmoid()
model.train()
return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy())
def train():
model = GCN_LP(dataset.num_features, 128, 64).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss().to(device)
min_epochs = 10
best_model = None
best_val_auc = 0
final_test_auc = 0
model.train()
for epoch in range(100):
optimizer.zero_grad()
edge_label, edge_label_index = negative_sample()
out = model(train_data.x, train_data.edge_index, edge_label_index).view(-1)
loss = criterion(out, edge_label)
loss.backward()
optimizer.step()
# validation
val_auc = test(model, val_data)
test_auc = test(model, test_data)
if epoch + 1 > min_epochs and val_auc > best_val_auc:
best_val_auc = val_auc
final_test_auc = test_auc
print('epoch {:03d} train_loss {:.8f} val_auc {:.4f} test_auc {:.4f}'
.format(epoch, loss.item(), val_auc, test_auc))
return final_test_auc
On the final test set AUC by :
final best auc: 0.9076681560198044
边栏推荐
- Eve-ng lab simulator Cisco, H3C test host alias
- I want to ask you a question. I want to synchronize the database, but I think it is synchronized according to MySQL binlog. If it is a large table, one
- Paper:《Peeking Inside the Black Box: Visualizing Statistical Learning with Plots of Individual Condi
- TS learning (VII): interface and type compatibility of TS
- If the development of the metauniverse still follows the development logic of the Internet, and its end point also follows
- Internship in 2022
- Salt and ice particles cannot be distinguished
- GBase 8a 关于No Suitable Driver 问题
- LVGL 8.2 Textarea
- Large screen visual adaptation file
猜你喜欢

Paper:《Peeking Inside the Black Box: Visualizing Statistical Learning with Plots of Individual Condi

Unity3d learning note 9 - loading textures

ESWC 2018 | R-GCN:基于图卷积网络的关系数据建模

运筹学基础【一】 之 导论

# 1. Excel的IF函数

Cannot make qopenglcontext current in a different thread: the solution to pyqt multithread crash
![Function and technical principle of data desensitization [detailed explanation]](/img/bb/5890d8fd140982ea7b994942093cc7.png)
Function and technical principle of data desensitization [detailed explanation]

Typical data Lake application cases

数据链路层协议 ——— 以太网协议

Actual combat | record an attack and defense drill management
随机推荐
今天很重要
Source code | opencv DNN + yolov7 target detection
etcd学习
Paper:《Peeking Inside the Black Box: Visualizing Statistical Learning with Plots of Individual Condi
How to transfer NFT metadata from IPFs to smart contracts
Geely and Daimler set up a joint venture to produce pure electric smart in China!
Huawei cloud from entry to actual combat | cloud rapid site establishment service and enterprise host security service
Infinite connection · infinite collaboration | the first global enterprise communication cloud conference WECC is coming
Libenent and libev
LVGL 8.2 Tabview & Window
Numpy overview
IT自媒体高调炫富,被黑客组织盯上,铁定要吃牢饭了…
[CTF learning] steganography set in CTF -- picture steganography
[internship] processing time
[daily question] 731. My schedule II
Definition and basic terms of tree
Very clear organization
Kubesphere 3.3.0 offline installation tutorial
LVGL 8.2 Roller
Thinking of reading