Learning from graph data using Keras

Overview

Steps to run =>

  • Download the cora dataset from this link : https://linqs.soe.ucsc.edu/data
  • unzip the files in the folder input/cora
  • cd code
  • python eda.py
  • python word_features_only.py # for baseline model 53.28% accuracy
  • python graph_embedding.py # for model_1 73.06% accuracy
  • python graph_features_embedding.py # for model_2 76.35% accuracy

Learning from Graph data using Keras and Tensorflow

Cora Data set Citation Graph

Motivation :

There is a lot of data out there that can be represented in the form of a graph in real-world applications like in Citation Networks, Social Networks (Followers graph, Friends network, … ), Biological Networks or Telecommunications.
Using Graph extracted features can boost the performance of predictive models by relying of information flow between close nodes. However, representing graph data is not straightforward especially if we don’t intend to implement hand-crafted features.
In this post we will explore some ways to deal with generic graphs to do node classification based on graph representations learned directly from data.

Dataset :

The Cora citation network data set will serve as the base to the implementations and experiments throughout this post. Each node represents a scientific paper and edges between nodes represent a citation relation between the two papers.
Each node is represented by a set of binary features ( Bag of words ) as well as by a set of edges that link it to other nodes.
The dataset has 2708 nodes classified into one of seven classes. The network has 5429 links. Each Node is also represented by a binary word features indicating the presence of the corresponding word. Overall there is 1433 binary (Sparse) features for each node. In what follows we only use 140 samples for training and the rest for validation/test.

Problem Setting :

Problem : Assigning a class label to nodes in a graph while having few training samples.
Intuition/Hypothesis : Nodes that are close in the graph are more likely to have similar labels.
Solution : Find a way to extract features from the graph to help classify new nodes.

Proposed Approach :


Baseline Model :

Simple Baseline Model

We first experiment with the simplest model that learn to predict node classes using only the binary features and discarding all graph information.
This model is a fully-connected Neural Network that takes as input the binary features and outputs the class probabilities for each node.

Baseline model Accuracy : 53.28%

****This is the initial accuracy that we will try to improve on by adding graph based features.

Adding Graph features :

One way to automatically learn graph features by embedding each node into a vector by training a network on the auxiliary task of predicting the inverse of the shortest path length between two input nodes like detailed on the figure and code snippet below :

Learning an embedding vector for each node

The next step is to use the pre-trained node embedding as input to the classification model. We also add the an additional input which is the average binary features of the neighboring nodes using distance of learned embedding vectors.

The resulting classification network is described in the following figure :

Using pretrained embeddings to do node classification

Graph embedding classification model Accuracy : 73.06%

We can see that adding learned graph features as input to the classification model helps significantly improve the classification accuracy compared to the baseline model from **53.28% to 73.06% ** 😄 .

Improving Graph feature learning :

We can look to further improve the previous model by pushing the pre-training further and using the binary features in the node embedding network and reusing the pre-trained weights from the binary features in addition to the node embedding vector. This results in a model that relies on more useful representations of the binary features learned from the graph structure.

Improved Graph embedding classification model Accuracy : 76.35%

This additional improvement adds a few percent accuracy compared to the previous approach.

Conclusion :

In this post we saw that we can learn useful representations from graph structured data and then use these representations to improve the generalization performance of a node classification model from **53.28% to 76.35% ** 😎 .

Code to reproduce the results is available here : https://github.com/CVxTz/graph_classification

Owner
Mansar Youness
Mansar Youness
An integration of several popular automatic augmentation methods, including OHL (Online Hyper-Parameter Learning for Auto-Augmentation Strategy) and AWS (Improving Auto Augment via Augmentation Wise Weight Sharing) by Sensetime Research.

