Simple and efficient RevNet-Library with DeepSpeed support

Related tags

Text Data & NLPrevlib
Overview

RevLib

Simple and efficient RevNet-Library with DeepSpeed support

Features

  • Half the constant memory usage and faster than RevNet libraries
  • Less memory than gradient checkpointing (1 * output_size instead of n_layers * output_size)
  • Same speed as activation checkpointing
  • Extensible
  • Trivial code (<100 Lines)

Getting started

Installation

python3 -m pip install revlib

Examples

iRevNet

iRevNet is not only partially reversible but instead a fully-invertible model. The source code looks complex at first glance. It also doesn't use the memory savings it could utilize, as RevNet requires custom AutoGrad functions that are hard to maintain. An iRevNet can be implemented like this using revlib:

import torch
from torch import nn
import revlib

channels = 64
channel_multiplier = 4
depth = 3
classes = 1000


# Create a basic function that's reversibly executed multiple times. (Like f() in ResNet)
def conv(in_channels, out_channels):
    return nn.Conv2d(in_channels, out_channels, (3, 3), padding=1)


def block_conv(in_channels, out_channels):
    return nn.Sequential(conv(in_channels, out_channels),
                         nn.Dropout(0.2),
                         nn.BatchNorm2d(out_channels),
                         nn.ReLU())


def block():
    return nn.Sequential(block_conv(channels, channels * channel_multiplier),
                         block_conv(channels * channel_multiplier, channels),
                         nn.Conv2d(channels, channels, (3, 3), padding=1))


# Create a reversible model. f() is invoked depth-times with different weights.
rev_model = revlib.ReversibleSequential(*[block() for _ in range(depth)])

# Wrap reversible model with non-reversible layers
model = nn.Sequential(conv(3, 2*channels), rev_model, conv(2 * channels, classes))

# Use it like you would a regular PyTorch model
inp = torch.randn((1, 3, 224, 224))
out = model(inp)
out.mean().backward()
assert out.size() == (1, 1000, 224, 224)

MomentumNet

MomentumNet is another recent paper that made significant advancements in the area of memory-efficient networks. They propose to use a momentum stream instead of a second model output as illustrated below: MomentumNetIllustration. Implementing that with revlib requires you to write a custom coupling operation (functional analogue to MemCNN) that merges input and output streams.

import torch
from torch import nn
import revlib

channels = 64
depth = 16
momentum_ema_beta = 0.99


# Compute y2 from x2 and f(x1) by merging x2 and f(x1) in the forward pass.
def momentum_coupling_forward(other_stream: torch.Tensor, fn_out: torch.Tensor) -> torch.Tensor:
    return other_stream * momentum_ema_beta + fn_out * (1 - momentum_ema_beta)


# Calculate x2 from y2 and f(x1) by manually computing the inverse of momentum_coupling_forward.
def momentum_coupling_inverse(output: torch.Tensor, fn_out: torch.Tensor) -> torch.Tensor:
    return (output - fn_out * (1 - momentum_ema_beta)) / momentum_ema_beta


# Pass in coupling functions which will be used instead of x2 + f(x1) and y2 - f(x1)
rev_model = revlib.ReversibleSequential(*[layer for _ in range(depth)
                                          for layer in [nn.Conv2d(channels, channels, (3, 3), padding=1),
                                                        nn.Identity()]],
                                        coupling_forward=[momentum_coupling_forward, revlib.additive_coupling_forward],
                                        coupling_inverse=[momentum_coupling_inverse, revlib.additive_coupling_inverse])

inp = torch.randn((16, channels * 2, 224, 224))
out = rev_model(inp)
assert out.size() == (16, channels * 2, 224, 224)

Reformer

Reformer uses RevNet with chunking and LSH-attention to efficiently train a transformer. Using revlib, standard implementations, such as lucidrains' Reformer, can be improved upon to use less memory. Below we're still using the basic building blocks from lucidrains' code to have a comparable model.

import torch
from torch import nn
from reformer_pytorch.reformer_pytorch import LSHSelfAttention, Chunk, FeedForward, AbsolutePositionalEmbedding
import revlib


