Code for KDD'20 "Generative Pre-Training of Graph Neural Networks"

Overview

GPT-GNN: Generative Pre-Training of Graph Neural Networks



GPT-GNN is a pre-training framework to initialize GNNs by generative pre-training. It can be applied to large-scale and heterogensous graphs.

You can see our KDD 2020 paper Generative Pre-Training of Graph Neural Networks for more details.

Overview

The key package is GPT_GNN, which contains the the high-level GPT-GNN pretraining framework, base GNN models, and base graph structure and data loader.

To illustrate how to apply the GPT_GNN framework for arbitrary graphs, we provide examples of pre-training on both hetergeneous (OAG) and homogeneous graphs (reddit). Both of them are of large-scale.

Within each example_* package, there is a pretrain_*.py file for pre-training a GNN on the given graph, and also multiple finetune_*.py files for training and validating on downstream tasks.

DataSet

For Open Academic Graph (OAG), we provide a heterogeneous graph containing highly-cited CS papers (8.1G) spanning from 1900-2020. You can download the preprocessed graph via this link. We split the data by their time: Pre-training ( t < 2014 ); Training ( 2014 <= t < 2017); Validation ( t = 2017 ); Testing ( 2018 <= t ). As we use the raw-text as attribute generation task for OAG, we provide a pre-trained word2vec model via this link.

If you want to directly process from raw data, you can download via this link. After downloading it, run preprocess_OAG.py to extract features and store them in our data structure.

For Reddit, we simply download the preprocessed graph using pyG.datasets API, and then turn it into our own data structure using preprocess_reddit.py. We randomly split the data into different sets.

Setup

This implementation is based on pytorch_geometric. To run the code, you need the following dependencies:

You can simply run pip install -r requirements.txt to install all the necessary packages.

Usage

We first introduce the arguments to control hyperparameters. There are mainly three types of arguments, for pre-training; for dataset; for model and optimization.

For pre-training, we provide arguments to control different modules for attribute and edge generation tasks:

  --attr_ratio                     FLOAT   The ratio (0~1) of attribute generation loss .       Default is 0.5.
  --attr_type                      STR     type of attribute decoder ['text' or 'vec']          Default is 'vec'
  --neg_samp_num                   INT     Whether to use layer-norm on the last layer.         Default is False.
  --queue_size                     INT     Max size of adaptive embedding queue.                Default is 256.

For datasets, we provide arguments to control mini-batch sampling:

  --data_dir                       STR     The address of preprocessed graph.
  --pretrain_model_dir             STR     The address for storing the pre-trained models.
  --sample_depth                   INT     How many layers within a mini-batch subgraph         Default is 6.
  --sample_width                   INT     How many nodes to be sampled per layer per type      Default is 128.

For both pre-training and fine-tuning, we provide arguments to control model and optimizer hyperparameters. We highlight some key arguments below:

  --conv_name                      STR     Name of GNN filter (model)                           Default is hgt.
  --scheduler                      STR     Name of learning rate scheduler                      Default is cycle (for pretrain) and cosine (for fine-tuning)
  --n_hid                          INT     Number of hidden dimension                           Default is 400.
  --n_layers                       INT     Number of GNN layers                                 Default is 3.
  --prev_norm                      BOOL    Whether to use layer-norm on previous layers.        Default is False.
  --last_norm                      BOOL    Whether to use layer-norm on the last layer.         Default is False.
  --max_lr                         FLOAT   Maximum learning rate.                               Default is 1e-3 (for pretrain) and 5e-4 (for fine-tuning).  

The following commands pretrain a 3-layer HGT over OAG-CS:

python pretrain_OAG.py --attr_type text --conv_name hgt --n_layers 3 --pretrain_model_dir /datadrive/models/gta_all_cs3

The following commands use the pre-trained model as initialization and finetune on the paper-field classification task using 10% of training and validation data:

python finetune_OAG_PF.py --use_pretrain --pretrain_model_dir /datadrive/models/gta_all_cs3 --n_layer 3 --data_percentage 0.1

Pre-trained Models

  1. The 3-layer HGT model pre-trained over OAG-CS under Time-Transfer Setting via this link
  2. The 3-layer HGT model pre-trained over Reddit via this link

Citation

Please consider citing the following paper when using our code for your application.

@inproceedings{gpt_gnn,
  title={GPT-GNN: Generative Pre-Training of Graph Neural Networks},
  author={Ziniu Hu and Yuxiao Dong and Kuansan Wang and Kai-Wei Chang and Yizhou Sun},
  booktitle={Proceedings of the 26th ACM SIGKDD Conference on Knowledge Discovery and Data Mining},
  year={2020}
}

This implementation is mainly based on pyHGT API.

Owner
Ziniu Hu
CS PhD student at UCLA
Ziniu Hu
Grounding Representation Similarity with Statistical Testing

Grounding Representation Similarity with Statistical Testing This repo contains code to replicate the results in our paper, which evaluates representa

26 Dec 02, 2022
Task-based end-to-end model learning in stochastic optimization

Task-based End-to-end Model Learning in Stochastic Optimization This repository is by Priya L. Donti, Brandon Amos, and J. Zico Kolter and contains th

CMU Locus Lab 164 Dec 29, 2022
Simple reimplemetation experiments about FcaNet

