Implementation of COCO-LM, Correcting and Contrasting Text Sequences for Language Model Pretraining, in Pytorch

Overview

COCO LM Pretraining (wip)

Implementation of COCO-LM, Correcting and Contrasting Text Sequences for Language Model Pretraining, in Pytorch. They were able to make contrastive learning work in a self-supervised manner for language model pretraining. Seems like a solid successor to Electra.

Install

$ pip install coco-lm-pytorch

Usage

An example using the x-transformers library

$ pip install x-transformers

Then

import torch
from coco_lm_pytorch import COCO

# (1) instantiate the generator and discriminator, making sure that the generator is roughly a quarter to a half of the size of the discriminator

from x_transformers import TransformerWrapper, Encoder

generator = TransformerWrapper(
    num_tokens = 20000,
    emb_dim = 128,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 256,         # smaller hidden dimension
        heads = 4,         # less heads
        ff_mult = 2,       # smaller feedforward dimension
        depth = 1
    )
)

discriminator = TransformerWrapper(
    num_tokens = 20000,
    emb_dim = 128,
    max_seq_len = 1024,
    attn_layers = Encoder(
        dim = 1024,
        heads = 16,
        ff_mult = 4,
        depth = 12
    )
)

# (2) weight tie the token and positional embeddings of generator and discriminator

generator.token_emb = discriminator.token_emb
generator.pos_emb = discriminator.pos_emb

# weight tie any other embeddings if available, token type embeddings, etc.

# (3) instantiate COCO

trainer = COCO(
    generator,
    discriminator,
    discr_dim = 1024,            # the embedding dimension of the discriminator
    discr_layer = 'norm',        # the layer name in the discriminator, whose output would be used for predicting token is still the same or replaced
    cls_token_id = 1,            # a token id must be reserved for [CLS], which is prepended to the sequence for contrastive learning
    mask_token_id = 2,           # the token id reserved for masking
    pad_token_id = 0,            # the token id for padding
    mask_prob = 0.15,            # masking probability for masked language modeling
    mask_ignore_token_ids = [],  # ids of tokens to ignore for mask modeling ex. (cls, sep)
    cl_weight = 1.,              # weight for the contrastive learning loss
    disc_weight = 1.,            # weight for the corrective learning loss
    gen_weight = 1.              # weight for the MLM loss
)

# (4) train

data = torch.randint(0, 20000, (1, 1024))

loss = trainer(data)
loss.backward()

# after much training, the discriminator should have improved

torch.save(discriminator, f'./pretrained-model.pt')

Citations

