当前位置:网站首页>Pyg tutorial (3): neighbor sampling

Pyg tutorial (3): neighbor sampling

2022-06-21 06:43:00 Si Xi is towering

One . Why do I need neighbor sampling ?

stay GNN field , Big pictures are very common , But because of GPU Limitations of video memory , The big picture can't be put into GPU Training on . So , Neighbor sampling can be used , In this way, you can take GNN Expand to the big picture . stay PyG in , There are many ways of neighbor sampling , Specific details torch_geometric.loader. This article takes GraphSage Taking neighbor sampling in as an example , Its presence PyG Achieve in NeighborLoader.

NeighborSampler It's also PyG About China GraphSage Implementation of neighbor sampling in , But it has been abandoned , It will be deleted in future versions .

Two .NeighborLoader Detailed explanation

2.1 GraphSage Neighbor sampling principle

Suppose the number of layers sampled is K K K, The number of neighbors sampled in each layer is S k S_k Sk,GraphSage The middle neighbor sampling is carried out in this way :

  • Step one : First, give a small batch of node sets to sample neighbors B \mathcal{B} B;
  • Step two : Yes B \mathcal{B} B Of 1 1 1 jump (hop) Neighbor sampling , Then get B 1 \mathcal{B}_1 B1, Then on B 1 \mathcal{B}_1 B1 Of 1 1 1 Jump neighbors to sample ( That is, of the initial node set 2 2 2 Jump neighbor ) obtain B 2 \mathcal{B}_2 B2, So back and forth K K K Time , Get a subgraph related to the initial small batch node set .

The picture on the left shows GraphSage One given in 2 Example of layer neighbor sampling , The number of neighbors sampled in each layer S k S_k Sk They are equal. ( The picture shows 3).

graphsage

2.2 API Introduce

PyG in ,GraphSage The neighbor sampling implementation of is torch_geometric.loader.NeighborLoader, The initialization function parameter is :

def __init__(
    self,
    data: Union[Data, HeteroData],
    num_neighbors: NumNeighbors,
    input_nodes: InputNodes = None,
    replace: bool = False,
    directed: bool = True,
    transform: Callable = None,
    neighbor_sampler: Optional[NeighborSampler] = None,
    **kwargs,
)

Common parameters are described as follows :

  • data: Graph object to sample , It can be a heterogeneous graph HeteroData, It can also be isomorphic graph Data;
  • num_neighbors: Every node, every iteration ( Each layer ) Maximum number of neighbors sampled ,List[int] type , for example [2,2] Represent sampling 2 layer , Each node in each layer samples at most 2 A neighbor ;
  • input_nodes: The node index of the original graph to be included in the sub graph sampled from the original graph , namely 2.1 The first in the section B \mathcal{B} B,torch.Tensor() type ;
  • directed: If set to False, All edges between all sampling nodes will be included ;
  • **kwargstorch.utils.data.DataLoader Extra parameters for , for example batch_size,shuffle( See this for details API).

2.3 Sampling practice

For visual aesthetics , The figure data used in this section is PyG Provided in KarateClub Data sets , This data set describes the social relationships of a karate club member , The node is 34 Members , If two members are still socializing outside the club , Then connect edges between corresponding nodes , The visualization of this dataset is as follows :

karateclub

The following is the loading of the dataset 、 Visualization and neighbor sampling source code :

import torch
from torch_geometric.datasets import KarateClub
from torch_geometric.utils import to_networkx
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.loader import NeighborLoader


def draw(graph):
    nids = graph.n_id
    graph = to_networkx(graph)
    for i, nid in enumerate(nids):
        graph.nodes[i]['txt'] = str(nid.item())
    node_labels = nx.get_node_attributes(graph, 'txt')
    # print(node_labels)
    # {0: '14', 1: '32', 2: '33', 3: '18', 4: '30', 5: '28', 6: '20'}
    nx.draw_networkx(graph, labels=node_labels, node_color='#00BFFF')
    plt.axis("off")
    plt.show()


dataset = KarateClub()
g = dataset[0]
# print(g)
# Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34])
g.n_id = torch.arange(g.num_nodes)

for s in NeighborLoader(g, num_neighbors=[2, 2], input_nodes=torch.Tensor([14])):
    draw(s)
    break

In the above source code , The number of sampling layers set is 2 layer 、 Each node samples at most per layer 2 A neighbor , The initial node set for sampling is {14}, The corresponding sampling results are as follows :

graphsage_sampler

As can be seen from the above figure , In the first iteration , Sampled nodes {14} Of the two 1 Jump neighbor {32,33}, And then, in the second iteration, we'll do {32,33} Samples were taken separately to obtain {2,8]} and {18,30}.

It should be noted that through NeighborLoader In the returned subgraph , The global node index is mapped to the local index corresponding to the subgraph . therefore , To map the nodes in the current mining graph to the corresponding nodes in the original graph , Sure Create an attribute in the original diagram to complete the mapping between the two , For example, in the source code of sampling practice :

g.n_id = torch.arange(g.num_nodes)

So since , After sampling, the nodes in the subgraph also contain n_id attribute , In this way, the nodes of the subgraph can be mapped back , The visualization of the graph in the above example takes advantage of this , The corresponding mapping is :

{
    0: '14', 1: '32', 2: '33', 3: '18', 4: '30', 5: '28', 6: '20'}

Conclusion

PyG The implementation of neighbor sampling in is far more than the above , See the following official website for details :

原网站

版权声明
本文为[Si Xi is towering]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/172/202206210624433025.html