当前位置:网站首页>torch_ About the geometric Mini batch

torch_ About the geometric Mini batch

2022-06-12 13:06:00 Dongxuan

Mini-batches

PyG Create a sparse block diagonal connection matrix (defined by edge_index) , Splice the features of nodes and labels at the node level . 

So in one_batch in The number of nodes is different . This is different from the previous batch Different , In the past, we used to cut the cake evenly . Each... Entered in this library batch The total number of nodes is different

This special treatment mini batch The way , Use another special class :torch_geometric.loader.DataLoader

from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader

dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

for batch in loader:
    batch
    >>> DataBatch(batch=[1082], edge_index=[2, 4066], x=[1082, 21], y=[32])

    batch.num_graphs
    >>> 32

He spliced the graph data , Final 32 Graph data , Spliced together into a owning assembly 1082 Nodes ,21 The characteristics of dimensions ,4066 Graph of edges batch data  

torch_geometric.data.Batch  Class inheritance  torch_geometric.data.Data  And contains additional properties Pointer array , Specify the figure number of each node :batch.

batch=[0⋯01⋯n−2n−1⋯n−1]⊤( That is to say 0,...,31)

Calculate the average value of each dimension of the node characteristics of each graph  

from torch_scatter import scatter_mean
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader

dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

for data in loader:
    data
    >>> DataBatch(batch=[1082], edge_index=[2, 4066], x=[1082, 21], y=[32])

    data.num_graphs
    >>> 32

    x = scatter_mean(data.x, data.batch, dim=0)
    x.size()
    >>> torch.Size([32, 21])

You can learn more about the internal batching procedure of PyG, e.g., how to modify its behaviour, here. For documentation of scatter operations, we refer the interested reader to the torch-scatterdocumentation.

Data Transforms

Data preprocessing method ( Data to enhance , Data transformation ), You can also link multiple preprocessing methods , Analogy and picture operation , First crop, Normalization and so on . 

Transforms are a common way in torchvision to transform images and perform augmentation. PyG comes with its own transforms, which expect a Data object as input and return a new transformed Data object. Transforms can be chained together using torch_geometric.transforms.Compose and are applied before saving a processed dataset on disk (pre_transform) or before accessing a graph in a dataset (transform).

 transforms on the ShapeNet dataset (containing 17,000 3D shape point clouds and per point labels from 16 shape categories).

The following feeling is extracted Airplane A sample set of  

from torch_geometric.datasets import ShapeNet

dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'])

dataset[0]
>>> Data(pos=[2518, 3], y=[2518])

We can convert the point cloud dataset into a graph dataset by generating nearest neighbor graphs from the point clouds via transforms:

import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet

dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'],
                    pre_transform=T.KNNGraph(k=6))

dataset[0]
>>> Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518])

Note

We use the pre_transform to convert the data before saving it to disk (leading to faster loading times). Note that the next time the dataset is initialized it will already contain graph edges, even if you do not pass any transform. If the pre_transform does not match with the one from the already processed dataset, you will be given a warning.

In addition, we can use the transform argument to randomly augment a Data object, e.g., translating each node position by a small number:

import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet

dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'],
                    pre_transform=T.KNNGraph(k=6),
                    transform=T.RandomTranslate(0.01))

dataset[0]
>>> Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518])

You can find a complete list of all implemented transforms at torch_geometric.transforms.

原网站

版权声明
本文为[Dongxuan]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/163/202206121250499444.html