This repository contains the code for our paper VDA (public in EMNLP2021 main conference)

Related tags

Deep LearningVDA
Overview

Virtual Data Augmentation: A Robust and General Framework for Fine-tuning Pre-trained Models

This repository contains the code for our paper VDA (public in EMNLP2021 main conference)

Quick Links

Overview

We propose a general framework Virtual Data Augmentation (VDA) for robustly fine-tuning Pre-trained Language Models for downstream tasks. Our VDA utilizes a masked language model with Gaussian noise to augment virtual examples for improving the robustness, and also adopts regularized training to further guarantee the semantic relevance and diversity.

Train VDA

In the following section, we describe how to train a model with VDA by using our code.

Training

Data

For evaluation of our VDA, we use 6 text classification datasets, i.e. Yelp, IMDB, AGNews, MR, QNLI and MRPC datasets. These datasets can be downloaded from the GoogleDisk

After download the two ziped files, users should unzip the data fold that contains the training, validation and test data of the 6 datasets. While the Robust fold contains the examples for test the robustness.

Training scripts We public our VDA with 4 base models. For single sentence classification tasks, we use text_classifier_xxx.py files. While for sentence pair classification tasks, we use text_pair_classifier_xxx.py:

  • text_classifier.py and text_pair_classifier.py: BERT-base+VDA

  • text_classifier_freelb.py and text_pair_classifier_freelb.py: FreeLB+VDA on BERT-base

  • text_classifier_smart.py and text_pair_classifier_smart.py: SMART+VDA on BERT-base, where we only use the smooth-inducing adversarial regularization.

  • text_classifier_smix.py and text_pair_classifier_smix.py: Smix+VDA on BERT-base, where we remove the adversarial data augmentation for fair comparison

We provide example scripts for both training and test of our VDA on the 6 datasets. In run_train.sh, we provide 6 example for training on the yelp and qnli datasets. This script calls text_classifier_xxx.py for training (xxx refers to the base model). We explain the arguments in following:

  • --dataset: Training file path.
  • --mlm_path: Pre-trained checkpoints to start with. For now we support BERT-based models (bert-base-uncased, bert-large-uncased, etc.)
  • --save_path: Saved fine-tuned checkpoints file.
  • --max_length: Max sequence length. (For Yelp/IMDB/AG, we use 512. While for MR/QNLI/MRPC, we use 256.)
  • --max_epoch: The maximum training epoch number. (In most of datasets and models, we use 10.)
  • --batch_size: The batch size. (We adapt the batch size to the maximum number w.r.t the GPU memory size. Note that too small number may cause model collapse.)
  • --num_label: The number of labels. (For AG, we use 4. While for other, we use 2.)
  • --lr: Learning rate.
  • --num_warmup: The rate of warm-up steps.
  • --variance: The variance of the Gaussian noise.

For results in the paper, we use Nvidia Tesla V100 32G and Nvidia 3090 24G GPUs to train our models. Using different types of devices or different versions of CUDA/other softwares may lead to slightly different performance.

Evaluation

During training, our model file will show the original accuracy on the test set of the 6 datasets, which evaluates the accuracy performance of our model. Our evaluation code for robustness is based on a modified version of BERT-Attack. It outputs Attack Accuracy, Query Numbers and Perturbation Ratio metrics.

Before evaluation, please download the evaluation datasets for Robustness from the GoogleDisk. Then, following the commonly-used settings, users need to download and process consine similarity matrix following TextFooler.

Based on the checkpoint of the fine-tuned models, we use therun_test.sh script for test the robustness on yelp and qnli datasets. It is based on bert_robust.py file. We explain the arguments in following:

  • --data_path: Training file path.
  • --mlm_path: Pre-trained checkpoints to start with. For now we support BERT-based models (bert-base-uncased, bert-large-uncased, etc.)
  • --tgt_path: The fine-tuned checkpoints file.
  • --num_label: The number of labels. (For AG, we use 4. While for other, we use 2.)

which is expected to output the results as:

original accuracy is 0.960000, attack accuracy is 0.533333, query num is 687.680556, perturb rate is 0.177204

Citation

Please cite our paper if you use VDA in your work:

@inproceedings{zhou2021vda,
  author    = {Kun Zhou, Wayne Xin Zhao, Sirui Wang, Fuzheng Zhang, Wei Wu and Ji-Rong Wen},
  title     = {Virtual Data Augmentation: A Robust and General Framework for Fine-tuning Pre-trained Models},
  booktitle = {{EMNLP} 2021},
  publisher = {The Association for Computational Linguistics},
}
Owner
RUCAIBox
An enthusiastic group that aims to create beautiful things with AI
RUCAIBox
Python module providing a framework to trace individual edges in an image using Gaussian process regression.

Edge Tracing using Gaussian Process Regression Repository storing python module which implements a framework to trace individual edges in an image usi

Jamie Burke 7 Dec 27, 2022
This library is a location of the LegacyLogger for PyTorch Lightning.

