SGLKT-VisDial
Pytorch Implementation for the paper:
Reasoning Visual Dialog with Sparse Graph Learning and Knowledge Transfer
Gi-Cheon Kang, Junseok Park, Hwaran Lee, Byoung-Tak Zhang*, and Jin-Hwa Kim* (* corresponding authors)
In EMNLP 2021 Findings
Setup and Dependencies
This code is implemented using PyTorch v1.0+, and provides out of the box support with CUDA 9+ and CuDNN 7+. Anaconda/Miniconda is the recommended to set up this codebase:
- Install Anaconda or Miniconda distribution based on Python3+ from their downloads' site.
- Clone this repository and create an environment:
git clone https://www.github.com/gicheonkang/sglkt-visdial
conda create -n visdial-ch python=3.6
# activate the environment and install all dependencies
conda activate sglkt
cd sglkt-visdial/
pip install -r requirements.txt
# install this codebase as a package in development version
python setup.py develop
Download Data
- We used the Faster-RCNN pre-trained with Visual Genome as image features. Download the image features below, and put each feature under
$PROJECT_ROOT/data/{SPLIT_NAME}_feature
directory. We needimage_id
to RCNN bounding box index file ({SPLIT_NAME}_imgid2idx.pkl
) because the number of bounding box per image is not fixed (ranging from 10 to 100).
train_btmup_f.hdf5
: Bottom-up features of 10 to 100 proposals from images oftrain
split (32GB).val_btmup_f.hdf5
: Bottom-up features of 10 to 100 proposals from images ofvalidation
split (0.5GB).test_btmup_f.hdf5
: Bottom-up features of 10 to 100 proposals from images oftest
split (2GB).
-
Download the pre-trained, pre-processed word vectors from here (
glove840b_init_300d.npy
), and keep them under$PROJECT_ROOT/data/
directory. You can manually extract the vectors by executingdata/init_glove.py
. -
Download visual dialog dataset from here (
visdial_1.0_train.json
,visdial_1.0_val.json
,visdial_1.0_test.json
, andvisdial_1.0_val_dense_annotations.json
) under$PROJECT_ROOT/data/
directory. -
Download the additional data for Sparse Graph Learning and Knowledge Transfer under
$PROJECT_ROOT/data/
directory.
visdial_1.0_train_coref_structure.json
: structural supervision fortrain
split.visdial_1.0_val_coref_structure.json
: structural supervision forval
split.visdial_1.0_test_coref_structure.json
: structural supervision fortest
split.visdial_1.0_train_dense_labels.json
: pseudo labels for knowledge transfer.visdial_1.0_word_counts_train.json
: word counts fortrain
split.
Training
Train the model provided in this repository as:
python train.py --gpu-ids 0 1 # provide more ids for multi-GPU execution other args...
Saving model checkpoints
This script will save model checkpoints at every epoch as per path specified by --save-dirpath
. Default path is $PROJECT_ROOT/checkpoints
.
Evaluation
Evaluation of a trained model checkpoint can be done as follows:
python evaluate.py --load-pthpath /path/to/checkpoint.pth --split val --gpu-ids 0 1
Validation scores can be checked in offline setting. But if you want to check the test split
score, you have to submit a json file to EvalAI online evaluation server. You can make json format with --save_ranks True
option.
Pre-trained model & Results
We provide the pre-trained models for SGL+KT and SGL.
To reproduce the results reported in the paper, please run the command below.
python evaluate.py --load-pthpath SGL+KT.pth --split test --gpu-ids 0 1 --save-ranks True
Performance on v1.0 test-std
(trained on v1.0
train):
Model | Overall | NDCG | MRR | [email protected] | [email protected] | [email protected] | Mean |
---|---|---|---|---|---|---|---|
SGL+KT | 65.31 | 72.60 | 58.01 | 46.20 | 71.01 | 83.20 | 5.85 |
Citation
If you use this code in your published research, please consider citing:
@article{kang2021reasoning,
title={Reasoning Visual Dialog with Sparse Graph Learning and Knowledge Transfer},
author={Kang, Gi-Cheon and Park, Junseok and Lee, Hwaran and Zhang, Byoung-Tak and Kim, Jin-Hwa},
journal={arXiv preprint arXiv:2004.06698},
year={2021}
}
License
MIT License
Acknowledgements
We use Visual Dialog Challenge Starter Code and MCAN-VQA as reference code.