FcaNet-CIFAR An implementation of the paper FcaNet: Frequency Channel Attention Networks on CIFAR10/CIFAR100 dataset. how to run Code: python Cifar.py

76 Feb 04, 2021
An Exact Solver for Semi-supervised Minimum Sum-of-Squares Clustering

PC-SOS-SDP: an Exact Solver for Semi-supervised Minimum Sum-of-Squares Clustering PC-SOS-SDP is an exact algorithm based on the branch-and-bound techn

Antonio M. Sudoso 1 Nov 13, 2022
Dynamic Neural Representational Decoders for High-Resolution Semantic Segmentation

Dynamic Neural Representational Decoders for High-Resolution Semantic Segmentation Requirements This repository needs mmsegmentation Training To train

Adelaide Intelligent Machines (AIM) Group 7 Sep 12, 2022
PhysCap: Physically Plausible Monocular 3D Motion Capture in Real Time

PhysCap: Physically Plausible Monocular 3D Motion Capture in Real Time The implementation is based on SIGGRAPH Aisa'20. Dependencies Python 3.7 Ubuntu

soratobtai 124 Dec 08, 2022
Example repository for custom C++/CUDA operators for TorchScript

Custom TorchScript Operators Example This repository contains examples for writing, compiling and using custom TorchScript operators. See here for the

106 Dec 14, 2022
A CROSS-MODAL FUSION NETWORK BASED ON SELF-ATTENTION AND RESIDUAL STRUCTURE FOR MULTIMODAL EMOTION RECOGNITION

CFN-SR A CROSS-MODAL FUSION NETWORK BASED ON SELF-ATTENTION AND RESIDUAL STRUCTURE FOR MULTIMODAL EMOTION RECOGNITION The audio-video based multimodal

skeleton 15 Sep 26, 2022
Running AlphaFold2 (from ColabFold) in Azure Machine Learning

Running AlphaFold2 (from ColabFold) in Azure Machine Learning Colby T. Ford, Ph.D. Companion repository for Medium Post: How to predict many protein s

Colby T. Ford 3 Feb 18, 2022
DeepLM: Large-scale Nonlinear Least Squares on Deep Learning Frameworks using Stochastic Domain Decomposition (CVPR 2021)

DeepLM DeepLM: Large-scale Nonlinear Least Squares on Deep Learning Frameworks using Stochastic Domain Decomposition (CVPR 2021) Run Please install th

Jingwei Huang 130 Dec 02, 2022
Unsupervised phone and word segmentation using dynamic programming on self-supervised VQ features.

Unsupervised Phone and Word Segmentation using Vector-Quantized Neural Networks Overview Unsupervised phone and word segmentation on speech data is pe

Herman Kamper 13 Dec 11, 2022
A Comprehensive Empirical Study of Vision-Language Pre-trained Model for Supervised Cross-Modal Retrieval

CLIP4CMR A Comprehensive Empirical Study of Vision-Language Pre-trained Model for Supervised Cross-Modal Retrieval The original data and pre-calculate

24 Dec 26, 2022
Paper Title: Heterogeneous Knowledge Distillation for Simultaneous Infrared-Visible Image Fusion and Super-Resolution

HKDnet Paper Title: "Heterogeneous Knowledge Distillation for Simultaneous Infrared-Visible Image Fusion and Super-Resolution" Email:

wasteland 11 Nov 12, 2022
RipsNet: a general architecture for fast and robust estimation of the persistent homology of point clouds

RipsNet: a general architecture for fast and robust estimation of the persistent homology of point clouds This repository contains the code asscoiated

Felix Hensel 14 Dec 12, 2022
Discriminative Condition-Aware PLDA

DCA-PLDA This repository implements the Discriminative Condition-Aware Backend described in the paper: L. Ferrer, M. McLaren, and N. Brümmer, "A Speak

Luciana Ferrer 31 Aug 05, 2022
PyBullet CartPole and Quadrotor environments—with CasADi symbolic a priori dynamics—for learning-based control and reinforcement learning

safe-control-gym Physics-based CartPole and Quadrotor Gym environments (using PyBullet) with symbolic a priori dynamics (using CasADi) for learning-ba

Dynamic Systems Lab 300 Dec 28, 2022
Tensorflow/Keras Plug-N-Play Deep Learning Models Compilation

DeepBay This project was created with the objective of compile Machine Learning Architectures created using Tensorflow or Keras. The architectures mus

Whitman Bohorquez 4 Sep 26, 2022
Extracting knowledge graphs from language models as a diagnostic benchmark of model performance.

Interpreting Language Models Through Knowledge Graph Extraction Idea: How do we interpret what a language model learns at various stages of training?

EPFL Machine Learning and Optimization Laboratory 9 Oct 25, 2022
Global-Local Attention for Emotion Recognition

Global-Local Attention for Emotion Recognition Requirements Python 3 Install tensorflow (or tensorflow-gpu) = 2.0.0 Install some other packages pip i

Minh Nhat Le 15 Apr 21, 2022
Official PyTorch implementation of "ArtFlow: Unbiased Image Style Transfer via Reversible Neural Flows"

ArtFlow Official PyTorch implementation of the paper: ArtFlow: Unbiased Image Style Transfer via Reversible Neural Flows Jie An*, Siyu Huang*, Yibing

123 Dec 27, 2022