Usable Implementation of "Bootstrap Your Own Latent" self-supervised learning, from Deepmind, in Pytorch

Overview

Bootstrap Your Own Latent (BYOL), in Pytorch

PyPI version

Practical implementation of an astoundingly simple method for self-supervised learning that achieves a new state of the art (surpassing SimCLR) without contrastive learning and having to designate negative pairs.

This repository offers a module that one can easily wrap any image-based neural network (residual network, discriminator, policy network) to immediately start benefitting from unlabelled image data.

Update 1: There is now new evidence that batch normalization is key to making this technique work well

Update 2: A new paper has successfully replaced batch norm with group norm + weight standardization, refuting that batch statistics are needed for BYOL to work

Update 3: Finally, we have some analysis for why this works

Yannic Kilcher's excellent explanation

Now go save your organization from having to pay for labels :)

Install

$ pip install byol-pytorch

Usage

Simply plugin your neural network, specifying (1) the image dimensions as well as (2) the name (or index) of the hidden layer, whose output is used as the latent representation used for self-supervised training.

import torch
from byol_pytorch import BYOL
from torchvision import models

resnet = models.resnet50(pretrained=True)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = 'avgpool'
)

opt = torch.optim.Adam(learner.parameters(), lr=3e-4)

def sample_unlabelled_images():
    return torch.randn(20, 3, 256, 256)

for _ in range(100):
    images = sample_unlabelled_images()
    loss = learner(images)
    opt.zero_grad()
    loss.backward()
    opt.step()
    learner.update_moving_average() # update moving average of target encoder

# save your improved network
torch.save(resnet.state_dict(), './improved-net.pt')

That's pretty much it. After much training, the residual network should now perform better on its downstream supervised tasks.

BYOL → SimSiam

A new paper from Kaiming He suggests that BYOL does not even need the target encoder to be an exponential moving average of the online encoder. I've decided to build in this option so that you can easily use that variant for training, simply by setting the use_momentum flag to False. You will no longer need to invoke update_moving_average if you go this route as shown in the example below.

import torch
from byol_pytorch import BYOL
from torchvision import models

resnet = models.resnet50(pretrained=True)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = 'avgpool',
    use_momentum = False       # turn off momentum in the target encoder
)

opt = torch.optim.Adam(learner.parameters(), lr=3e-4)

def sample_unlabelled_images():
    return torch.randn(20, 3, 256, 256)

for _ in range(100):
    images = sample_unlabelled_images()
    loss = learner(images)
    opt.zero_grad()
    loss.backward()
    opt.step()

# save your improved network
torch.save(resnet.state_dict(), './improved-net.pt')

Advanced

While the hyperparameters have already been set to what the paper has found optimal, you can change them with extra keyword arguments to the base wrapper class.

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = 'avgpool',
    projection_size = 256,           # the projection size
    projection_hidden_size = 4096,   # the hidden dimension of the MLP for both the projection and prediction
    moving_average_decay = 0.99      # the moving average decay factor for the target encoder, already set at what paper recommends
)

By default, this library will use the augmentations from the SimCLR paper (which is also used in the BYOL paper). However, if you would like to specify your own augmentation pipeline, you can simply pass in your own custom augmentation function with the augment_fn keyword.

augment_fn = nn.Sequential(
    kornia.augmentation.RandomHorizontalFlip()
)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = -2,
    augment_fn = augment_fn
)

In the paper, they seem to assure that one of the augmentations have a higher gaussian blur probability than the other. You can also adjust this to your heart's delight.

augment_fn = nn.Sequential(
    kornia.augmentation.RandomHorizontalFlip()
)

augment_fn2 = nn.Sequential(
    kornia.augmentation.RandomHorizontalFlip(),
    kornia.filters.GaussianBlur2d((3, 3), (1.5, 1.5))
)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = -2,
    augment_fn = augment_fn,
    augment_fn2 = augment_fn2,
)

To fetch the embeddings or the projections, you simply have to pass in a return_embeddings = True flag to the BYOL learner instance

import torch
from byol_pytorch import BYOL
from torchvision import models

resnet = models.resnet50(pretrained=True)

learner = BYOL(
    resnet,
    image_size = 256,
    hidden_layer = 'avgpool'
)

imgs = torch.randn(2, 3, 256, 256)
projection, embedding = learner(imgs, return_embedding = True)

Alternatives

If your downstream task involves segmentation, please look at the following repository, which extends BYOL to 'pixel'-level learning.

https://github.com/lucidrains/pixel-level-contrastive-learning

Citation

