Learned Token Pruning for Transformers

Overview

LTP: Learned Token Pruning for Transformers

Screenshot from 2021-07-08 13-39-02

Screenshot from 2021-07-08 13-39-28

Check our paper for more details.

Installation

We follow the same installation procedure as the original Huggingface transformer repo.

pip install sklearn scipy datasets torch
pip install -e .  # in the top directory

Prepare Checkpoints

LTP is implemented on top of Huggingface transformer's I-BERT implementation. Therefore, we first need to generate a checkpoint file of ibert finetuned on the target downstream task. While you can do this on the original Huggingface repository, we also support our base branch ltp/base where you can run the following code to finetune ibert on the GLUE tasks.

git checkout ltp/base
cd examples/text-classification
python run_glue.py --model_name_or_path kssteven/ibert-roberta-base --output_dir {CKPT} --task {TASK} --do_train --do_eval {--some_more_arguments}
  • {TASK}: RTE, MRPC, STSB, SST2, QNLI, QQP, MNLI
  • Please refer to the Huggingface tutorial and the official documentation for more details in arguments and hyperparameters.
  • Note that as default ibert behaves the same as roberta (see this tutorial), hence the resulting model will be the same as roberta-base finetuned on the target GLUE task.

The final model will be checkpointed in {CKPT}.

  • Remove {CKPT}/trainer_state.json.
  • In the configuration file {CKPT}/config.json, change (1) "architectures" to ["LTPForSequenceClassification"] and (2) "model_type" to "ltp".

Run Learned Token Pruning

Add the following lines in the configuration file {CKPT}/config.json.

"prune_mode": "absolute_threshold",
"final_token_threshold": 0.01, 
"scoring_mode": "mean",

final_token_threshold determines the token threshold of the last layer, and the thresholds of the remaining layers will be linearly scaled. For instance, the thresholds for the 3rd, 6th, and 9th layers will be 0.0025, 0.005, and 0.0075, respectively, when setting the final_token_threshold , i.e., the threshold for the last (12th) layer, to 0.01. This number is a hyperparameter, and we found that 0.01 works well in many cases.

The learnable mode consists of 2 stages: soft threshold and hard threshold. Please refer to our paper for more details.

1. Soft Threshold

We first train the model using the soft threshold mode. This trains the thresholds as well as the model parameters to search for the best threshold configuration.

Run the following command:

python run.py --arch ltp-base --task {TASK} --restore {CKPT} --lr 2e-5 --temperature {T}\
  --lambda 0.1 --weight_decay 0 --bs 64 --masking_mode soft --epoch {epoch} --save_step 100 --no_load
  • {TASK}: RTE, MRPC, STSB, SST2, QNLI, QQP, MNLI
  • You can assign different learning rate for lr, but 2e-5 worked fine.
  • We set {epoch} to be 10 for smaller datasets (e.g., RTE, MRPC) and 1 for larger datasets (e.g., SST2, QNLI, MRPC).
  • --no_load flag will not load the best model at the end of the training (i.e., the final checkpoint will be the one at the end of training).
  • lambda is an important hyperparameter than controls the pruning level: the higher the value, the more we prune tokens. 0.01 ~ 0.2 worked well in many cases, but we recommend the user to empirically search for the best number for it.
  • temperature is another hyperparameter, and 1e-3 ~ 1e-5 worked well. In the paper, we searched over {1e−4, 2e−4, 5e−4, 1e−3, 2e−3}.

The final model will be checkpointed in {CKPT_soft} = checkpoints/base/{TASK}/absolute_threshold/rate_{final_token_threshold}/temperature_{T}/lambda_{lambda}/lr_{lr}. Remove trainer_state.json from the checkpoint file in {CKPT_soft}.

2. Hard Threshold

Once we learn the thresholds, we fix those values, turn back to the hard threshold mode, and finetune the model parameters only.

Run the following command:

python run.py --arch ltp-base --task {TASK} --restore {CKPT_soft} --lr {LR} --bs 64 --masking_mode hard --epoch 5 
  • We used {LR} {0.5, 1, 2}e-5 in the paper.
  • You can additionally set --save_step 500 for more frequent evaluation/logging. The default setting will evaluate for every 1 epoch.

The final model will be checkpointed in {CKPT_soft}/hard/lr_{LR}.

Run Baseline Methods

We additionally provide code to reproduce the baseline methods used in our paper (i.e., top-k and manual threshold).

Top-k Token Pruning

Add the following lines in {CKPT}/config.json.

"prune_mode": "topk",
"token_keep_rate": 0.2,

The token keep rates of the first three layers and the last layer are 1 and token_keep_rate, respectively. The keep rates of the remaining layers are scaled linearly. The smaller token_keep_rate is, the more aggressive we prune tokens. You can also assign negative number for token_keep_rate and, in that case, the keep rate of each layer will be assigned as max(0, keep_rate).

Run the following command:

python run.py --arch ltp-base --task {TASK} --restore {CKPT} --lr {LR} --bs 64 --masking_mode hard --epoch 5
  • We used {LR} {0.5, 1, 2}e-5 in the paper.
  • You can additionally set --save_step 500 for more frequent evaluation/logging. The default setting will evaluate for every 1 epoch.

The final model will be checkpointed in {CKPT}/topk/lr_{LR}.

Manual(Non-leanrable) Threshold Pruning

Add the following lines in {CKPT}/config.json.

"prune_mode": "absolute_threshold",
"final_token_threshold": 0.01, 
"scoring_mode": "mean",

Run the following command:

