torchbearer: A model fitting library for PyTorch

Overview

Note: We're moving to PyTorch Lightning! Read about the move here. From the end of February, torchbearer will no longer be actively maintained. We'll continue to fix bugs when they are found and ensure that torchbearer runs on new versions of pytorch. However, we won't plan or implement any new functionality (if there's something you'd like to see in a training library, consider creating an issue on PyTorch Lightning).

logo

PyPI version Python 2.7 | 3.5 | 3.6 | 3.7 PyTorch 1.0.0 | 1.1.0 | 1.2.0 | 1.3.0 | 1.4.0 Build Status codecov Documentation Status Downloads

WebsiteDocsExamplesInstallCitingRelated

A PyTorch model fitting library designed for use by researchers (or anyone really) working in deep learning or differentiable programming. Specifically, we aim to dramatically reduce the amount of boilerplate code you need to write without limiting the functionality and openness of PyTorch.

Examples

General

Quickstart: Get up and running with torchbearer, training a simple CNN on CIFAR-10.
Callbacks: A detailed exploration of callbacks in torchbearer, with some useful visualisations.
Imaging: A detailed exploration of the imaging sub-package in torchbearer, useful for showing visualisations during training.
Serialization: This guide gives an introduction to serializing and restarting training in torchbearer.
History and Replay: This guide gives an introduction to the history returned by a trial and the ability to replay training.
Custom Data Loaders: This guide gives an introduction on how to run custom data loaders in torchbearer.
Data Parallel: This guide gives an introduction to using torchbearer with DataParrallel.
LiveLossPlot: A demonstration of the LiveLossPlot callback included in torchbearer.
PyCM: A demonstration of the PyCM callback included in torchbearer for generating confusion matrices.
NVIDIA Apex: A guide showing how to perform half and mixed precision training in torchbearer with NVIDIA Apex.

Deep Learning

Training a VAE: A demonstration of how to train (add do a simple visualisation of) a Variational Auto-Encoder (VAE) on MNIST with torchbearer.
Training a GAN: A demonstration of how to train (add do a simple visualisation of) a Generative Adversarial Network (GAN) on MNIST with torchbearer.
Generating Adversarial Examples: A demonstration of how to perform a simple adversarial attack with torchbearer.
Transfer Learning with Torchbearer: A demonstration of how to perform transfer learning on STL10 with torchbearer.
Regularisers in Torchbearer: A demonstration of how to use all of the built-in regularisers in torchbearer (Mixup, CutOut, CutMix, Random Erase, Label Smoothing and Sample Pairing).
Manifold Mixup: A demonstration of how to use the Manifold Mixup callback in Torchbearer.
Class Appearance Model: A demonstration of the Class Appearance Model (CAM) callback in torchbearer.

Differentiable Programming

Optimising Functions: An example (and some fun visualisations) showing how torchbearer can be used for the purpose of optimising functions with respect to their parameters using gradient descent.
Linear SVM: Train a linear support vector machine (SVM) using torchbearer, with an interactive visualisation!
Breaking Adam: The Adam optimiser doesn't always converge, in this example we reimplement some of the function optimisations from the AMSGrad paper showing this empirically.

Install

The easiest way to install torchbearer is with pip:

pip install torchbearer

Alternatively, build from source with:

pip install git+https://github.com/pytorchbearer/torchbearer

Citing Torchbearer

If you find that torchbearer is useful to your research then please consider citing our preprint: Torchbearer: A Model Fitting Library for PyTorch, with the following BibTeX entry:

@article{torchbearer2018,
  author = {Ethan Harris and Matthew Painter and Jonathon Hare},
  title = {Torchbearer: A Model Fitting Library for PyTorch},
  journal  = {arXiv preprint arXiv:1809.03363},
  year = {2018}
}

Related

Torchbearer isn't the only library for training PyTorch models. Here are a few others that might better suit your needs (this is by no means a complete list, see the awesome pytorch list or the incredible pytorch for more):

  • skorch, model wrapper that enables use with scikit-learn - crossval etc. can be very useful
  • PyToune, simple Keras style API
  • ignite, advanced model training from the makers of PyTorch, can need a lot of code for advanced functions (e.g. Tensorboard)
  • TorchNetTwo (TNT), can be complex to use but well established, somewhat replaced by ignite
  • Inferno, training utilities and convenience classes for PyTorch
  • Pytorch Lightning, lightweight wrapper on top of PyTorch with advanced multi-gpu and cluster support
  • Pywick, high-level training framework, based on torchsample, support for various segmentation models