@misc{grill2020bootstrap,
    title = {Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning},
    author = {Jean-Bastien Grill and Florian Strub and Florent Altché and Corentin Tallec and Pierre H. Richemond and Elena Buchatskaya and Carl Doersch and Bernardo Avila Pires and Zhaohan Daniel Guo and Mohammad Gheshlaghi Azar and Bilal Piot and Koray Kavukcuoglu and Rémi Munos and Michal Valko},
    year = {2020},
    eprint = {2006.07733},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{chen2020exploring,
    title={Exploring Simple Siamese Representation Learning}, 
    author={Xinlei Chen and Kaiming He},
    year={2020},
    eprint={2011.10566},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}
Comments
  • Negative Loss, Transfer Learning/Fine-Tuning Question

    Negative Loss, Transfer Learning/Fine-Tuning Question

    Hi! Thanks for sharing this repo -- really clean and easy to use.

    When training using the PyTorch Lightning script from the repo, my loss is negative (and gets more negative over time) when training. Is this expected? Screenshot 2020-06-22 at 6 23 47 PM


    I'm curious to know if you've fine-tuned a pretrained model using this BYOL as the README example suggested. If yes, how were the results? Any intuition regarding how many epochs to fine-tune for?

    Thanks!

    opened by rsomani95 13
  • AssertionError: hidden layer never emitted an output with multi-gpu training

    AssertionError: hidden layer never emitted an output with multi-gpu training

    I tried your library with a WideResnet40-2 model and used layer_index=-2.

    The lightning example works fine for single-gpu but i got the error with multiple GPUs.

    opened by reactivetype 7
  • How to transfer the trained ckpt to pytorch.pth model?

    How to transfer the trained ckpt to pytorch.pth model?

    I use the example script to train a model, I got a ckpt file. but how could I extra the trained resnet50.pth instead of the whole SelfSupervisedLearner? Sorry I am new for pytorch lightning lib. What I want is the SelfSupervised resnet50.pth, because I want this to replace the original ImageNet-pretrained one. Thank you a lot.

    opened by knaffe 5
  • Training loss decreased and then increased

    Training loss decreased and then increased

    Hi, I used your example on my own data. The training loss decreased and then increased after 100 epochs, which is wired. Did you meet similar situations? Is it hard to train the model? the batchsize is 128/256 lr is 0.1/0.2 weight_decay is 1e-6

    opened by easonyang1996 4
  • Can't load ckpt

    Can't load ckpt

    I use byol-pytorch-master/examples/lightning/train.py to generate ckpt locally after training, but when I load ckpt, there will be the following errors. How should I load it? Thanks a lot! 截屏2020-11-18 上午12 51 48

    opened by AndrewTal 4
  • BYOL uses different augmentations for view1 and view2

    BYOL uses different augmentations for view1 and view2

    opened by OlivierDehaene 4
  • Transferring results on Cifar and other datasets

    Transferring results on Cifar and other datasets

    Thanks for your open sourcing!

    I notice that the BYOL has a large gap on the transferring downstream datasets: e.g., SimCLR reaches 71.6% on Cifar 100, while BYOL can reach to 78.4%.

    I understand that this might depends on the downstream training protocols. And could you provide us a sample code on that, especially for the LBFGS optimized logistic regressor?

    opened by jacobswan1 4
  • The saved network is same as the initial one?

    The saved network is same as the initial one?

    Firstly, thank you so much for this clean implementation!!

    The self-supervised training process looks good, but the saved (i.e. improved) model is exactly the same as the initial one on my side. Have you observed the same problem?

    The code I tested:

    import torch
    from net.byol import BYOL
    from torchvision import models
     
           
    resnet = models.resnet50(pretrained=True)
    param_1 = resnet.parameters()
    
    learner = BYOL(
        resnet,
        image_size = 256,
        hidden_layer = 'avgpool'
    )
    
    opt = torch.optim.Adam(learner.parameters(), lr=3e-4)
    
    def sample_unlabelled_images():
        return torch.randn(20, 3, 256, 256)
    
    for _ in range(2):
        images = sample_unlabelled_images()
        loss = learner(images)
        opt.zero_grad()
        loss.backward()
        opt.step()
        learner.update_moving_average() # update moving average of target encoder
    
    # save your improved network
    torch.save(resnet.state_dict(), './checkpoints/improved-net.pt')
    
    # restore the model      
    resnet2 = models.resnet50()
    resnet2.load_state_dict(torch.load('./checkpoints/improved-net.pt'))
    param_2 = resnet2.parameters()
    
    # test whether two models are the same 
    for p1, p2 in zip(param_1, param_2):
        if p1.data.ne(p2.data).sum() > 0:
            print('They are different.')
    print('They are same.')
    
    opened by KimMeen 3
  • the maximum batch size can only be set to 32

    the maximum batch size can only be set to 32

    When I run the code with a 2080ti GPU with 10G memory, the maximum batch size can only be set to 32. Is there any place in the code that takes up a lot of video memory?

    opened by cuixianheng 3
  • Pretrained network

    Pretrained network

    Hi, thanks for sharing the code and making it so easy to use. I see in the example you set resnet = models.resnet50(pretrained=True). Is this what is done in the paper? Shouldn't self-supervised-learned networks be trained from scratch?

    Thanks again, P.

    opened by pmorerio 3
  • Singleton Class Members

    Singleton Class Members

    Forgive me for my unfamiliarity with software design, but I'm wondering why it is necessary to write a singleton wrapper for projector and target_encoder. Is there any disadvantage of initializing them in __init__?

    opened by wentaoyuan 3
  • Increase EMA-parameter during training

    Increase EMA-parameter during training

    Hi, I noticed that the EMA-parameter (called beta in the code, τ in the paper) is not updated during training. In the paper they describe that they increase τ from the start value to 1 during training: "Specifically, we set τ = 1 − (1 − τbase) · (cos(πk/K) + 1)/2 with k the current training step and K the maximum number of training steps." This makes a huge difference to the validation loss at the end of the training.

    without_tau_update with_tau_update

    opened by Benjamin-Hansson 1
  • Why the loss is different from BYOL authors'

    Why the loss is different from BYOL authors'

    I found the loss is different from the loss said in BYOL paper which should be a L2 loss and I did't find explanation... The loss in this repo is a cosine loss, and I just want to know why. BTW, thanks for this great repo!

    opened by Jing-XING 2
  • How to cluster/predict images?

    How to cluster/predict images?

    Hi, I have trained using examples given with pytorch-lightning. I couldn't find code to do clustering of images after training. How can I find which image falls in which cluster? Is there any predictor API? I want to do something like this

    image

    opened by laxmimerit 1
  • BN layer weights and biases are not updated

    BN layer weights and biases are not updated

    Thanks for sharing this repo, great work!

    I trained BYOL on my data and noticed that the weights and biases for BN layers are not updated on the saved model. I used resnet18 without pretrained weights resnet = models.resnet50(pretrained=False). After training for multiple epochs, the saved model has bn1.weight all equal to 1.0 and bn1.bias all equal to 0.0 .

    Is this the expected behavior or am I missing something? Appreciate your response!

    opened by kregmi 1
  •  Warning: grad and param do not obey the gradient layout contract.

    Warning: grad and param do not obey the gradient layout contract.

    Has anybody gotten a similar warning when using it?

    Warning: grad and param do not obey the gradient layout contract. This is not an error, but may impair performance. grad.sizes() = [512, 256, 1, 1], strides() = [256, 1, 1, 1] param.sizes() = [512, 256, 1, 1], strides() = [256, 1, 256, 256] (function operator())

    opened by mohaEs 3
Releases(0.6.0)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
This GitHub repo consists of Code and Some results of project- Diabetes Treatment using Gold nanoparticles. These Consist of ML Models used for prediction Diabetes and further the basic theory and working of Gold nanoparticles.

GoldNanoparticles This GitHub repo consists of Code and Some results of project- Diabetes Treatment using Gold nanoparticles. These Consist of ML Mode

1 Jan 30, 2022
Emotional conditioned music generation using transformer-based model.

This is the official repository of EMOPIA: A Multi-Modal Pop Piano Dataset For Emotion Recognition and Emotion-based Music Generation. The paper has b

hung anna 96 Nov 09, 2022
Script that attempts to force M1 macs into RGB mode when used with monitors that are defaulting to YPbPr.

fix_m1_rgb Script that attempts to force M1 macs into RGB mode when used with monitors that are defaulting to YPbPr. No warranty provided for using th

Kevin Gao 116 Jan 01, 2023
The dataset of tweets pulling from Twitters with keyword: Hydroxychloroquine, location: US, Time: 2020

HCQ_Tweet_Dataset: FREE to Download. Keywords: HCQ, hydroxychloroquine, tweet, twitter, COVID-19 This dataset is associated with the paper "Understand

2 Mar 16, 2022
PuppetGAN - Cross-Domain Feature Disentanglement and Manipulation just got way better! 🚀

Better Cross-Domain Feature Disentanglement and Manipulation with Improved PuppetGAN Quite cool... Right? Introduction This repo contains a TensorFlow

Giorgos Karantonis 5 Aug 25, 2022
Watch faces morph into each other with StyleGAN 2, StyleGAN, and DCGAN!

FaceMorpher FaceMorpher is an innovative project to get a unique face morph (or interpolation for geeks) on a website. Yes, this means you can see fac

Anish 9 Jun 24, 2022
Technical Analysis Indicators - Pandas TA is an easy to use Python 3 Pandas Extension with 130+ Indicators

Pandas TA - A Technical Analysis Library in Python 3 Pandas Technical Analysis (Pandas TA) is an easy to use library that leverages the Pandas package

Kevin Johnson 3.2k Jan 09, 2023
Implementation of our NeurIPS 2021 paper "A Bi-Level Framework for Learning to Solve Combinatorial Optimization on Graphs".

PPO-BiHyb This is the official implementation of our NeurIPS 2021 paper "A Bi-Level Framework for Learning to Solve Combinatorial Optimization on Grap

<a href=[email protected]"> 66 Nov 23, 2022
Code for our ACL 2021 paper - ConSERT: A Contrastive Framework for Self-Supervised Sentence Representation Transfer

ConSERT Code for our ACL 2021 paper - ConSERT: A Contrastive Framework for Self-Supervised Sentence Representation Transfer Requirements torch==1.6.0

Yan Yuanmeng 478 Dec 25, 2022
The materials used in the SaxonJS tutorial presented at Declarative Amsterdam, 2021

SaxonJS-Tutorial-2021, version 1.0.4 Last updated on 4 November, 2021. Table of contents Background Prerequisites Starting a web server Running a Java

Saxonica 11 Oct 23, 2022
Style-based Point Generator with Adversarial Rendering for Point Cloud Completion (CVPR 2021)

Style-based Point Generator with Adversarial Rendering for Point Cloud Completion (CVPR 2021) An efficient PyTorch library for Point Cloud Completion.

Microsoft 119 Jan 02, 2023
Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable.

Diffrax Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. Diffrax is a JAX-based library providing numerical differe

Patrick Kidger 717 Jan 09, 2023
Deep Dual Consecutive Network for Human Pose Estimation (CVPR2021)

Beanie - is an asynchronous ODM for MongoDB, based on Motor and Pydantic. It uses an abstraction over Pydantic models and Motor collections to work wi

295 Dec 29, 2022
BDDM: Bilateral Denoising Diffusion Models for Fast and High-Quality Speech Synthesis

Bilateral Denoising Diffusion Models (BDDMs) This is the official PyTorch implementation of the following paper: BDDM: BILATERAL DENOISING DIFFUSION M

172 Dec 23, 2022
Implementation of "Meta-rPPG: Remote Heart Rate Estimation Using a Transductive Meta-Learner"

Meta-rPPG: Remote Heart Rate Estimation Using a Transductive Meta-Learner This repository is the official implementation of Meta-rPPG: Remote Heart Ra

Eugene Lee 137 Dec 13, 2022
This repository provides some of the code implemented and the data used for the work proposed in "A Cluster-Based Trip Prediction Graph Neural Network Model for Bike Sharing Systems".

cluster-link-prediction This repository provides some of the code implemented and the data used for the work proposed in "A Cluster-Based Trip Predict

Bárbara 0 Dec 28, 2022
Accelerated SMPL operation, commonly used in generate 3D human mesh, STAR included.

SMPL2 An enchanced and accelerated SMPL operation which commonly used in 3D human mesh generation. It takes a poses, shapes, cam_trans as inputs, outp

JinTian 20 Oct 17, 2022
CrossMLP - The repository offers the official implementation of our BMVC 2021 paper (oral) in PyTorch.

CrossMLP Cascaded Cross MLP-Mixer GANs for Cross-View Image Translation Bin Ren1, Hao Tang2, Nicu Sebe1. 1University of Trento, Italy, 2ETH, Switzerla

Bingoren 16 Jul 27, 2022
BossNAS: Exploring Hybrid CNN-transformers with Block-wisely Self-supervised Neural Architecture Search

BossNAS This repository contains PyTorch evaluation code, retraining code and pretrained models of our paper: BossNAS: Exploring Hybrid CNN-transforme

Changlin Li 127 Dec 26, 2022
Score refinement for confidence-based 3D multi-object tracking

Score refinement for confidence-based 3D multi-object tracking Our video gives a brief explanation of our Method. This is the official code for the pa

Cognitive Systems Research Group 47 Dec 26, 2022