python run.py --arch ltp-base --task {TASK} --restore {CKPT} --lr {LR} --bs 64 --masking_mode hard --epoch 5 --save_step 500
  • We used {LR} {0.5, 1, 2}e-5 in the paper.
  • You can additionally set --save_step 500 for more frequent evaluation/logging. The default setting will evaluate for every 1 epoch.
  • Note that the only difference from the learned token pruning mode is that we run the hard threshold mode from the beginning.

The final model will be checkpointed in {CKPT}/hard/lr_{LR}.

This repository contains the source code for the paper Tutorial on amortized optimization for learning to optimize over continuous domains by Brandon Amos

Tutorial on Amortized Optimization This repository contains the source code for the paper Tutorial on amortized optimization for learning to optimize

Meta Research 144 Dec 26, 2022
Tom-the-AI - A compound artificial intelligence software for Linux systems.

Tom the AI (version 0.82) WARNING: This software is not yet ready to use, I'm still setting up the GitHub repository. Should be ready in a few days. T

2 Apr 28, 2022
Exploring whether attention is necessary for vision transformers

Do You Even Need Attention? A Stack of Feed-Forward Layers Does Surprisingly Well on ImageNet Paper/Report TL;DR We replace the attention layer in a v

Luke Melas-Kyriazi 461 Jan 07, 2023
Multi-Objective Reinforced Active Learning

Multi-Objective Reinforced Active Learning Dependencies wandb tqdm pytorch = 1.7.0 numpy = 1.20.0 scipy = 1.1.0 pycolab == 1.2 Weights and Biases O

Markus Peschl 6 Nov 19, 2022
Task-related Saliency Network For Few-shot learning

Task-related Saliency Network For Few-shot learning This is an official implementation in Tensorflow of TRSN. Abstract An essential cue of human wisdo

1 Nov 18, 2021
Official PyTorch code of Holistic 3D Scene Understanding from a Single Image with Implicit Representation (CVPR 2021)

Implicit3DUnderstanding (Im3D) [Project Page] Holistic 3D Scene Understanding from a Single Image with Implicit Representation Cheng Zhang, Zhaopeng C

Cheng Zhang 149 Jan 08, 2023
Unified Pre-training for Self-Supervised Learning and Supervised Learning for ASR

UniSpeech The family of UniSpeech: UniSpeech (ICML 2021): Unified Pre-training for Self-Supervised Learning and Supervised Learning for ASR UniSpeech-

Microsoft 282 Jan 09, 2023
Official Implementation of DDOD (Disentangle your Dense Object Detector), ACM MM2021

Disentangle Your Dense Object Detector This repo contains the supported code and configuration files to reproduce object detection results of Disentan

loveSnowBest 51 Jan 07, 2023
Library of various Few-Shot Learning frameworks for text classification

FewShotText This repository contains code for the paper A Neural Few-Shot Text Classification Reality Check Environment setup # Create environment pyt

Thomas Dopierre 47 Jan 03, 2023
PyTorch implementation of DeepDream algorithm

neural-dream This is a PyTorch implementation of DeepDream. The code is based on neural-style-pt. Here we DeepDream a photograph of the Golden Gate Br

121 Nov 05, 2022
MatchGAN: A Self-supervised Semi-supervised Conditional Generative Adversarial Network

MatchGAN: A Self-supervised Semi-supervised Conditional Generative Adversarial Network This repository is the official implementation of MatchGAN: A S

Justin Sun 12 Dec 27, 2022
Official implementation of "Not only Look, but also Listen: Learning Multimodal Violence Detection under Weak Supervision" ECCV2020

XDVioDet Official implementation of "Not only Look, but also Listen: Learning Multimodal Violence Detection under Weak Supervision" ECCV2020. The proj

peng 64 Dec 12, 2022
State-Relabeling Adversarial Active Learning

State-Relabeling Adversarial Active Learning Code for SRAAL [2020 CVPR Oral] Requirements torch = 1.6.0 numpy = 1.19.1 tqdm = 4.31.1 AL Results The

10 Jul 14, 2022
The codebase for Data-driven general-purpose voice activity detection.

Data driven GPVAD Repository for the work in TASLP 2021 Voice activity detection in the wild: A data-driven approach using teacher-student training. S

Heinrich Dinkel 75 Nov 27, 2022
Revisiting Weakly Supervised Pre-Training of Visual Perception Models

SWAG: Supervised Weakly from hashtAGs This repository contains SWAG models from the paper Revisiting Weakly Supervised Pre-Training of Visual Percepti

Meta Research 134 Jan 05, 2023
gtfs2vec - Learning GTFS Embeddings for comparing PublicTransport Offer in Microregions

gtfs2vec This is a companion repository for a gtfs2vec - Learning GTFS Embeddings for comparing PublicTransport Offer in Microregions publication. Vis

Politechnika Wrocławska - repozytorium dla informatyków 5 Oct 10, 2022
Heterogeneous Temporal Graph Neural Network

Heterogeneous Temporal Graph Neural Network This repository contains the datasets and source code of HTGNN. run_mag.ipynb is the training and testing

15 Dec 22, 2022
code for our ECCV 2020 paper "A Balanced and Uncertainty-aware Approach for Partial Domain Adaptation"

Code for our ECCV (2020) paper A Balanced and Uncertainty-aware Approach for Partial Domain Adaptation. Prerequisites: python == 3.6.8 pytorch ==1.1.0

32 Nov 27, 2022
Styleformer - Official Pytorch Implementation

Styleformer -- Official PyTorch implementation Styleformer: Transformer based Generative Adversarial Networks with Style Vector(https://arxiv.org/abs/

Jeeseung Park 159 Dec 12, 2022
Pytorch Implementation of Various Point Transformers

Pytorch Implementation of Various Point Transformers Recently, various methods applied transformers to point clouds: PCT: Point Cloud Transformer (Men

Neil You 434 Dec 30, 2022