A PyTorch implementation of "Graph Classification Using Structural Attention" (KDD 2018).

Overview

GAM

PWC codebeat badge repo sizebenedekrozemberczki

A PyTorch implementation of Graph Classification Using Structural Attention (KDD 2018).

Abstract

Graph classification is a problem with practical applications in many different domains. To solve this problem, one usually calculates certain graph statistics (i.e., graph features) that help discriminate between graphs of different classes. When calculating such features, most existing approaches process the entire graph. In a graphlet-based approach, for instance, the entire graph is processed to get the total count of different graphlets or subgraphs. In many real-world applications, however, graphs can be noisy with discriminative patterns confined to certain regions in the graph only. In this work, we study the problem of attention-based graph classification . The use of attention allows us to focus on small but informative parts of the graph, avoiding noise in the rest of the graph. We present a novel RNN model, called the Graph Attention Model (GAM), that processes only a portion of the graph by adaptively selecting a sequence of “informative” nodes. Experimental results on multiple real-world datasets show that the proposed method is competitive against various well-known methods in graph classification even though our method is limited to only a portion of the graph.

This repository provides an implementation for GAM as described in the paper:

Graph Classification using Structural Attention. John Boaz Lee, Ryan Rossi, and Xiangnan Kong KDD, 2018. [Paper]

Requirements

The codebase is implemented in Python 3.5.2. package versions used for development are just below.

networkx           2.4
tqdm               4.28.1
numpy              1.15.4
pandas             0.23.4
texttable          1.5.0
argparse           1.1.0
sklearn            0.20.0
torch              1.2.0.
torchvision        0.3.0

Datasets

The code takes graphs for training from an input folder where each graph is stored as a JSON. Graphs used for testing are also stored as JSON files. Every node id, node label and class has to be indexed from 0. Keys of dictionaries and nested dictionaries are stored strings in order to make JSON serialization possible.

For example these JSON files have the following key-value structure:

{"target": 1,
 "edges": [[0, 1], [0, 4], [1, 3], [1, 4], [2, 3], [2, 4], [3, 4]],
 "labels": {"0": 2, "1": 3, "2": 2, "3": 3, "4": 4},
 "inverse_labels": {"2": [0, 2], "3": [1, 3], "4": [4]}}

The **target key** has an integer value, which is the ID of the target class (e.g. Carcinogenicity). The **edges key** has an edge list value for the graph of interest. The **labels key** has a dictonary value for each node, these labels are stored as key-value pairs (e.g. node - atom pair). The **inverse_labels key** has a key for each node label and the values are lists containing the nodes that have a specific node label.

Options

Training a GAM model is handled by the src/main.py script which provides the following command line arguments.

Input and output options

  --train-graph-folder   STR    Training graphs folder.      Default is `input/train/`.
  --test-graph-folder    STR    Testing graphs folder.       Default is `input/test/`.
  --prediction-path      STR    Path to store labels.        Default is `output/erdos_predictions.csv`.
  --log-path             STR    Log json path.               Default is `logs/erdos_gam_logs.json`. 

Model options

  --repetitions          INT         Number of scoring runs.                  Default is 10. 
  --batch-size           INT         Number of graphs processed per batch.    Default is 32. 
  --time                 INT         Time budget.                             Default is 20. 
  --step-dimensions      INT         Neurons in step layer.                   Default is 32. 
  --combined-dimensions  INT         Neurons in shared layer.                 Default is 64. 
  --epochs               INT         Number of GAM training epochs.           Default is 10. 
  --learning-rate        FLOAT       Learning rate.                           Default is 0.001.
  --gamma                FLOAT       Discount rate.                           Default is 0.99. 
  --weight-decay         FLOAT       Weight decay.                            Default is 10^-5. 

Examples

The following commands learn a neural network, make predictions, create logs, and write the latter ones to disk.

Training a GAM model on the default dataset. Saving predictions and logs at default paths.

python src/main.py

Training a GAM model for a 100 epochs with a batch size of 512.

python src/main.py --epochs 100 --batch-size 512

Setting a high time budget for the agent.

python src/main.py --time 128

Training a model with some custom learning rate and epoch number.

python src/main.py --learning-rate 0.001 --epochs 200

License


Owner
Benedek Rozemberczki
Machine Learning Engineer at AstraZeneca | PhD from The University of Edinburgh.
Benedek Rozemberczki
Contrastive Multi-View Representation Learning on Graphs

Contrastive Multi-View Representation Learning on Graphs This work introduces a self-supervised approach based on contrastive multi-view learning to l

Kaveh 208 Dec 23, 2022
A treasure chest for visual recognition powered by PaddlePaddle

简体中文 | English PaddleClas 简介 飞桨图像识别套件PaddleClas是飞桨为工业界和学术界所准备的一个图像识别任务的工具集,助力使用者训练出更好的视觉模型和应用落地。 近期更新 2021.11.1 发布PP-ShiTu技术报告,新增饮料识别demo 2021.10.23 发