class Reformer(torch.nn.Module):
    def __init__(self, sequence_length: int, features: int, depth: int, heads: int, bucket_size: int = 64,
                 lsh_hash_count: int = 8, ff_chunks: int = 16, input_classes: int = 256, output_classes: int = 256):
        super(Reformer, self).__init__()
        self.token_embd = nn.Embedding(input_classes, features * 2)
        self.pos_embd = AbsolutePositionalEmbedding(features * 2, sequence_length)

        self.core = revlib.ReversibleSequential(*[nn.Sequential(nn.LayerNorm(features), layer) for _ in range(depth)
                                                 for layer in
                                                 [LSHSelfAttention(features, heads, bucket_size, lsh_hash_count),
                                                  Chunk(ff_chunks, FeedForward(features, activation=nn.GELU), 
                                                        along_dim=-2)]],
                                                split_dim=-1)
        self.out_norm = nn.LayerNorm(features * 2)
        self.out_linear = nn.Linear(features * 2, output_classes)

    def forward(self, inp: torch.Tensor) -> torch.Tensor:
        return self.out_linear(self.out_norm(self.core(self.token_embd(inp) + self.pos_embd(inp))))


sequence = 1024
classes = 16
model = Reformer(sequence, 256, 6, 8, output_classes=classes)
out = model(torch.ones((16, sequence), dtype=torch.long))
assert out.size() == (16, sequence, classes)

Explanation

Most other RevNet libraries, such as MemCNN and Revtorch calculate both f() and g() in one go, to create one large computation. RevLib, on the other hand, brings Mesh TensorFlow's "reversible half residual and swap" to PyTorch. reversible_half_residual_and_swap computes only one of f() and g() and swaps the inputs and gradients. This way, the library only has to store one output as it can recover the other output during the backward pass.
Following Mesh TensorFlow's example, revlib also uses separate x1 and x2 tensors instead of concatenating and splitting at every step to reduce the cost of memory-bound operations.

RevNet's memory consumption doesn't scale with its depth, so it's significantly more memory-efficient for deep models. One problem in most implementations was that two tensors needed to be stored in the output, quadrupling the required memory. The high memory consumption rendered RevNet nearly useless for small networks, such as BERT, with its six layers.
RevLib works around this problem by storing only one output and two inputs for each forward pass, giving a model as small as BERT a >2x improvement!

Ignoring the dual-path structure of a RevNet, it usually used to be much slower than gradient checkpointing. However, RevLib uses minimal coupling functions and has no overhead between Sequence items, allowing it to train as fast as a comparable model with gradient checkpointing.

Owner
Lucas Nestler
German ai researcher
Lucas Nestler
FB ID CLONER WUTHOT CHECKPOINT, FACEBOOK ID CLONE FROM FILE