@misc{meng2021cocolm,
    title   = {COCO-LM: Correcting and Contrasting Text Sequences for Language Model Pretraining}, 
    author  = {Yu Meng and Chenyan Xiong and Payal Bajaj and Saurabh Tiwary and Paul Bennett and Jiawei Han and Xia Song},
    year    = {2021},
    eprint  = {2102.08473},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
You might also like...
Big Bird: Transformers for Longer Sequences

BigBird, is a sparse-attention based transformer which extends Transformer based models, such as BERT to much longer sequences. Moreover, BigBird comes along with a theoretical understanding of the capabilities of a complete transformer that the sparse model can handle.

Beyond Paragraphs: NLP for Long Sequences

Beyond Paragraphs: NLP for Long Sequences

Text-Summarization-using-NLP - Text Summarization using NLP  to fetch BBC News Article and summarize its text and also it includes custom article Summarization PyTorch implementation of Microsoft's text-to-speech system FastSpeech 2: Fast and High-Quality End-to-End Text to Speech.
PyTorch implementation of Microsoft's text-to-speech system FastSpeech 2: Fast and High-Quality End-to-End Text to Speech.

An implementation of Microsoft's "FastSpeech 2: Fast and High-Quality End-to-End Text to Speech"

MILES is a multilingual text simplifier inspired by LSBert - A BERT-based lexical simplification approach proposed in 2018. Unlike LSBert, MILES uses the bert-base-multilingual-uncased model, as well as simple language-agnostic approaches to complex word identification (CWI) and candidate ranking. PyTorch implementation of the paper:  Text is no more Enough! A Benchmark for Profile-based Spoken Language Understanding
PyTorch implementation of the paper: Text is no more Enough! A Benchmark for Profile-based Spoken Language Understanding

Text is no more Enough! A Benchmark for Profile-based Spoken Language Understanding This repository contains the official PyTorch implementation of th

Integrating the Best of TF into PyTorch, for Machine Learning, Natural Language Processing, and Text Generation.  This is part of the CASL project: http://casl-project.ai/
Integrating the Best of TF into PyTorch, for Machine Learning, Natural Language Processing, and Text Generation. This is part of the CASL project: http://casl-project.ai/

Texar-PyTorch is a toolkit aiming to support a broad set of machine learning, especially natural language processing and text generation tasks. Texar

In this repository, I have developed an end to end Automatic speech recognition project. I have developed the neural network model for automatic speech recognition with PyTorch and used MLflow to manage the ML lifecycle, including experimentation, reproducibility, deployment, and a central model registry.
Official PyTorch code for ClipBERT, an efficient framework for end-to-end learning on image-text and video-text tasks

Official PyTorch code for ClipBERT, an efficient framework for end-to-end learning on image-text and video-text tasks. It takes raw videos/images + text as inputs, and outputs task predictions. ClipBERT is designed based on 2D CNNs and transformers, and uses a sparse sampling strategy to enable efficient end-to-end video-and-language learning.

Comments
  • Question about corrective LM loss

    Question about corrective LM loss

    Hi @lucidrains ,

    Thanks for your great repo!

    I looked at your code: coco_lm_pytorch.py. I see there are three losses in line 242. weighted_loss = self.cl_weight * cl_loss + self.gen_weight * mlm_loss + self.disc_weight * disc_loss

    cl_loss is the contrastive loss, mlm_loss is the loss of the auxiliary generator, and disc_loss is the loss of binary discrimination. I wonder where the LM loss of corrective language modeling loss is. Could you point me?

    Best, Abdul.

    opened by elmadany 0
  • What can v0.0.2 do?

    What can v0.0.2 do?

    I'm quite excited to give COCO-LM a try! Thanks as always for the great speedy open source repo @lucidrains .

    Quick question: has this repository been tried on real data, and if so - loosely what type of setup? Trying to figure out whether jumping in coco-lm-pytorch I should have the expectation of being a first beta-tester, or I'm looking at something that is already stable. Thanks!

    opened by dginev 0
Releases(0.0.2)
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
Snips Python library to extract meaning from text

Snips NLU Snips NLU (Natural Language Understanding) is a Python library that allows to extract structured information from sentences written in natur

Snips 3.7k Dec 30, 2022
The simple project to separate mixed voice (2 clean voices) to 2 separate voices.

Speech Separation The simple project to separate mixed voice (2 clean voices) to 2 separate voices. Result Example (Clisk to hear the voices): mix ||

vuthede 31 Oct 30, 2022
Curso práctico: NLP de cero a cien 🤗

Curso Práctico: NLP de cero a cien Comprende todos los conceptos y arquitecturas clave del estado del arte del NLP y aplícalos a casos prácticos utili

Somos NLP 147 Jan 06, 2023
vits chinese, tts chinese, tts mandarin

vits chinese, tts chinese, tts mandarin 史上训练最简单,音质最好的语音合成系统

AmorTX 12 Dec 14, 2022
Residual2Vec: Debiasing graph embedding using random graphs

Residual2Vec: Debiasing graph embedding using random graphs This repository contains the code for S. Kojaku, J. Yoon, I. Constantino, and Y.-Y. Ahn, R

SADAMORI KOJAKU 5 Oct 12, 2022
Textlesslib - Library for Textless Spoken Language Processing

textlesslib Textless NLP is an active area of research that aims to extend NLP t

Meta Research 379 Dec 27, 2022
Pre-Training with Whole Word Masking for Chinese BERT

Pre-Training with Whole Word Masking for Chinese BERT

Yiming Cui 7.7k Dec 31, 2022
Reproducing the Linear Multihead Attention introduced in Linformer paper (Linformer: Self-Attention with Linear Complexity)

Linear Multihead Attention (Linformer) PyTorch Implementation of reproducing the Linear Multihead Attention introduced in Linformer paper (Linformer:

Kui Xu 58 Dec 23, 2022
DeepSpeech - Easy-to-use Speech Toolkit including SOTA ASR pipeline, influential TTS with text frontend and End-to-End Speech Simultaneous Translation.

(简体中文|English) Quick Start | Documents | Models List PaddleSpeech is an open-source toolkit on PaddlePaddle platform for a variety of critical tasks i

5.6k Jan 03, 2023
TensorFlow code and pre-trained models for BERT

BERT ***** New March 11th, 2020: Smaller BERT Models ***** This is a release of 24 smaller BERT models (English only, uncased, trained with WordPiece

Google Research 32.9k Jan 08, 2023
chaii - hindi & tamil question answering

chaii - hindi & tamil question answering This is the solution for rank 5th in Kaggle competition: chaii - Hindi and Tamil Question Answering. The comp

abhishek thakur 33 Dec 18, 2022
The aim of this task is to predict someone's English proficiency based on a text input.

English_proficiency_prediction_NLP The aim of this task is to predict someone's English proficiency based on a text input. Using the The NICT JLE Corp

1 Dec 13, 2021
原神抽卡记录数据集-Genshin Impact gacha data

提要 持续收集原神抽卡记录中 可以使用抽卡记录导出工具导出抽卡记录的json,将json文件发送至[email protected],我会在清除个人信息后

117 Dec 27, 2022
Recognition of 38 speech commands in russian. Based on Yandex Cup 2021 ML Challenge: ASR

Speech_38_ru_commands Recognition of 38 speech commands in russian. Based on Yandex Cup 2021 ML Challenge: ASR Программа умеет распознавать 38 ключевы

Andrey 9 May 05, 2022
[EMNLP 2021] LM-Critic: Language Models for Unsupervised Grammatical Error Correction

LM-Critic: Language Models for Unsupervised Grammatical Error Correction This repo provides the source code & data of our paper: LM-Critic: Language M

Michihiro Yasunaga 98 Nov 24, 2022
Prompt tuning toolkit for GPT-2 and GPT-Neo

mkultra mkultra is a prompt tuning toolkit for GPT-2 and GPT-Neo. Prompt tuning injects a string of 20-100 special tokens into the context in order to

61 Jan 01, 2023
Unsupervised text tokenizer for Neural Network-based text generation.

SentencePiece SentencePiece is an unsupervised text tokenizer and detokenizer mainly for Neural Network-based text generation systems where the vocabu

Google 6.4k Jan 01, 2023
Sequence modeling benchmarks and temporal convolutional networks

Sequence Modeling Benchmarks and Temporal Convolutional Networks (TCN) This repository contains the experiments done in the work An Empirical Evaluati

CMU Locus Lab 3.5k Jan 03, 2023
Framework for fine-tuning pretrained transformers for Named-Entity Recognition (NER) tasks

NERDA Not only is NERDA a mesmerizing muppet-like character. NERDA is also a python package, that offers a slick easy-to-use interface for fine-tuning

Ekstra Bladet 141 Dec 30, 2022
Web mining module for Python, with tools for scraping, natural language processing, machine learning, network analysis and visualization.

Pattern Pattern is a web mining module for Python. It has tools for: Data Mining: web services (Google, Twitter, Wikipedia), web crawler, HTML DOM par

Computational Linguistics Research Group 8.4k Dec 30, 2022