Comments
  • RuntimeError: CUDA error: CUBLAS_STATUS_ALLOC_FAILED when calling `cublasCreate(handle

    RuntimeError: CUDA error: CUBLAS_STATUS_ALLOC_FAILED when calling `cublasCreate(handle

    Dear all,

    It seems that torchbearer does not want to work for me. I am trying to simply classify images using resnet. You can find my code here (https://github.com/FrancescoSaverioZuppichini/PyTorch-Deep-Learning-Template/tree/feature/cuda-error), the main training logic is:

    import time
    from comet_ml import Experiment
    import torchbearer
    import torch.optim as optim
    import torch.nn as nn
    from torchsummary import summary
    from Project import Project
    from data import get_dataloaders
    from data.transformation import train_transform, val_transform
    from models import MyCNN, resnet18
    from utils import device, show_dl
    from torchbearer import Trial
    from torchbearer.callbacks import CSVLogger, ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
    from callbacks import CometCallback
    from logger import logging
    
    if __name__ == '__main__':
        project = Project()
        # our hyperparameters
        params = {
            'lr': 0.001,
            'batch_size': 64,
            'epochs': 1,
            'model': 'resnet18-finetune',
            'id': time.time()
        }
    
        logging.info(f'Using device={device} 🚀')
        # everything starts with the data
        train_dl, val_dl, test_dl = get_dataloaders(
            project.data_dir,
            val_transform=val_transform,
            train_transform=train_transform,
            batch_size=params['batch_size'],
            num_workers=4,
        )
        # is always good practice to visualise some of the train and val images to be sure data-aug
        # is applied properly
        # show_dl(train_dl)
        # show_dl(test_dl)
        # define our comet experiment
        experiment = Experiment(api_key='8THqoAxomFyzBgzkStlY95MOf',
                                project_name="dl-pytorch-template", workspace="francescosaveriozuppichini")
        experiment.log_parameters(params)
        # create our special resnet18
        cnn = resnet18(n_classes=2).to(device)
        loss = nn.CrossEntropyLoss()
        # print the model summary to show useful information
        logging.info(summary(cnn, (3, 224, 244)))
        # define custom optimizer and instantiace the trainer `Model`
        optimizer = optim.Adam(cnn.parameters(), lr=params['lr'])
        # create our Trial object to train and evaluate the model
        trial = Trial(cnn, optimizer, loss, metrics=['acc', 'loss'],
                      callbacks=[
                          CometCallback(experiment),
                          ReduceLROnPlateau(monitor='val_loss',
                                            factor=0.1, patience=5),
                          EarlyStopping(monitor='val_acc', patience=5, mode='max'),
                          CSVLogger(str(project.checkpoint_dir / 'history.csv')),
                          ModelCheckpoint(str(project.checkpoint_dir / f'{params["id"]}-best.pt'), monitor='val_acc', mode='max')
        ]).to(device)
        trial.with_generators(train_generator=train_dl,
                              val_generator=val_dl, test_generator=test_dl)
        history = trial.run(epochs=params['epochs'], verbose=1)
        logging.info(history)
        preds = trial.evaluate(data_key=torchbearer.TEST_DATA)
        logging.info(f'test preds=({preds})')
        # experiment.log_metric('test_acc', test_acc)
    
    

    I am running the same logic (same model) with poutyne and I have no problems. I really would like to switch to torchbearer

    Error is:

    2020-02-03 13:32:03,386 - [INFO] - None
      0%|                                                                                                                                                             | 0/1 [00:00<?, ?it/s]C:/w/1/s/tmp_conda_3.7_100118/conda/conda-bld/pytorch_1579082551706/work/aten/src/THCUNN/ClassNLLCriterion.cu:106: block: [0,0,0], thread: [2,0,0] Assertion `t >= 0 && t < n_classes` failed.
    C:/w/1/s/tmp_conda_3.7_100118/conda/conda-bld/pytorch_1579082551706/work/aten/src/THCUNN/ClassNLLCriterion.cu:106: block: [0,0,0], thread: [13,0,0] Assertion `t >= 0 && t < n_classes` failed.
    C:/w/1/s/tmp_conda_3.7_100118/conda/conda-bld/pytorch_1579082551706/work/aten/src/THCUNN/ClassNLLCriterion.cu:106: block: [0,0,0], thread: [17,0,0] Assertion `t >= 0 && t < n_classes` failed.
    C:/w/1/s/tmp_conda_3.7_100118/conda/conda-bld/pytorch_1579082551706/work/aten/src/THCUNN/ClassNLLCriterion.cu:106: block: [0,0,0], thread: [20,0,0] Assertion `t >= 0 && t < n_classes` failed.
    C:/w/1/s/tmp_conda_3.7_100118/conda/conda-bld/pytorch_1579082551706/work/aten/src/THCUNN/ClassNLLCriterion.cu:106: block: [0,0,0], thread: [21,0,0] Assertion `t >= 0 && t < n_classes` failed.
    C:/w/1/s/tmp_conda_3.7_100118/conda/conda-bld/pytorch_1579082551706/work/aten/src/THCUNN/ClassNLLCriterion.cu:106: block: [0,0,0], thread: [22,0,0] Assertion `t >= 0 && t < n_classes` failed.
    C:/w/1/s/tmp_conda_3.7_100118/conda/conda-bld/pytorch_1579082551706/work/aten/src/THCUNN/ClassNLLCriterion.cu:106: block: [0,0,0], thread: [23,0,0] Assertion `t >= 0 && t < n_classes` failed.
    C:/w/1/s/tmp_conda_3.7_100118/conda/conda-bld/pytorch_1579082551706/work/aten/src/THCUNN/ClassNLLCriterion.cu:106: block: [0,0,0], thread: [25,0,0] Assertion `t >= 0 && t < n_classes` failed.
    C:/w/1/s/tmp_conda_3.7_100118/conda/conda-bld/pytorch_1579082551706/work/aten/src/THCUNN/ClassNLLCriterion.cu:106: block: [0,0,0], thread: [29,0,0] Assertion `t >= 0 && t < n_classes` failed.
    C:/w/1/s/tmp_conda_3.7_100118/conda/conda-bld/pytorch_1579082551706/work/aten/src/THCUNN/ClassNLLCriterion.cu:106: block: [0,0,0], thread: [30,0,0] Assertion `t >= 0 && t < n_classes` failed.
    Traceback (most recent call last):
      File "c:/Users/Francesco/Documents/PyTorch-Deep-Learning-Template/main.py", line 64, in <module>
        history = trial.run(epochs=params['epochs'], verbose=1)
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torchbearer\trial.py", line 133, in wrapper
        res = func(self, *args, **kwargs)
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torchbearer\trial.py", line 988, in run
        final_metrics = self._fit_pass(state)[torchbearer.METRICS]
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torchbearer\trial.py", line 298, in wrapper
        res = func(self, *args, **kwargs)
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torchbearer\trial.py", line 1033, in _fit_pass
        state[torchbearer.OPTIMIZER].step(lambda: self.closure(state))
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torch\optim\adam.py", line 58, in step
        loss = closure()
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torchbearer\trial.py", line 1033, in <lambda>
        state[torchbearer.OPTIMIZER].step(lambda: self.closure(state))
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torchbearer\bases.py", line 382, in closure
        state[loss].backward(**state[torchbearer.BACKWARD_ARGS])
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\comet_ml\monkey_patching.py", line 246, in wrapper
        return_value = original(*args, **kwargs)
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torch\tensor.py", line 195, in backward
        torch.autograd.backward(self, gradient, retain_graph, create_graph)
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torch\autograd\__init__.py", line 99, in backward
        allow_unreachable=True)  # allow_unreachable flag
    RuntimeError: CUDA error: CUBLAS_STATUS_ALLOC_FAILED when calling `cublasCreate(handle)`
    

    Do your library work for you? Do you use it in your daily workflow?

    Thank you.

    Cheers,

    Francesco Saverio

    opened by FrancescoSaverioZuppichini 7
  • Accuracy computation for seq2seq model

    Accuracy computation for seq2seq model

    0/10(t): 100%|██████████| 1000/1000 [02:30<00:00,  6.75it/s, running_loss=0.0338, running_acc=0.326, loss=0.689, loss_std=1.17, acc=35.9, acc_std=0]
    0/10(v): 100%|██████████| 20/20 [00:01<00:00, 19.90it/s, val_loss=0.0341, val_loss_std=0.0122, val_acc=42, val_acc_std=0]
    1/10(t): 100%|██████████| 1000/1000 [02:30<00:00,  6.76it/s, running_loss=0.00997, running_acc=0.327, loss=0.019, loss_std=0.0166, acc=41.8, acc_std=0]
    1/10(v): 100%|██████████| 20/20 [00:01<00:00, 19.98it/s, val_loss=0.0126, val_loss_std=0.00798, val_acc=42.1, val_acc_std=0]
    2/10(t): 100%|██████████| 1000/1000 [02:30<00:00,  6.75it/s, running_loss=0.00493, running_acc=0.328, loss=0.00837, loss_std=0.00938, acc=41.8, acc_std=0]
    2/10(v): 100%|██████████| 20/20 [00:01<00:00, 19.89it/s, val_loss=0.00783, val_loss_std=0.00716, val_acc=42.2, val_acc_std=0]
    3/10(t):  45%|████▌     | 454/1000 [01:08<01:21,  6.73it/s, running_loss=0.00458, running_acc=0.316]
    

    Are the accuracies correct? (running_acc=.326, acc=35.9?)

    I may be misunderstanding something, but shouldn't running_acc and acc be the same at the end of each epoch?

    bug 
    opened by kl2792 6
  • Tqdm for Jupyter Notebook

    Tqdm for Jupyter Notebook

    Each iteration of TQDM starts a new line in Jupyter Notebook -- is there any way to integrate one of the suggested fixes into torchbearer?

    (ref: https://github.com/tqdm/tqdm/issues/375, https://stackoverflow.com/a/47200571)

    bug 
    opened by kl2792 6
  • ReduceLROnPlateau

    ReduceLROnPlateau

    Dear all,

    first of all, I love this library.

    The ReduceLROnPlateau is not working when I call trail.evaluate.

    ...
        trial = Trial(cnn, optimizer, loss, metrics=['acc', 'loss'],
                      callbacks=[
                        #   CometCallback(experiment),
                          ReduceLROnPlateau(monitor='val_loss',
                                            factor=0.1, patience=5),
                        #   EarlyStopping(monitor='val_acc', patience=5, mode='max'),
                        #   CSVLogger('history.csv'),
                        #   ModelCheckpoint('best.pt', monitor='val_acc', mode='max')
        ]).to(device)
        trial.with_generators(train_generator=train_dl,
                              val_generator=val_dl, test_generator=test_dl)
        # history = trial.run(params['epochs'], verbose=1)
        preds = trial.evaluate(data_key=torchbearer.TEST_DATA)
    

    error:

    0/1(e): 100%|███████████████████████████████████| 1/1 [00:00<00:00,  2.18it/s, test_acc=0.4667, test_loss=0.6516]
    Traceback (most recent call last):
      File "c:/Users/Francesco/Documents/PyTorch-Deep-Learning-Template/main.py", line 62, in <module>
        preds = trial.evaluate(data_key=torchbearer.TEST_DATA)
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torchbearer\trial.py", line 298, in wrapper
        res = func(self, *args, **kwargs)
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torchbearer\trial.py", line 133, in wrapper
        res = func(self, *args, **kwargs)
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torchbearer\trial.py", line 1131, in evaluate
        state[torchbearer.CALLBACK_LIST].on_end_epoch(state)
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torchbearer\callbacks\callbacks.py", line 221, in on_end_epoch
        self._for_list(lambda callback: callback.on_end_epoch(state))
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torchbearer\trial.py", line 105, in _for_list
        function(self.callback_list)
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torchbearer\callbacks\callbacks.py", line 221, in <lambda>
        self._for_list(lambda callback: callback.on_end_epoch(state))
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torchbearer\callbacks\callbacks.py", line 221, in on_end_epoch
        self._for_list(lambda callback: callback.on_end_epoch(state))
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torchbearer\callbacks\callbacks.py", line 66, in _for_list
        function(callback)
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torchbearer\callbacks\callbacks.py", line 221, in <lambda>
        self._for_list(lambda callback: callback.on_end_epoch(state))
      File "C:\Users\Francesco\Anaconda3\envs\dl\lib\site-packages\torchbearer\callbacks\torch_scheduler.py", line 32, in on_end_epoch
        self._scheduler.step(state[torchbearer.METRICS][self._monitor], epoch=state[torchbearer.EPOCH])
    KeyError: 'val_loss'
    

    Probably you need to disable the callback when evaluating (or just checking if the monitored metrics is in state['metrics'].

    Thank you!

    Best Regards,

    Francesco Saverio

    opened by FrancescoSaverioZuppichini 5
  • Trial predict fails with the given example

    Trial predict fails with the given example

    After training the model, I want to get the prediction on the test set. Not the accuracy. I know that should be Trial.evaluate(). It works well. Therefore I used Trial.predict(). Is that right?

    But the error says that AttributeError: 'dict' object has no attribute 'data'.

    I read the instruction of Trial that provides an example,

    
    # Simple trial to predict on some validation and test data
    >>> from torchbearer import Trial
    >>> val_data = torch.rand(5, 5)
    >>> test_data = torch.rand(5, 5)
    >>> t = Trial(None).with_test_data(test_data)
    >>> test_predictions = t.predict(data_key=torchbearer.TEST_DATA)
    
    

    I ran it but got an error AttributeError: 'NoneType' object has no attribute 'eval'

    So, is there any problem in this method?

    bug 
    opened by danielhuoo 5
  • Lean Model Checkpointing

    Lean Model Checkpointing

    Hi,

    I ran a Trial and have my model saves my model using torchbearer.callbacks.checkpointers.Best to a file model.pt.

    When I load the file with torch.load and run try to make a forward pass with it, I get the following error:

    model = MyModule()
    state_dict = torch.load('vae.pt')
    model.load_state_dict(state_dict) # <== I get the error here
    

    AttributeError: 'StateKey' object has no attribute 'startswith'

    I get that model is being saved so that I can be recovered to be ready for torchbearer, but how can we save the model lean?

    It seems like here, the model is only saved for reusability by torchbearer.

    Thanks a lot!

    docs 
    opened by dorukhansergin 5
  • loss_std resulting in complex number and breaking Tensorboard

    loss_std resulting in complex number and breaking Tensorboard

    I'm using torchbearer with PyTorch 0.4 and TensorboardX 1.2. Previously, I was using PyTorch 0.4.1, but I had to downgrade to use the TensorboardX because of a incompatibility with them. After adding the Tensorboard callback, the following error is raised after training for some time:

    {TypeError}can't convert complex to float

    When debugging, I noticed that the add_scalar() of TensorboardX tried to convert the scalar to float and, somehow, the val_loss_std was a complex number. Is there and error in how the std is calculated in order to result in a complex number?

    bug 
    opened by fernandocamargoai 5
  • Support multi input and output

    Support multi input and output

    Right now, it's not possible to:

    • Have a Module multiple inputs (eg. forward(x1, x2)).
    • Have a Module with multiple outputs (returning a tuple).

    I worked around the first problem by creating a Module with a single input and indexing each individual input. But the second problem makes it impossible to use the TripletMarginLoss, for example, since it expects 3 outputs from the Module.

    opened by fernandocamargoai 3
  • Model checkpointers save_weights_only

    Model checkpointers save_weights_only

    As per the discussion in #504 it would be good if the checkpointers had an option to just save the model state dict, rather than the trial one. Not sure what the argument should be, something like save_model_only / save_weights_only? @dorukhansergin @MattPainter01 any thoughts on this?

    enhancement 
    opened by ethanwharris 3
  • Running indefinitely?

    Running indefinitely?

    Currently there is no way to ask torchbearer to run until stopped. This would be useful for reinforcement learning where we don't know how long an episode will be.

    enhancement 
    opened by ethanwharris 3
  • lr_scheduler order changed since PyTorch 1.1.0

    lr_scheduler order changed since PyTorch 1.1.0

    Many thanks for the wonderful library.

    A warning message emerged when a scheduler was used in the callbacks: scheduler = torchbearer.callbacks.torch_scheduler.StepLR(step_size=5, gamma=0.1)

    Hope it could be considered in later update if not yet included.

    Python37\lib\site-packages\torch\optim\lr_scheduler.py:100: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule.See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
      "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
    
    bug 
    opened by yongduek 2
  • Queston about training loop

    Queston about training loop

    Hi! I'm trying to fork the repo and add some functionality for an experiment. But that requires an addition in the training loop. I've read the documentation and the code but I can't seem to understand where the training loop itself is defined. Can somebody point me in the right direction?

    Thanks in advance!

    opened by AnabetsyR 4
  • GradientNormClipping callback error

    GradientNormClipping callback error

    When I insert this callback in the trial I get the following error. Is this some kind of bug? It seems like the gradients are not passed in the callback.

    """ File "/home/dimitris/.local/lib/python3.6/site-packages/torch/nn/utils/clip_grad.py", line 30, in clip_grad_norm_ total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type) for p in parameters]), norm_type) RuntimeError: stack expects a non-empty TensorList """

    opened by dimimal 0
  • Native automatic mixed precision for torchbearer

    Native automatic mixed precision for torchbearer

    Native automatic mixed precision support (torch.cuda.amp) is now in master: https://pytorch.org/docs/master/amp.html https://pytorch.org/docs/master/notes/amp_examples.html

    Not sure if you ever tried Nvidia's (our) experimental Apex Amp, but I know it has many pain points (extension builds, forward/backward compatibilty, DataParallel support, flaky checkpointing, i don’t even know if it can be hacked to handle double backward/gradient penalty, others…). torch.cuda.amp fixes all these, the interface is more flexible and intuitive, and the tighter integration with pytorch brings more future optimizations into scope.

    I think the torch.cuda.amp API is a good fit for a higher-level library because its style is more functional (as in, it doesn't statefully alter anything outside itself). The necessary torch.cuda.amp calls don't have silent/weird effects elsewhere.

    If you want to talk about adding torch.cuda.amp to torchbearer, with an eye towards it becoming the future-proof source of mixed precision, message me on Pytorch slack anytime (or ask me for invites if you're not signed up). I'll check this issue periodically but I'm on Pytorch slack a greater fraction of the time than I care to admit.

    opened by mcarilli 0
  • Y_pred tuple behaviour changed

    Y_pred tuple behaviour changed

    When y_pred is a tuple (i.e. model returns multiple outputs) the criterion now recieves the tuple unpacked. This should either be reverted or more clearly documented.

    bug 
    opened by ethanwharris 0
Releases(0.5.3)
  • 0.5.3(Jan 31, 2020)

  • 0.5.2(Jan 28, 2020)

    [0.5.2] - 2020-28-01

    Added

    • Added option to use mixup loss with cutmix
    • Support for PyTorch 1.4.0

    Changed

    • Changed PyCM save methods to use *args and **kwargs

    Deprecated

    Removed

    Fixed

    • Fixed a bug where the PyCM callback would fail when saving
    Source code(tar.gz)
    Source code(zip)
  • 0.5.1(Nov 6, 2019)

    [0.5.1] - 2019-11-06

    Added

    • Added BCPlus callback for between-class learning
    • Added support for PyTorch 1.3
    • Added a show flag to the ImagingCallback.to_pyplot method, set to false to stop it from calling plt.show
    • Added manifold mixup

    Changed

    • Changed the default behaviour of ImagingCallback.to_pyplot to turn off the axis

    Deprecated

    Removed

    Fixed

    • Fixed a bug when resuming an old state dict with tqdm enabled
    • Fixed a bug in imaging where passing a title to to_pyplot was not possible
    Source code(tar.gz)
    Source code(zip)
  • 0.5.0(Sep 17, 2019)

    [0.5.0] - 2019-09-17

    Added

    • Added PyTorch CyclicLR scheduler

    Changed

    • Torchbearer now supports Modules with multiple inputs and multiple outputs

    Deprecated

    Removed

    • Cyclic LR callback in favour of torch cyclic lr scheduler
    • Removed support for PyTorch 0.4.x

    Fixed

    • Fixed bug where aggregate predictions couldn't handle empty list
    • Fixed a bug where Runtime Errors on forward weren't handled properly
    • Fixed a bug where exceptions on forward wouldn't print the traceback properly
    • Fixed a documentation mistake whereby ReduceLROnPlateau was said to increase learning rate
    Source code(tar.gz)
    Source code(zip)
  • 0.4.0(Sep 17, 2019)

    [0.4.0] - 2019-07-05

    Added

    • Added with_loader trial method that allows running of custom batch loaders
    • Added a Mock Model which is set when None is passed as the model to a Trial. Mock Model always returns None.
    • Added __call__(state) to StateKey so that they can now be used as losses
    • Added a callback to do cutout regularisation
    • Added a with_data trial method that allows passing of train, val and test data in one call
    • Added the missing on_init callback decorator
    • Added a step_on_batch flag to the early stopping callback
    • Added multi image support to imaging
    • Added a callback to unpack state into torchbearer.X at sample time for specified keys and update state after the forward pass based on model outputs. This is useful for using DataParallel which pass the main state dict directly.
    • Added callback for generating confusion matrices with PyCM
    • Added a mixup callback with associated loss
    • Added Label Smoothing Regularisation (LSR) callback
    • Added CutMix regularisation
    • Added default metric from paper for when Mixup loss is used

    Changed

    • Changed history to now just be a list of records
    • Categorical Accuracy metric now also accepts tensors of size (B, C) and gets the max over C for the taget class

    Deprecated

    Removed

    • Removed the variational sub-package, this will now be packaged separately
    • Removed verbose argument from the early stopping callback

    Fixed

    • Fixed a bug where list or dictionary metrics would cause the tensorboard callback to error
    • Fixed a bug where running a trial without training steps would error
    • Fixed a bug where the caching imaging callback didn't reset data so couldn't be run in multiple trials
    • Fixed a bug in the ClassAppearanceModel callback
    • Fixed a bug where the state given to predict was not a State object
    • Fixed a bug with Cutout on gpu
    • Fixed a bug where MakeGrid callback wasn't passing all arguments correctly
    • Fixed a bug in ImagingCallback that would sometimes cause make_grid to throw an error
    • Fixed a bug where the verbose argument would not work unless given as a keyword argument
    • Fixed a bug where the data_key argument would sometimes not work as expected
    • Fixed a bug where cutout required a seed
    • Fixed a bug where cutmix wasn't sendign the beta distribution sample to the device
    Source code(tar.gz)
    Source code(zip)
  • 0.3.2(May 28, 2019)

    [0.3.2] - 2019-05-28

    Added

    Changed

    Deprecated

    Removed

    Fixed

    • Fixed a bug where for_steps would sometimes not work as expected if called in the wrong order
    • Fixed a bug where torchbearer installed via pip would crash on import
    Source code(tar.gz)
    Source code(zip)
  • 0.3.1(May 24, 2019)

    [0.3.1] - 2019-05-24

    Added

    • Added cyclic learning rate finder
    • Added on_init callback hook to run at the end of trial init
    • Added callbacks for weight initialisation in torchbearer.callbacks.init
    • Added with_closure trial method that allows running of custom closures
    • Added base_closure function to bases that allows creation of standard training loop closures
    • Added ImagingCallback class for callbacks which produce images that can be sent to tensorboard, visdom or a file
    • Added CachingImagingCallback and MakeGrid callback to make a grid of images
    • Added the option to give the only_if callback decorator a function of self and state rather than just state
    • Added Layer-sequential unit-variance (LSUV) initialization
    • Added ClassAppearanceModel callback and example page for visualising CNNs
    • Added on_checkpoint callback decorator
    • Added support for PyTorch 1.1.0

    Changed

    • No_grad and enable_grad decorators are now also context managers

    Deprecated

    Removed

    • Removed the fluent decorator, just use return self
    • Removed install dependency on torchvision, still required for some functionality

    Fixed

    • Fixed bug where replay errored when train or val steps were None
    • Fixed a bug where mock optimser wouldn't call it's closure
    • Fixed a bug where the notebook check raised ModuleNotFoundError when IPython not installed
    • Fixed a memory leak with metrics that causes issues with very long epochs
    • Fixed a bug with the once and once_per_epoch decorators
    • Fixed a bug where the test criterion wouldn't accept a function of state
    • Fixed a bug where type inference would not work correctly when chaining Trial methods
    • Fixed a bug where checkpointers would error when they couldn't find the old checkpoint to overwrite
    • Fixed a bug where the 'test' label would sometimes not populate correctly in the default accuracy metric
    Source code(tar.gz)
    Source code(zip)
  • 0.3.0(Feb 28, 2019)

    [0.3.0] - 2019-02-28

    Added

    • Added torchbearer.variational, a sub-package for implementations of state of the art variational auto-encoders
    • Added SimpleUniform and SimpleExponential distributions
    • Added a decorator which can be used to cite a research article as part of a doc string
    • Added an optional dimension argument to the mean, std and running_mean metric aggregators
    • Added a var metric and decorator which can be used to calculate the variance of a metric
    • Added an unbiased flag to the std and var metrics to optionally not apply Bessel's correction (consistent with torch.std / torch.var)
    • Added support for rounding 1D lists to the Tqdm callback
    • Added SimpleWeibull distribution
    • Added support for Python 2.7
    • Added SimpleWeibullSimpleWeibullKL
    • Added SimpleExponentialSimpleExponentialKL
    • Added the option for model parameters only saving to Checkpointers.
    • Added documentation about serialization.
    • Added support for indefinite data loading. Iterators can now be run until complete independent of epochs or iterators can be refreshed during an epoch if complete.
    • Added support for batch intervals in interval checkpointer
    • Added line magic %torchbearer notebook
    • Added 'accuracy' variants of 'acc' default metrics

    Changed

    • Changed the default behaviour of the std metric to compute the sample std, in line with torch.std
    • Tqdm precision argument now rounds to decimal places rather than significant figures
    • Trial will now simply infer if the model has an argument called 'state'
    • Torchbearer now infers if inside a notebook and will use the appropriate tqdm module if not set

    Deprecated

    Removed

    • Removed the old Model API (deprecated since version 0.2.0)
    • Removed the 'pass_state' argument from Trial, this will now be inferred
    • Removed the 'std' decorator from the default metrics

    Fixed

    • Fixed a bug in the weight decay callback which would result in potentially negative decay (now just uses torch.norm)
    • Fixed a bug in the cite decorator causing the citation to not show up correctly
    • Fixed a memory leak in the mse primitive metric
    Source code(tar.gz)
    Source code(zip)
  • 0.2.6.1(Feb 25, 2019)

  • 0.2.6(Dec 19, 2018)

    [0.2.6] - 2018-12-19

    Added

    Changed

    • Y_PRED, Y_TRUE and X can now equivalently be accessed as PREDICTION, TARGET and INPUT respectively

    Deprecated

    Removed

    Fixed

    • Fixed a bug where the LiveLossPlot callback would trigger an error if run and evaluate were called separately
    • Fixed a bug where state key errors would report to the wrong stack level
    • Fixed a bug where the user would wrongly get a state key error in some cases
    Source code(tar.gz)
    Source code(zip)
  • 0.2.5(Dec 19, 2018)

    [0.2.5] - 2018-12-19

    Added

    • Added flag to replay to replay only a single batch per epoch
    • Added support for PyTorch 1.0.0 and Python 3.7
    • MetricTree can now unpack dictionaries from root, this is useful if you want to get a mean of a metric. However, this should be used with caution as it extracts only the first value in the dict and ignores the rest.
    • Added a callback for the livelossplot visualisation tool for notebooks

    Changed

    • All error / accuracy metrics can now optionally take state keys for predictions and targets as arguments

    Deprecated

    Removed

    Fixed

    • Fixed a bug with the EpochLambda metric which required y_true / y_pred to have specific forms
    Source code(tar.gz)
    Source code(zip)
  • 0.2.4(Nov 16, 2018)

    [0.2.4] - 2018-11-16

    Added

    • Added metric functionality to state keys so that they can be used as metrics if desired
    • Added customizable precision to the printer callbacks
    • Added threshold to binary accuracy. Now it will appropriately handle any values in [0, 1]

    Changed

    • Changed the default printer precision to 4s.f.
    • Tqdm on_epoch now shows metrics immediately when resuming

    Deprecated

    Removed

    Fixed

    • Fixed a bug which would incorrectly trigger version warnings when loading in models
    • Fixed bugs where the Trial would not fail gracefully if required objects were not in state
    • Fixed a bug where none criterion didn't work with the add_to_loss callback
    • Fixed a bug where tqdm on_epoch always started at 0
    Source code(tar.gz)
    Source code(zip)
  • 0.2.3(Oct 12, 2018)

    [0.2.3] - 2018-10-12

    Added

    • Added string representation of Trial to give summary
    • Added option to log Trial summary to TensorboardText
    • Added a callback point ('on_checkpoint') which can be used for model checkpointing after the history ios updated

    Changed

    • When resuming training checkpointers no longer delete the state file the trial was loaded from
    • Changed the metric eval to include a data_key which tells us what data we are evaluating on

    Deprecated

    Removed

    Fixed

    • Fixed a bug where callbacks weren't handled correctly in the predict and evaluate methods of Trial
    • Fixed a bug where the history wasn't updated when new metrics were calculated with the evaluate method of Trial
    • Fixed a bug where tensorboard writers couldn't be reused
    • Fixed a bug where the none criterion didn't require gradient
    • Fix bug where tqdm wouldn't get correct iterator length when evaluating on test generator
    • Fixed a bug where evaluating before training tried to update history before it existed
    • Fixed a bug where the metrics would output 'val_acc' even if evaluating on test or train data
    • Fixed a bug where roc metric didn't detach y_pred before sending to numpy
    • Fixed a bug where resuming from a checkpoint saved with one of the callbacks didn't populate the epoch number correctly
    Source code(tar.gz)
    Source code(zip)
  • 0.2.2(Sep 18, 2018)

    [0.2.2] - 2018-09-18

    Added

    • The default_for_key metric decorator can now be used to pass arguments to the init of the inner metric
    • The default metric for the key 'top_10_acc' is now the TopKCategoricalAccuracy metric with k set to 10
    • Added global verbose flag for trial that can be overridden by run, evaluate, predict
    • Added an LR metric which retrieves the current learning rate from the optimizer, default for key 'lr'

    Fixed

    • Fixed a bug where the DefaultAccuracy metric would not put the inner metric in eval mode if the first call to reset was after the call to eval
    • Fixed a bug where trying to load a state dict in a different session to where it was saved didn't work properly
    • Fixed a bug where the empty criterion would trigger an error if no Y_TRUE was put in state
    Source code(tar.gz)
    Source code(zip)
  • 0.2.1(Sep 11, 2018)

    [0.2.1] - 2018-09-11

    Added

    • Evaluation and prediction can now be done on any data using data_key keywork arg
    • Text tensorboard/visdom logger that writes epoch/batch metrics to text

    Changed

    • TensorboardX, Numpy, Scikit-learn and Scipy are no longer dependancies and only required if using the tensorboard callbacks or roc metric

    Deprecated

    Removed

    Fixed

    • Model class setting generator incorrectly leading to stop iterations.
    • Argument ordering is consistent in Trial.with_generators and Trial.__init__
    • Added a state dict for the early stopping callback
    • Fixed visdom parameters not getting set in some cases
    Source code(tar.gz)
    Source code(zip)
  • 0.2.0(Aug 21, 2018)

    See [NEW!] in README.md for new key features

    [0.2.0] - 2018-08-21

    Added

    • Added the ability to pass custom arguments to the tqdm callback
    • Added an ignore_index flag to the categorical accuracy metric, similar to nn.CrossEntropyLoss. Usage: metrics=[CategoricalAccuracyFactory(ignore_index=0)]
    • Added TopKCategoricalAccuracy metric (default for key: top_5_acc)
    • Added BinaryAccuracy metric (default for key: binary_acc)
    • Added MeanSquaredError metric (default for key: mse)
    • Added DefaultAccuracy metric (use with 'acc' or 'accuracy') - infers accuracy from the criterion
    • New Trial api torchbearer.Trial to replace the Model api. Trial api is more atomic and uses the fluent pattern to allow chaining of methods.
    • torchbearer.Trial has with_x_generator and with_x_data methods to add training/validation/testing generators to the trial. There is a with_generators method to allow passing of all generators in one call.
    • torchbearer.Trial has for_x_steps and for_steps to allow running of trails without explicit generators or data tensors
    • torchbearer.Trial keeps a history of run calls which tracks number of epochs ran and the final metrics at each epoch. This allows seamless resuming of trial running.
    • torchbearer.Trial.state_dict now returns the trial history and callback list state allowing for full resuming of trials
    • torchbearer.Trial has a replay method that can replay training (with callbacks and display) from the history. This is useful when loading trials from state.
    • The backward call can now be passed args by setting state[torchbearer.BACKWARD_ARGS]
    • torchbearer.Trial implements the forward pass, loss calculation and backward call as a optimizer closure
    • Metrics are now explicitly calculated with no gradient

    Changed

    • Callback decorators can now be chained to allow construction with multiple methods filled
    • Callbacks can now implement state_dict and ``load_state_dict` to allow callbacks to resume with state
    • State dictionary is now accepts StateKey objects which are unique and generated through torchbearer.state.get_state
    • State dictionary now warns when accessed with strings as this allows for collisions
    • Checkpointer callbacks will now resume from a state dict when resume=True in Trial

    Deprecated

    • torchbearer.Model has been deprecated in favour of the new torchbearer.Trial api

    Removed

    • Removed the MetricFactory class. Decorators still work in the same way but the Factory is no longer needed.

    Fixed

    Source code(tar.gz)
    Source code(zip)
  • 0.1.7(Aug 14, 2018)

    [0.1.7] - 2018-08-14

    Added

    • Added visdom logging support to tensorbard callbacks
    • Added option to choose tqdm module (tqdm, tqdm_notebook, ...) to Tqdm callback
    • Added some new decorators to simplify custom callbacks that must only run under certain conditions (or even just once).

    Changed

    • Instantiation of Model will now trigger a warning pending the new Trial API in the next version
    • TensorboardX dependancy now version 1.4

    Deprecated

    Removed

    Fixed

    • Mean and standard deviation calculations now work correctly for network outputs with many dimensions
    • Callback list no longer shared between fit calls, now a new copy is made each fit
    Source code(tar.gz)
    Source code(zip)
  • 0.1.6(Aug 10, 2018)

    [0.1.6] - 2018-08-10

    Added

    • Added a verbose level (options are now 0,1,2) which will print progress for the entire fit call, updating every epoch. Useful when doing dynamic programming with little data.
    • Added support for dictionary outputs of dataloader
    • Added abstract superclass for building TensorBoardX based callbacks

    Changed

    • Timer callback can now also be used as a metric which allows display of specified timings to printers and has been moved to metrics.
    • The loss_criterion is renamed to criterion in torchbearer.Model arguments.
    • The criterion in torchbearer.Model is now optional and will provide a zero loss tensor if it is not given.
    • TensorBoard callbacks refactored to be based on a common super class
    • TensorBoard callbacks refactored to use a common SummaryWriter for each log directory

    Deprecated

    Removed

    Fixed

    • Standard deviation calculation now returns 0 instead of complex value when given very close samples
    Source code(tar.gz)
    Source code(zip)
  • 0.1.5(Jul 30, 2018)

    [0.1.5] - 2018-07-30

    Added

    • Added a on_validation_criterion callback hook
    • Added a DatasetValidationSplitter which can be used to create a validation split if required for datasets like Cifar10 or MNIST
    • Added simple timer callback

    Changed

    Deprecated

    Removed

    Fixed

    • Fixed a bug where checkpointers would not save the model in some cases
    • Fixed a bug with the ROC metric causing it to not work
    Source code(tar.gz)
    Source code(zip)
  • 0.1.4(Jul 23, 2018)

    [0.1.4] - 2018-07-23

    Added

    • Added a decorator API for metrics which allows decorators to be used for metric construction
    • Added a default_for_key decorator which can be used to associate a string with a given metric in metric lists
    • Added a decorator API for callbacks which allows decorators to be used for simple callback construction
    • Added a add_to_loss callback decorator which allows quicker constructions of callbacks that add values to the loss

    Changed

    • Changed the API for running metrics and aggregators to no longer wrap a metric but instead receive input

    Deprecated

    Removed

    Fixed

    Source code(tar.gz)
    Source code(zip)
  • 0.1.3(Jul 18, 2018)

    [0.1.3] - 2018-07-17

    Added

    • Added a flag (step_on_batch) to the LR Scheduler callbacks which allows for step() to be called on each iteration instead of each epoch
    • Added on_sample_validation and on_forward_validation calls for validation callbacks
    • Added GradientClipping callback which simply clips the absolute gradient of the model parameters

    Changed

    • Changed the order of the arguments to the lambda function in the EpochLambda metric for consistency with pytorch and other metrics
    • Checkpointers now create directory to savepath if it doesn't exist
    • Changed the 'on_forward_criterion' callback method to 'on_criterion'
    • Changed epoch number in printer callbacks to be consistent with the rest of torchbearer

    Deprecated

    Removed

    Fixed

    • Fixed tests which were failing as of version 0.1.2
    • Fixed validation_steps not being added to state
    • Fixed checkpointer bug when path contained only filename and no directory path
    • Fixed console printer bug not printing validation statistics
    • Fixed console printer bug calling final_metrics before they existed in state
    Source code(tar.gz)
    Source code(zip)
  • v0.1.2(Jun 8, 2018)

    [0.1.2] - 2018-06-08

    Added

    • Added support for tuple outputs from generators, bink expects output to be length 2. Specifically, x, y = next() is possible, where x and y can be tuples of arbitrary size or depth
    • Added support for torch dtypes in bink Model.to(...)
    • Added pickle_module and pickle_protocol to checkpointers for consistency with torch.save

    Changed

    • Changed the learning rate scheduler callbacks to no longer require an optimizer and to have the proper arguments

    Deprecated

    Removed

    Fixed

    • Fixed an issue in GradientNormClipping which raised a warning in PyTorch >= 0.4
    Source code(tar.gz)
    Source code(zip)
  • v0.1.1(May 30, 2018)

Owner
The torchbearer project, by @ecs-vlc
This repository contains code for the paper "Disentangling Label Distribution for Long-tailed Visual Recognition", published at CVPR' 2021

Disentangling Label Distribution for Long-tailed Visual Recognition (CVPR 2021) Arxiv link Blog post This codebase is built on Causal Norm. Install co

Hyperconnect 85 Oct 18, 2022
The mini-MusicNet dataset

mini-MusicNet A music-domain dataset for multi-label classification Music transcription is sequence-to-sequence prediction problem: given an audio per

John Thickstun 4 Nov 09, 2022
DeepMind Alchemy task environment: a meta-reinforcement learning benchmark

The DeepMind Alchemy environment is a meta-reinforcement learning benchmark that presents tasks sampled from a task distribution with deep underlying structure.

DeepMind 188 Dec 25, 2022
PointCNN: Convolution On X-Transformed Points (NeurIPS 2018)

PointCNN: Convolution On X-Transformed Points Created by Yangyan Li, Rui Bu, Mingchao Sun, Wei Wu, Xinhan Di, and Baoquan Chen. Introduction PointCNN

Yangyan Li 1.3k Dec 21, 2022
An implementation of "MixHop: Higher-Order Graph Convolutional Architectures via Sparsified Neighborhood Mixing" (ICML 2019).

MixHop and N-GCN ⠀ A PyTorch implementation of "MixHop: Higher-Order Graph Convolutional Architectures via Sparsified Neighborhood Mixing" (ICML 2019)

Benedek Rozemberczki 393 Dec 13, 2022
Evolutionary Population Curriculum for Scaling Multi-Agent Reinforcement Learning

Evolutionary Population Curriculum for Scaling Multi-Agent Reinforcement Learning This is the code for implementing the MADDPG algorithm presented in

97 Dec 21, 2022
Hand gesture recognition model that can be used as a remote control for a smart tv.

Gesture_recognition The training data consists of a few hundred videos categorised into one of the five classes. Each video (typically 2-3 seconds lon

Pratyush Negi 1 Aug 11, 2022
HAR-stacked-residual-bidir-LSTMs - Deep stacked residual bidirectional LSTMs for HAR

HAR-stacked-residual-bidir-LSTM The project is based on this repository which is presented as a tutorial. It consists of Human Activity Recognition (H

Guillaume Chevalier 287 Dec 27, 2022
Raster Vision is an open source Python framework for building computer vision models on satellite, aerial, and other large imagery sets

Raster Vision is an open source Python framework for building computer vision models on satellite, aerial, and other large imagery sets (including obl

Azavea 1.7k Dec 22, 2022
Robust Video Matting in PyTorch, TensorFlow, TensorFlow.js, ONNX, CoreML!

Robust Video Matting (RVM) English | 中文 Official repository for the paper Robust High-Resolution Video Matting with Temporal Guidance. RVM is specific

flow-dev 2 Aug 21, 2022
The Official Repository for "Generalized OOD Detection: A Survey"

Generalized Out-of-Distribution Detection: A Survey 1. Overview This repository is with our survey paper: Title: Generalized Out-of-Distribution Detec

Jingkang Yang 338 Jan 03, 2023
This repo contains the official code and pre-trained models for the Dynamic Vision Transformer (DVT).

Dynamic-Vision-Transformer (Pytorch) This repo contains the official code and pre-trained models for the Dynamic Vision Transformer (DVT). Not All Ima

210 Dec 18, 2022
Music Generation using Neural Networks Streamlit App

Music_Gen_Streamlit "Music Generation using Neural Networks" Streamlit App TO DO: Make a run_app.sh Introduction [~5 min] (Sohaib) Team Member names/i

Muhammad Sohaib Arshid 6 Aug 09, 2022
Automatic self-diagnosis program (python required)Automatic self-diagnosis program (python required)

auto-self-checker 자동으로 자가진단 해주는 프로그램(python 필요) 중요 이 프로그램이 실행될때에는 절대로 마우스포인터를 움직이거나 키보드를 건드리면 안된다(화면인식, 마우스포인터로 직접 클릭) 사용법 프로그램을 구동할 폴더 내의 cmd창에서 pip

1 Dec 30, 2021
The official implementation of Theme Transformer

Theme Transformer This is the official implementation of Theme Transformer. Checkout our demo and paper : Demo | arXiv Environment: using python versi

Ian Shih 85 Dec 08, 2022
Adaptive Prototype Learning and Allocation for Few-Shot Segmentation (CVPR 2021)

ASGNet The code is for the paper "Adaptive Prototype Learning and Allocation for Few-Shot Segmentation" (accepted to CVPR 2021) [arxiv] Overview data/

Gen Li 91 Dec 23, 2022
SeqFormer: a Frustratingly Simple Model for Video Instance Segmentation

SeqFormer: a Frustratingly Simple Model for Video Instance Segmentation SeqFormer SeqFormer: a Frustratingly Simple Model for Video Instance Segmentat

Junfeng Wu 298 Dec 22, 2022
A PyTorch implementation for our paper "Dual Contrastive Learning: Text Classification via Label-Aware Data Augmentation".

Dual-Contrastive-Learning A PyTorch implementation for our paper "Dual Contrastive Learning: Text Classification via Label-Aware Data Augmentation". Y

hoshi-hiyouga 85 Dec 26, 2022
Efficient Training of Audio Transformers with Patchout

PaSST: Efficient Training of Audio Transformers with Patchout This is the implementation for Efficient Training of Audio Transformers with Patchout Pa

165 Dec 26, 2022
Unet network with mean teacher for altrasound image segmentation

Unet network with mean teacher for altrasound image segmentation

5 Nov 21, 2022