Official PyTorch Implementation of SSMix (Findings of ACL 2021)

Related tags

Deep Learningssmix
Overview

SSMix: Saliency-based Span Mixup for Text Classification (Findings of ACL 2021)

Official PyTorch Implementation of SSMix | Paper


SSMix

Abstract

Data augmentation with mixup has shown to be effective on various computer vision tasks. Despite its great success, there has been a hurdle to apply mixup to NLP tasks since text consists of discrete tokens with variable length. In this work, we propose SSMix, a novel mixup method where the operation is performed on input text rather than on hidden vectors like previous approaches. SSMix synthesizes a sentence while preserving the locality of two original texts by span-based mixing and keeping more tokens related to the prediction relying on saliency information. With extensive experiments, we empirically validate that our method outperforms hidden-level mixup methods on the wide range of text classification benchmarks, including textual entailment, sentiment classification, and question-type classification.

Code Structure

|__ augmentation/ --> augmentation methods by method type
    |__ __init__.py --> wrapper for all augmentation methods. Contains metric used for single & paired sentence tasks
    |__ saliency.py --> Calculates saliency by L2 norm gradient backpropagation
    |__ ssmix.py --> Output ssmix sentence with options such as no span and no saliency given two input sentence with additional information
    |__ unk.py --> Output randomly replaced unk sentence 
|__ read_data/ --> Module used for loading data
    |__ __init__.py --> wrapper function for getting data split by train and valid depending on dataset type
    |__  dataset.py --> Class to get NLU dataset
    |__ preprocess.py --> preprocessor that makes input, label, and accuracy metric depending on dataset type
|__ trainer.py --> Code that does actual training 
|__ run_train.py --> Load hyperparameter, initiate training, pipeline
|__ classifiation_model.py -> Augmented from huggingface modeling_bert.py. Define BERT architectures that can handle multiple inputs for Tmix

Part of code is modified from the MixText implementation.

Getting Started

pip install -r requirements.txt

Code is runnable on both CPU and GPU, but we highly recommended to run on GPU. Strictly following the versions specified in the requirements.txt file is desirable to sucessfully execute our code without errors.

Model Training

python run_train.py --batch_size ${BSZ} --seed ${SEED} --dataset {DATASET} --optimizer_lr ${LR} ${MODE}

For all our experiments, we use 32 as the batch size (BSZ), and perform five different runs by changing the seed (SEED) from 0 to 4. We experiment on a wide range of text classifiction datasets (DATASET): 'sst2', 'qqp', 'mnli', 'qnli', 'rte', 'mrpc', 'trec-coarse', 'trec-fine', 'anli'. You should set --anli_round argument to one of 1, 2, 3 for the ANLI dataset.

Once you run the code, trained checkpoints are created under checkpoints directory. To train a model without mixup, you have to set MODE to 'normal'. To run with mixup approaches including our SSMix, you should set MODE as the name of the mixup method ('ssmix', 'tmix', 'embedmix', 'unk'). We load the checkpoint trained without mixup before training with mixup. We use 5e-5 for the normal mode and 1e-5 for mixup methods as the learning rate (LR).

You can modify the argument values (e.g., embed_alpha, hidden_alpha, etc) to adjust to your training hyperparameter needs. For ablation study of SSMix, you can exclude salieny constraint (--ss_no_saliency) or span constraint (--ss_no_span). Type python run_train.py --help or check run_train.py to see the full list of available hyperparameters. For debugging or analysis, you can turn on verbose options (--verbose and --verbose_show_augment_example).

License

Copyright 2021-present NAVER Corp.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Owner
Clova AI Research
Open source repository of Clova AI Research, NAVER & LINE
Clova AI Research
ServiceX Transformer that converts flat ROOT ntuples into columnwise data

ServiceX_Uproot_Transformer ServiceX Transformer that converts flat ROOT ntuples into columnwise data Usage You can invoke the transformer from the co

Vis 0 Jan 20, 2022
Video Matting Refinement For Python

Video-matting refinement Library (use pip to install) scikit-image numpy av matplotlib Run Static background python path_to_video.mp4 Moving backgroun

3 Jan 11, 2022
Python Jupyter kernel using Poetry for reproducible notebooks

Poetry Kernel Use per-directory Poetry environments to run Jupyter kernels. No need to install a Jupyter kernel per Python virtual environment! The id

Pathbird 204 Jan 04, 2023
A geometric deep learning pipeline for predicting protein interface contacts.