* MY SOCIAL MEDIA : Programming And Memes Want to contact Mr. Error ? CONTACT : [ema

Mr. Error 9 Jun 17, 2021
CPC-big and k-means clustering for zero-resource speech processing

The CPC-big model and k-means checkpoints used in Analyzing Speaker Information in Self-Supervised Models to Improve Zero-Resource Speech Processing.

Benjamin van Niekerk 5 Nov 23, 2022
Weaviate demo with the text2vec-openai module

Weaviate demo with the text2vec-openai module This repository contains an example of how to use the Weaviate text2vec-openai module. When using this d

SeMI Technologies 11 Nov 11, 2022
Facilitating the design, comparison and sharing of deep text matching models.

MatchZoo Facilitating the design, comparison and sharing of deep text matching models. MatchZoo 是一个通用的文本匹配工具包,它旨在方便大家快速的实现、比较、以及分享最新的深度文本匹配模型。 🔥 News

Neural Text Matching Community 3.7k Jan 02, 2023
A spaCy wrapper of OpenTapioca for named entity linking on Wikidata

spaCyOpenTapioca A spaCy wrapper of OpenTapioca for named entity linking on Wikidata. Table of contents Installation How to use Local OpenTapioca Vizu

Universitätsbibliothek Mannheim 80 Jan 03, 2023
test

Lidar-data-decode In this project, you can decode your lidar data frame(pcap file) and make your own datasets(test dataset) in Windows without any hug

46 Dec 05, 2022
Korean Simple Contrastive Learning of Sentence Embeddings using SKT KoBERT and kakaobrain KorNLU dataset

KoSimCSE Korean Simple Contrastive Learning of Sentence Embeddings implementation using pytorch SimCSE Installation git clone https://github.com/BM-K/

34 Nov 24, 2022
NAACL 2022: MCSE: Multimodal Contrastive Learning of Sentence Embeddings

MCSE: Multimodal Contrastive Learning of Sentence Embeddings This repository contains code and pre-trained models for our NAACL-2022 paper MCSE: Multi

Saarland University Spoken Language Systems Group 39 Nov 15, 2022
List of GSoC organisations with number of times they have been selected.

Welcome to GSoC Organisation Frequency And Details 👋 List of GSoC organisations with number of times they have been selected, techonologies, topics,

Shivam Kumar Jha 41 Oct 01, 2022
Simple tool/toolkit for evaluating NLG (Natural Language Generation) offering various automated metrics.

Simple tool/toolkit for evaluating NLG (Natural Language Generation) offering various automated metrics. Jury offers a smooth and easy-to-use interface. It uses datasets for underlying metric computa

Open Business Software Solutions 129 Jan 06, 2023
Code for CVPR 2021 paper: Revamping Cross-Modal Recipe Retrieval with Hierarchical Transformers and Self-supervised Learning

Revamping Cross-Modal Recipe Retrieval with Hierarchical Transformers and Self-supervised Learning This is the PyTorch companion code for the paper: A

Amazon 69 Jan 03, 2023
Train BPE with fastBPE, and load to Huggingface Tokenizer.

BPEer Train BPE with fastBPE, and load to Huggingface Tokenizer. Description The BPETrainer of Huggingface consumes a lot of memory when I am training

Lizhuo 1 Dec 23, 2021
Unsupervised text tokenizer focused on computational efficiency

YouTokenToMe YouTokenToMe is an unsupervised text tokenizer focused on computational efficiency. It currently implements fast Byte Pair Encoding (BPE)

VK.com 847 Dec 19, 2022
Easily train your own text-generating neural network of any size and complexity on any text dataset with a few lines of code.

textgenrnn Easily train your own text-generating neural network of any size and complexity on any text dataset with a few lines of code, or quickly tr

Max Woolf 4.8k Dec 30, 2022
NL-Augmenter 🦎 → 🐍 A Collaborative Repository of Natural Language Transformations

NL-Augmenter 🦎 → 🐍 The NL-Augmenter is a collaborative effort intended to add transformations of datasets dealing with natural language. Transformat

684 Jan 09, 2023
Negative sampling for solving the unlabeled entity problem in NER. ICLR-2021 paper: Empirical Analysis of Unlabeled Entity Problem in Named Entity Recognition.

Negative Sampling for NER Unlabeled entity problem is prevalent in many NER scenarios (e.g., weakly supervised NER). Our paper in ICLR-2021 proposes u

Yangming Li 128 Dec 29, 2022
This is a MD5 password/passphrase brute force tool

CROWES-PASS-CRACK-TOOl This is a MD5 password/passphrase brute force tool How to install: Do 'git clone https://github.com/CROW31/CROWES-PASS-CRACK-TO

9 Mar 02, 2022
BERT Attention Analysis

BERT Attention Analysis This repository contains code for What Does BERT Look At? An Analysis of BERT's Attention. It includes code for getting attent

Kevin Clark 401 Dec 11, 2022
Smart discord chatbot integrated with Dialogflow to manage different classrooms and assist in teaching!

smart-school-chatbot Smart discord chatbot integrated with Dialogflow to interact with students naturally and manage different classes in a school. De

Tom Huynh 5 Oct 24, 2022
Turn clang-tidy warnings and fixes to comments in your pull request

clang-tidy pull request comments A GitHub Action to post clang-tidy warnings and suggestions as review comments on your pull request. What platisd/cla

Dimitris Platis 30 Dec 13, 2022