neptune-contrib Documentation See neptune-contrib documentation site Installation Get prerequisites python versions 3.5.6/3.6 are supported Install li

neptune.ai 26 Oct 07, 2021
Raindrop strategy for Irregular time series

Graph-Guided Network For Irregularly Sampled Multivariate Time Series Overview This repository contains processed datasets and implementation code for

Zitnik Lab @ Harvard 74 Jan 03, 2023
MetaBalance: High-Performance Neural Networks for Class-Imbalanced Data

This repository is the official PyTorch implementation of Meta-Balance. Find the paper on arxiv MetaBalance: High-Performance Neural Networks for Clas

Arpit Bansal 20 Oct 18, 2021
Web-interface + rest API for classification and regression (https://jeff1evesque.github.io/machine-learning.docs)

Machine Learning This project provides a web-interface, as well as a programmatic-api for various machine learning algorithms. Supported algorithms: S

Jeff Levesque 252 Dec 11, 2022
Ranking Models in Unlabeled New Environments (iccv21)

Ranking Models in Unlabeled New Environments Prerequisites This code uses the following libraries Python 3.7 NumPy PyTorch 1.7.0 + torchivision 0.8.1

14 Dec 17, 2021
Official Keras Implementation for UNet++ in IEEE Transactions on Medical Imaging and DLMIA 2018

UNet++: A Nested U-Net Architecture for Medical Image Segmentation UNet++ is a new general purpose image segmentation architecture for more accurate i

Zongwei Zhou 1.8k Jan 07, 2023
Multiband spectro-radiometric satellite image analysis with K-means cluster algorithm

Multi-band Spectro Radiomertric Image Analysis with K-means Cluster Algorithm Overview Multi-band Spectro Radiomertric images are images comprising of

Chibueze Henry 6 Mar 16, 2022
SatelliteNeRF - PyTorch-based Neural Radiance Fields adapted to satellite domain

SatelliteNeRF PyTorch-based Neural Radiance Fields adapted to satellite domain.

Kai Zhang 46 Nov 20, 2022
A Python package to process & model ChEMBL data.

insilico: A Python package to process & model ChEMBL data. ChEMBL is a manually curated chemical database of bioactive molecules with drug-like proper

Steven Newton 0 Dec 09, 2021
Code Release for Learning to Adapt to Evolving Domains

EAML Code release for "Learning to Adapt to Evolving Domains" (NeurIPS 2020) Prerequisites PyTorch = 0.4.0 (with suitable CUDA and CuDNN version) tor

23 Dec 07, 2022
MCMC samplers for Bayesian estimation in Python, including Metropolis-Hastings, NUTS, and Slice

Sampyl May 29, 2018: version 0.3 Sampyl is a package for sampling from probability distributions using MCMC methods. Similar to PyMC3 using theano to

Mat Leonard 304 Dec 25, 2022
Ranger deep learning optimizer rewrite to use newest components

Ranger21 - integrating the latest deep learning components into a single optimizer Ranger deep learning optimizer rewrite to use newest components Ran

Less Wright 266 Dec 28, 2022
Locally Most Powerful Bayesian Test for Out-of-Distribution Detection using Deep Generative Models

LMPBT Supplementary code for the Paper entitled ``Locally Most Powerful Bayesian Test for Out-of-Distribution Detection using Deep Generative Models"

1 Sep 29, 2022
Pytorch implementation of Generative Models as Distributions of Functions 🌿

Generative Models as Distributions of Functions This repo contains code to reproduce all experiments in Generative Models as Distributions of Function

Emilien Dupont 117 Dec 29, 2022
Code for paper 'Hand-Object Contact Consistency Reasoning for Human Grasps Generation' at ICCV 2021

GraspTTA Hand-Object Contact Consistency Reasoning for Human Grasps Generation (ICCV 2021). Project Page with Videos Demo Quick Results Visualization

Hanwen Jiang 47 Dec 09, 2022
Robot Servers and Server Manager software for robo-gym

robo-gym-server-modules Robot Servers and Server Manager software for robo-gym. For info on how to use this package please visit the robo-gym website

JR ROBOTICS 4 Aug 16, 2021
[SIGGRAPH Asia 2019] Artistic Glyph Image Synthesis via One-Stage Few-Shot Learning

AGIS-Net Introduction This is the official PyTorch implementation of the Artistic Glyph Image Synthesis via One-Stage Few-Shot Learning. paper | suppl

Yue Gao 102 Jan 02, 2023
Advantage Actor Critic (A2C): jax + flax implementation

Advantage Actor Critic (A2C): jax + flax implementation Current version supports only environments with continious action spaces and was tested on muj

Andrey 3 Jan 23, 2022
PyTorch implementation of EGVSR: Efficcient & Generic Video Super-Resolution (VSR)

This is a PyTorch implementation of EGVSR: Efficcient & Generic Video Super-Resolution (VSR), using subpixel convolution to optimize the inference speed of TecoGAN VSR model. Please refer to the offi

789 Jan 04, 2023