An integration of several popular automatic augmentation methods, including OHL (Online Hyper-Parameter Learning for Auto-Augmentation Strategy) and AWS (Improving Auto Augment via Augmentation Wise

45 Dec 08, 2022
A PaddlePaddle version of Neural Renderer, refer to its PyTorch version

Neural 3D Mesh Renderer in PadddlePaddle A PaddlePaddle version of Neural Renderer, refer to its PyTorch version Install Run: pip install neural-rende

AgentMaker 13 Jul 12, 2022
Official repository for the ICLR 2021 paper Evaluating the Disentanglement of Deep Generative Models with Manifold Topology

Official repository for the ICLR 2021 paper Evaluating the Disentanglement of Deep Generative Models with Manifold Topology Sharon Zhou, Eric Zelikman

Stanford Machine Learning Group 34 Nov 16, 2022
Using machine learning to predict undergrad college admissions.

College-Prediction Project- Overview: Many have tried, many have failed. Few trailblazers are ambitious enought to chase acceptance into the top 15 un

John H Klinges 1 Jan 05, 2022
Dogs classification with Deep Metric Learning using some popular losses

Tsinghua Dogs classification with Deep Metric Learning 1. Introduction Tsinghua Dogs dataset Tsinghua Dogs is a fine-grained classification dataset fo

QuocThangNguyen 45 Nov 09, 2022
Single-Stage Instance Shadow Detection with Bidirectional Relation Learning (CVPR 2021 Oral)

Single-Stage Instance Shadow Detection with Bidirectional Relation Learning (CVPR 2021 Oral) Tianyu Wang*, Xiaowei Hu*, Chi-Wing Fu, and Pheng-Ann Hen

Steve Wong 51 Oct 20, 2022
Pcos-prediction - Predicts the likelihood of Polycystic Ovary Syndrome based on patient attributes and symptoms

PCOS Prediction 🥼 Predicts the likelihood of Polycystic Ovary Syndrome based on

Samantha Van Seters 1 Jan 10, 2022
This is an official implementation of the CVPR2022 paper "Blind2Unblind: Self-Supervised Image Denoising with Visible Blind Spots".

Blind2Unblind: Self-Supervised Image Denoising with Visible Blind Spots Blind2Unblind Citing Blind2Unblind @inproceedings{wang2022blind2unblind, tit

demonsjin 58 Dec 06, 2022
Deep Video Matting via Spatio-Temporal Alignment and Aggregation [CVPR2021]

Deep Video Matting via Spatio-Temporal Alignment and Aggregation [CVPR2021] Paper: https://arxiv.org/abs/2104.11208 Introduction Despite the significa

76 Dec 07, 2022
Implementation of NÜWA, state of the art attention network for text to video synthesis, in Pytorch

NÜWA - Pytorch (wip) Implementation of NÜWA, state of the art attention network for text to video synthesis, in Pytorch. This repository will be popul

Phil Wang 463 Dec 28, 2022
Speech Separation Using an Asynchronous Fully Recurrent Convolutional Neural Network

Speech Separation Using an Asynchronous Fully Recurrent Convolutional Neural Network This repository is the official implementation of Speech Separati

Kai Li (李凯) 116 Nov 09, 2022
This is a code repository for the paper "Graph Auto-Encoders for Financial Clustering".

Repository for the paper "Graph Auto-Encoders for Financial Clustering" Requirements Python 3.6 torch torch_geometric Instructions This is a simple c

Edward Turner 1 Dec 02, 2021
Locally Differentially Private Distributed Deep Learning via Knowledge Distillation (LDP-DL)

Locally Differentially Private Distributed Deep Learning via Knowledge Distillation (LDP-DL) A preprint version of our paper: Link here This is a samp

Di Zhuang 3 Jan 08, 2023
This is a yolo3 implemented via tensorflow 2.7

YoloV3 - an object detection algorithm implemented via TF 2.x source code In this article I assume you've already familiar with basic computer vision

2 Jan 17, 2022
KIDA: Knowledge Inheritance in Data Aggregation

KIDA: Knowledge Inheritance in Data Aggregation This project releases our 1st place solution on NeurIPS2021 ML4CO Dual Task. Slide and model weights a

24 Sep 08, 2022
Auto HMM: Automatic Discrete and Continous HMM including Model selection

Auto HMM: Automatic Discrete and Continous HMM including Model selection

Chess_champion 29 Dec 07, 2022
Code repository for the paper "Doubly-Trained Adversarial Data Augmentation for Neural Machine Translation" with instructions to reproduce the results.

Doubly Trained Neural Machine Translation System for Adversarial Attack and Data Augmentation Languages Experimented: Data Overview: Source Target Tra

Steven Tan 1 Aug 18, 2022
SwinIR: Image Restoration Using Swin Transformer

SwinIR: Image Restoration Using Swin Transformer This repository is the official PyTorch implementation of SwinIR: Image Restoration Using Shifted Win

Jingyun Liang 2.4k Jan 05, 2023
Kaggle Lyft Motion Prediction for Autonomous Vehicles 4th place solution

Lyft Motion Prediction for Autonomous Vehicles Code for the 4th place solution of Lyft Motion Prediction for Autonomous Vehicles on Kaggle. Discussion

44 Jun 27, 2022
PPO Lagrangian in JAX

PPO Lagrangian in JAX This repository implements PPO in JAX. Implementation is tested on the safety-gym benchmark. Usage Install dependencies using th

Karush Suri 2 Sep 14, 2022