A geometric deep learning pipeline for predicting protein interface contacts.

44 Dec 30, 2022
Action Segmentation Evaluation

Reference Action Segmentation Evaluation Code This repository contains the reference code for action segmentation evaluation. If you have a bug-fix/im

5 May 22, 2022
This repository contains the code for TACL2021 paper: SummaC: Re-Visiting NLI-based Models for Inconsistency Detection in Summarization

SummaC: Summary Consistency Detection This repository contains the code for TACL2021 paper: SummaC: Re-Visiting NLI-based Models for Inconsistency Det

Philippe Laban 24 Jan 03, 2023
中文语音识别系列,读者可以借助它快速训练属于自己的中文语音识别模型,或直接使用预训练模型测试效果。

MASR中文语音识别(pytorch版) 开箱即用 自行训练 使用与训练分离(增量训练) 识别率高 说明:因为每个人电脑机器不同,而且有些安装包安装起来比较麻烦,强烈建议直接用我编译好的docker环境跑 目前docker基础环境为ubuntu-cuda10.1-cudnn7-pytorch1.6.

发送小信号 180 Dec 17, 2022
Geneva is an artificial intelligence tool that defeats censorship by exploiting bugs in censors

Geneva is an artificial intelligence tool that defeats censorship by exploiting bugs in censors

Kevin Bock 1.5k Jan 06, 2023
[ICLR 2022] DAB-DETR: Dynamic Anchor Boxes are Better Queries for DETR

DAB-DETR This is the official pytorch implementation of our ICLR 2022 paper DAB-DETR. Authors: Shilong Liu, Feng Li, Hao Zhang, Xiao Yang, Xianbiao Qi

336 Dec 25, 2022
A graphical Semi-automatic annotation tool based on labelImg and Yolov5

💕YOLOV5 semi-automatic annotation tool (Based on labelImg)

EricFang 247 Jan 05, 2023
Scalable, Portable and Distributed Gradient Boosting (GBDT, GBRT or GBM) Library, for Python, R, Java, Scala, C++ and more. Runs on single machine, Hadoop, Spark, Dask, Flink and DataFlow

eXtreme Gradient Boosting Community | Documentation | Resources | Contributors | Release Notes XGBoost is an optimized distributed gradient boosting l

Distributed (Deep) Machine Learning Community 23.6k Dec 31, 2022
Official repository for Automated Learning Rate Scheduler for Large-Batch Training (8th ICML Workshop on AutoML)

Automated Learning Rate Scheduler for Large-Batch Training The official repository for Automated Learning Rate Scheduler for Large-Batch Training (8th

Kakao Brain 35 Jan 04, 2023
Pytorch implementation of our paper under review -- 1xN Pattern for Pruning Convolutional Neural Networks

1xN Pattern for Pruning Convolutional Neural Networks (paper) . This is Pytorch re-implementation of "1xN Pattern for Pruning Convolutional Neural Net

Mingbao Lin (林明宝) 29 Nov 29, 2022
Best Practices on Recommendation Systems

Recommenders What's New (February 4, 2021) We have a new relase Recommenders 2021.2! It comes with lots of bug fixes, optimizations and 3 new algorith

Microsoft 14.8k Jan 03, 2023
SigOpt wrappers for scikit-learn methods

SigOpt + scikit-learn Interfacing This package implements useful interfaces and wrappers for using SigOpt and scikit-learn together Getting Started In

SigOpt 73 Sep 30, 2022
Framework to build and train RL algorithms

RayLink RayLink is a RL framework used to build and train RL algorithms. RayLink was used to build a RL framework, and tested in a large-scale multi-a

Bytedance Inc. 32 Oct 07, 2022
SalFBNet: Learning Pseudo-Saliency Distribution via Feedback Convolutional Networks

SalFBNet This repository includes Pytorch implementation for the following paper: SalFBNet: Learning Pseudo-Saliency Distribution via Feedback Convolu

12 Aug 12, 2022
This is the source code for generating the ASL-Skeleton3D and ASL-Phono datasets. Check out the README.md for more details.

ASL-Skeleton3D and ASL-Phono Datasets Generator The ASL-Skeleton3D contains a representation based on mapping into the three-dimensional space the coo

Cleison Amorim 5 Nov 20, 2022
Selfplay In MultiPlayer Environments

This project allows you to train AI agents on custom-built multiplayer environments, through self-play reinforcement learning.

200 Jan 08, 2023