4.6k Dec 31, 2022
How will electric vehicles affect traffic congestion and energy consumption: an integrated modelling approach

EV-charging-impact This repository contains the code that has been used for the Queue modelling for the paper "How will electric vehicles affect traff

7 Nov 30, 2022
Towards Understanding Quality Challenges of the Federated Learning: A First Look from the Lens of Robustness

FL Analysis This repository contains the code and results for the paper "Towards Understanding Quality Challenges of the Federated Learning: A First L

3 Oct 17, 2022
Vision-Language Pre-training for Image Captioning and Question Answering

VLP This repo hosts the source code for our AAAI2020 work Vision-Language Pre-training (VLP). We have released the pre-trained model on Conceptual Cap

Luowei Zhou 373 Jan 03, 2023
Code for "Intra-hour Photovoltaic Generation Forecasting based on Multi-source Data and Deep Learning Methods."

pv_predict_unet-lstm Code for "Intra-hour Photovoltaic Generation Forecasting based on Multi-source Data and Deep Learning Methods." IEEE Transactions

FolkScientistInDL 8 Oct 08, 2022
Official implementation of our CVPR2021 paper "OTA: Optimal Transport Assignment for Object Detection" in Pytorch.

OTA: Optimal Transport Assignment for Object Detection This project provides an implementation for our CVPR2021 paper "OTA: Optimal Transport Assignme

217 Jan 03, 2023
Like Dirt-Samples, but cleaned up

Clean-Samples Like Dirt-Samples, but cleaned up, with clear provenance and license info (generally a permissive creative commons licence but check the

TidalCycles 39 Nov 30, 2022
Python scripts to detect faces in Python with the BlazeFace Tensorflow Lite models

Python scripts to detect faces using Python with the BlazeFace Tensorflow Lite models. Tested on Windows 10, Tensorflow 2.4.0 (Python 3.8).

Ibai Gorordo 46 Nov 17, 2022
ShuttleNet: Position-aware Fusion of Rally Progress and Player Styles for Stroke Forecasting in Badminton (AAAI 2022)

ShuttleNet: Position-aware Rally Progress and Player Styles Fusion for Stroke Forecasting in Badminton (AAAI 2022) Official code of the paper ShuttleN

Wei-Yao Wang 11 Nov 30, 2022
Simple improvement of VQVAE that allow to generate x2 sized images compared to baseline

vqvae_dwt_distiller.pytorch Simple improvement of VQVAE that allow to generate x2 sized images compared to baseline. It allows to generate 512x512 ima

Sergei Belousov 25 Jul 19, 2022
《LXMERT: Learning Cross-Modality Encoder Representations from Transformers》(EMNLP 2020)

The Most Important Thing. Our code is developed based on: LXMERT: Learning Cross-Modality Encoder Representations from Transformers

53 Dec 16, 2022
PyTorch implementation for our paper "Deep Facial Synthesis: A New Challenge"

FSGAN Here is the official PyTorch implementation for our paper "Deep Facial Synthesis: A New Challenge". This project achieve the translation between

Deng-Ping Fan 32 Oct 10, 2022
Source code for "UniRE: A Unified Label Space for Entity Relation Extraction.", ACL2021.

UniRE Source code for "UniRE: A Unified Label Space for Entity Relation Extraction.", ACL2021. Requirements python: 3.7.6 pytorch: 1.8.1 transformers:

Wang Yijun 109 Nov 29, 2022
Code for Fold2Seq paper from ICML 2021

[ICML2021] Fold2Seq: A Joint Sequence(1D)-Fold(3D) Embedding-based Generative Model for Protein Design Environment file: environment.yml Data and Feat

International Business Machines 43 Dec 04, 2022
Official git for "CTAB-GAN: Effective Table Data Synthesizing"

CTAB-GAN This is the official git paper CTAB-GAN: Effective Table Data Synthesizing. The paper is published on Asian Conference on Machine Learning (A

30 Dec 26, 2022
A PyTorch implementation of "Capsule Graph Neural Network" (ICLR 2019).

CapsGNN ⠀⠀ A PyTorch implementation of Capsule Graph Neural Network (ICLR 2019). Abstract The high-quality node embeddings learned from the Graph Neur

Benedek Rozemberczki 1.2k Jan 02, 2023
PyTorch implementation of our paper How robust are discriminatively trained zero-shot learning models?

How robust are discriminatively trained zero-shot learning models? This repository contains the PyTorch implementation of our paper How robust are dis

Mehmet Kerim Yucel 5 Feb 04, 2022
Code for the paper A Theoretical Analysis of the Repetition Problem in Text Generation

A Theoretical Analysis of the Repetition Problem in Text Generation This repository share the code for the paper "A Theoretical Analysis of the Repeti

Zihao Fu 37 Nov 21, 2022
A strongly-typed genetic programming framework for Python

monkeys "If an army of monkeys were strumming on typewriters they might write all the books in the British Museum." monkeys is a framework designed to

H. Chase Stevens 115 Nov 27, 2022