ocaml-torch provides some ocaml bindings for the PyTorch tensor library.

Overview

ocaml-torch

ocaml-torch provides some ocaml bindings for the PyTorch tensor library. This brings to OCaml NumPy-like tensor computations with GPU acceleration and tape-based automatic differentiation.

Main workflow

These bindings use the PyTorch C++ API and are mostly automatically generated. The current GitHub tip and the opam package v0.7 corresponds to PyTorch v1.8.0.

On Linux note that you will need the PyTorch version using the cxx11 abi cpu version, cuda 10.2 version.

Opam Installation

The opam package can be installed using the following command. This automatically installs the CPU version of libtorch.

opam install torch

You can then compile some sample code, see some instructions below. ocaml-torch can also be used in interactive mode via utop or ocaml-jupyter.

Here is a sample utop session.

utop

Build a Simple Example

To build a first torch program, create a file example.ml with the following content.

open Torch

let () =
  let tensor = Tensor.randn [ 4; 2 ] in
  Tensor.print tensor

Then create a dune file with the following content:

(executables
  (names example)
  (libraries torch))

Run dune exec example.exe to compile the program and run it!

Alternatively you can first compile the code via dune build example.exe then run the executable _build/default/example.exe (note that building the bytecode target example.bc may not work on macos).

Tutorials

Examples

Below is an example of a linear model trained on the MNIST dataset (full code).

  (* Create two tensors to store model weights. *)
  let ws = Tensor.zeros [image_dim; label_count] ~requires_grad:true in
  let bs = Tensor.zeros [label_count] ~requires_grad:true in

  let model xs = Tensor.(mm xs ws + bs) in
  for index = 1 to 100 do
    (* Compute the cross-entropy loss. *)
    let loss =
      Tensor.cross_entropy_for_logits (model train_images) ~targets:train_labels
    in

    Tensor.backward loss;

    (* Apply gradient descent, disable gradient tracking for these. *)
    Tensor.(no_grad (fun () ->
        ws -= grad ws * f learning_rate;
        bs -= grad bs * f learning_rate));

    (* Compute the validation error. *)
    let test_accuracy =
      Tensor.(argmax (model test_images) = test_labels)
      |> Tensor.to_kind ~kind:(T Float)
      |> Tensor.sum
      |> Tensor.float_value
      |> fun sum -> sum /. test_samples
    in
    printf "%d %f %.2f%%\n%!" index (Tensor.float_value loss) (100. *. test_accuracy);
  done
  • Some ResNet examples on CIFAR-10.
  • A simplified version of char-rnn illustrating character level language modeling using Recurrent Neural Networks.
  • Neural Style Transfer applies the style of an image to the content of another image. This uses some deep Convolutional Neural Network.

Models and Weights

Various pre-trained computer vision models are implemented in the vision library. The weight files can be downloaded at the following links:

Running the pre-trained models on some sample images can the easily be done via the following commands.

dune exec examples/pretrained/predict.exe path/to/resnet18.ot tiger.jpg

Natural Language Processing models based on BERT can be found in the ocaml-torch repo.

Alternative Installation Option

This alternative way to install ocaml-torch could be useful to run with GPU acceleration enabled.

The libtorch library can be downloaded from the PyTorch website (1.8.0 cpu version).

Download and extract the libtorch library then to build all the examples run:

export LIBTORCH=/path/to/libtorch
git clone https://github.com/LaurentMazare/ocaml-torch.git
cd ocaml-torch
make all
Comments
  • Confused about types

    Confused about types

    I think this should work:

    open Torch
    Tensor.(arange1 ~start:(f 0.) ~end:(f 1.))
    

    But I get

    Error: This expression has type Tensor.t but an expression was expected of type
             Torch_core.Wrapper.Scalar.t
    
    opened by bluddy 15
  • installation does not seem to properly install libtorch?

    installation does not seem to properly install libtorch?

    according to the README libtorch should be installed automatically. however, i get

    utop # #require "torch.toplevel";;
    Cannot load required shared library dlltorch_core_stubs.
    Reason: /Users/nbecker/.opam/4.07.1+flambda/lib/stublibs/dlltorch_core_stubs.so: dlopen(/Users/nbecker/.opam/4.07.1+flambda/lib/stublibs/dlltorch_core_stubs.so, 10): Library not loaded: @rpath/libc10.dylib
      Referenced from: /Users/nbecker/.opam/4.07.1+flambda/lib/stublibs/dlltorch_core_stubs.so
      Reason: image not found.
    Error: Reference to undefined global `Torch_core__Wrapper'
    

    after installing opam reinstall torch

    opened by nilsbecker 10
  • Installing for GPU acceleration

    Installing for GPU acceleration

    Hello,

    When I first installed ocaml-torch earlier, I used opam install torch, and I was running the CPU version. But now I want to accelerate this with GPU, and to do that I did the followings:

    cd ~
    wget https://download.pytorch.org/libtorch/cu102/libtorch-cxx11-abi-shared-with-deps-1.7.0.zip
    unzip libtorch-cx11.....zip   // then now I have ~/libtorch
    sudo rm -r ocaml-torch
    
    export LIBTORCH=~/libtorch
    git clone https://github.com/LaurentMazare/ocaml-torch.git
    cd ocaml-torch
    make all
    

    In my code, I used

    let device = T.Device.cuda_if_available () in
    let vs = T.Var_store.create ~name:"my-project" ~device () in ...
    

    But while running my code, when I typed nvidia-smi, the result was below:

    +-----------------------------------------------------------------------------+
    | NVIDIA-SMI 460.27.04    Driver Version: 460.27.04    CUDA Version: 11.2     |
    |-------------------------------+----------------------+----------------------+
    | GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
    | Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
    |                               |                      |               MIG M. |
    |===============================+======================+======================|
    |   0  GeForce RTX 208...  On   | 00000000:18:00.0 Off |                  N/A |
    | 31%   41C    P8    17W / 250W |      1MiB / 11019MiB |      0%      Default |
    |                               |                      |                  N/A |
    +-------------------------------+----------------------+----------------------+
    |   1  GeForce RTX 208...  On   | 00000000:3B:00.0 Off |                  N/A |
    | 31%   41C    P8    21W / 250W |      1MiB / 11019MiB |      0%      Default |
    |                               |                      |                  N/A |
    +-------------------------------+----------------------+----------------------+
    |   2  GeForce RTX 208...  On   | 00000000:86:00.0 Off |                  N/A |
    | 32%   40C    P8    16W / 250W |      1MiB / 11019MiB |      0%      Default |
    |                               |                      |                  N/A |
    +-------------------------------+----------------------+----------------------+
                                                                                   
    +-----------------------------------------------------------------------------+
    | Processes:                                                                  |
    |  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
    |        ID   ID                                                   Usage      |
    |=============================================================================|
    |  No running processes found                                                 |
    +-----------------------------------------------------------------------------+
    

    Maybe my code is still relying on the CPU version. What should I do more to use the GPU acceleration?

    Thanks, Gwonsoo

    opened by Kwonsoo 9
  • Stack overflow from the compiler.

    Stack overflow from the compiler.

    When trying to build the current master (corresponding to pytorch 1.7), I get a stack overflow from the compiler for src/wrapper/torch_bindings_generated.ml. It is indeed a very long file, but still I'm surprised. Here is the backtrace:

    Fatal error: exception Stack overflow
    Raised by primitive operation at Mach.instr_cons_debug in file "asmcomp/mach.ml", line 137, characters 2-185
    Re-raised at Misc.try_finally in file "utils/misc.ml", line 45, characters 10-56
    Called from Asmgen.(++) in file "asmcomp/asmgen.ml" (inlined), line 79, characters 15-18
    Called from Asmgen.compile_fundecl in file "asmcomp/asmgen.ml", line 84, characters 2-624
    Called from Stdlib__list.iter in file "list.ml", line 110, characters 12-15
    Called from Misc.try_finally in file "utils/misc.ml", line 31, characters 8-15
    Re-raised at Misc.try_finally in file "utils/misc.ml", line 45, characters 10-56
    Called from Asmgen.(++) in file "asmcomp/asmgen.ml" (inlined), line 79, characters 15-18
    Called from Asmgen.end_gen_implementation in file "asmcomp/asmgen.ml", line 153, characters 2-128
    Called from Misc.try_finally in file "utils/misc.ml", line 31, characters 8-15
    Re-raised at Misc.try_finally in file "utils/misc.ml", line 45, characters 10-56
    Called from Asmgen.compile_unit.(fun) in file "asmcomp/asmgen.ml", line 134, characters 7-231
    Called from Misc.try_finally in file "utils/misc.ml", line 31, characters 8-15
    Re-raised at Misc.try_finally in file "utils/misc.ml", line 45, characters 10-56
    Called from Optcompile.clambda.(fun) in file "driver/optcompile.ml", line 78, characters 7-336
    Called from Misc.try_finally in file "utils/misc.ml", line 31, characters 8-15
    Re-raised at Misc.try_finally in file "utils/misc.ml", line 45, characters 10-56
    Called from Compile_common.implementation.(fun) in file "driver/compile_common.ml", line 121, characters 71-113
    Called from Misc.try_finally in file "utils/misc.ml", line 31, characters 8-15
    Re-raised at Misc.try_finally in file "utils/misc.ml", line 45, characters 10-56
    Called from Misc.try_finally in file "utils/misc.ml", line 31, characters 8-15
    Re-raised at Misc.try_finally in file "utils/misc.ml", line 45, characters 10-56
    Called from Misc.try_finally in file "utils/misc.ml", line 31, characters 8-15
    Re-raised at Misc.try_finally in file "utils/misc.ml", line 45, characters 10-56
    Called from Compenv.process_action in file "driver/compenv.ml", line 596, characters 6-59
    Called from Stdlib__list.iter in file "list.ml", line 110, characters 12-15
    Called from Compenv.process_deferred_actions in file "driver/compenv.ml", line 672, characters 2-61
    Called from Optmain.main in file "driver/optmain.ml", line 55, characters 6-163
    Re-raised at Location.report_exception.loop in file "parsing/location.ml", line 926, characters 14-25
    Called from Optmain.main in file "driver/optmain.ml", line 133, characters 6-37
    Called from Optmain in file "driver/optmain.ml", line 137, characters 2-9
    

    I'm a bit puzzled because it is short, while I'd expected it to be very long since there's a stack overflow. Have you also met this problem at some point? There might be something fishy with my environment I don't know, but at least if I comment a part of the file it compiles fine.

    opened by pveber 9
  • backward pass with manually specified gradient

    backward pass with manually specified gradient

    https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#sphx-glr-beginner-blitz-autograd-tutorial-py has an example showing:

    x = torch.randn(3, requires_grad=True)
    
    y = x * 2
    while y.data.norm() < 1000:
        y = y * 2
    
    v = torch.tensor([0.1, 1.0, 0.0001], dtype=torch.float)
    y.backward(v)
    

    The key line here is y.backward(v)

    The OCaml bindings for backward only take a single tensor:

    grep "val backward" . -R
    ./torch/optimizer.mli:val backward_step : ?clip_grad:Clip_grad.t -> t -> loss:Tensor.t -> unit
    ./wrapper/wrapper.mli:  val backward : ?keep_graph:bool -> ?create_graph:bool -> t -> unit
    

    If we search for "t -> t -> unit", nothing useful shows up:

    rep "t -> t -> unit" . -R
    ./torch/tensor.mli:val ( += ) : t -> t -> unit
    ./torch/tensor.mli:val ( -= ) : t -> t -> unit
    ./torch/tensor.mli:val ( *= ) : t -> t -> unit
    ./torch/tensor.mli:val ( /= ) : t -> t -> unit
    ./torch/optimizer.mli:val step : ?clip_grad:Clip_grad.t -> t -> unit
    
    

    What is the OCaml way of running a backward pass with a custom gradient ?

    opened by zeroexcuses 8
  • Setting slices of a tensor: what is the equivalent of (python) x[i,:,:] = y

    Setting slices of a tensor: what is the equivalent of (python) x[i,:,:] = y

    I spent some time grepping through the (autogenerated) source, and think it's some sort of _input_put* function, but can't seem to figure out exactly what or how to do the slicing/indexing. In Python it would be:

    x = torch.zeros(1000, 30, 30)
    y = torch.randn(30, 30)
    x[5,:,:] = y
    

    or perhaps

    x = torch.zeros(1000, 60, 60)
    y = torch.randn(30, 30)
    i = 5
    x[5,30:,25:-5] = y
    

    Posting here in case others have the same simple question.

    Anyway, this is a great library, thank you very much for all the hard work!!

    opened by tlh24 7
  • Interoperability with npy-ocaml, ocaml-rs or tch-rs

    Interoperability with npy-ocaml, ocaml-rs or tch-rs

    Hey, I have a code in Rust that creates ndarray::Array2 converts it into_pyarray and returns from rust code to Python code and it can be used as a torch tensor.

    I am wondering, what would be a way I'd do the same in OCaml?

    opened by crackcomm 7
  • Directly constructing a Tensor

    Directly constructing a Tensor

    In PyTorch, we can do:

    x = torch.tensor([5.5, 3])
    print(x)
    Out:
    
    tensor([5.5000, 3.0000])
    

    Is there a corresponding function in OCaml-Torch? I don't see the corresponding function in Torch. or Torch.Tensor .

    Thanks!

    opened by zeroexcuses 7
  • Newbie question (building example) (Library not loaded: @rpath/libc10.dylib)

    Newbie question (building example) (Library not loaded: @rpath/libc10.dylib)

    Sorry I'm new to most of this (besides torch) but I wanted to give OCaml a try.

    On my mac I get this while building the example.

    Any clue what I'm doing wrong?

    Thanks

    ❯ dune build example.bc
    
          ocamlc example.bc (exit 2)
    (cd _build/default && /Users/pmanning/.opam/4.06.0/bin/ocamlc.opt -w @a-4-29-40-41-42-44-45-48-58-59-60-40 -strict-sequence -strict-formats -short-paths -keep-locs -g -o example.bc -I /Users/pmanning/.opam/4.06.0/lib/base -I /Users/pmanning/.opam/4.06.0/lib/base/caml -I /Users/pmanning/.opam/4.06.0/lib/base/shadow_stdlib -I /Users/pmanning/.opam/4.06.0/lib/bytes -I /Users/pmanning/.opam/4.06.0/lib/ctypes -I /Users/pmanning/.opam/4.06.0/lib/integers -I /Users/pmanning/.opam/4.06.0/lib/ocaml/threads -I /Users/pmanning/.opam/4.06.0/lib/sexplib0 -I /Users/pmanning/.opam/4.06.0/lib/stdio -I /Users/pmanning/.opam/4.06.0/lib/torch -I /Users/pmanning/.opam/4.06.0/lib/torch/core /Users/pmanning/.opam/4.06.0/lib/base/caml/caml.cma /Users/pmanning/.opam/4.06.0/lib/base/shadow_stdlib/shadow_stdlib.cma /Users/pmanning/.opam/4.06.0/lib/sexplib0/sexplib0.cma /Users/pmanning/.opam/4.06.0/lib/base/base.cma /Users/pmanning/.opam/4.06.0/lib/stdio/stdio.cma /Users/pmanning/.opam/4.06.0/lib/ocaml/unix.cma /Users/pmanning/.opam/4.06.0/lib/ocaml/bigarray.cma /Users/pmanning/.opam/4.06.0/lib/integers/integers.cma /Users/pmanning/.opam/4.06.0/lib/ctypes/ctypes.cma /Users/pmanning/.opam/4.06.0/lib/ocaml/threads/threads.cma /Users/pmanning/.opam/4.06.0/lib/ctypes/ctypes-foreign-base.cma /Users/pmanning/.opam/4.06.0/lib/ctypes/ctypes-foreign-threaded.cma /Users/pmanning/.opam/4.06.0/lib/ocaml/str.cma /Users/pmanning/.opam/4.06.0/lib/ctypes/cstubs.cma /Users/pmanning/.opam/4.06.0/lib/torch/core/torch_core.cma /Users/pmanning/.opam/4.06.0/lib/torch/torch.cma .example.eobjs/example.cmo)
    File "_none_", line 1:
    Error: Error on dynamically loaded library: /Users/pmanning/.opam/4.06.0/lib/stublibs/dlltorch_core_stubs.so: dlopen(/Users/pmanning/.opam/4.06.0/lib/stublibs/dlltorch_core_stubs.so, 10): Library not loaded: @rpath/libc10.dylib
      Referenced from: /Users/pmanning/.opam/4.06.0/lib/stublibs/dlltorch_core_stubs.so
      Reason: image not found
    
    opened by mfirry 7
  • Reading a .pt files gives an error

    Reading a .pt files gives an error

    Hi! I have been trying to use ocaml-torch to load a PyTorch model that has already been trained. I have installed PyTorch 1.9.0 which is what is compatible with ocaml-torch. Upon trying to load the file, it throws the following error

    Fatal error: exception (Failure "version_number <= kMaxSupportedFileFormatVersion ASSERT FAILED at /pytorch/caffe2/serialize/inline_container.cc:131, please report a bug to PyTorch. Attempted to read a PyTorch file with version 4, but the maximum supported version for reading is 1. Your PyTorch installation may be too old. (init at /pytorch/caffe2/serialize/inline_container.cc:131)\

    Can you suggest what can be done in this case? Thank you!

    Edit: Source code python file -

    net = CNN() net.eval() example = torch.rand(16384) traced = torch.jit.trace(net, example) traced.save("./models/feat.pt")

    ocaml file -

    open Base open Torch

    let model = Module.load "../models/feat.pt" ;;

    opened by Het-Shah 5
  • freeing up temporary CUDA tensors w/o Caml.Gc.full_major() ?

    freeing up temporary CUDA tensors w/o Caml.Gc.full_major() ?

    Here's the issue I'm running into:

    1. Rust has RAII. Thus, when a Tensor no longer has references, it's Drop is called, which I suspect triggers calling cuda_free.

    2. OCaml has GC. Each Tensor has a very small foot print (pointer?) on the CPU but possibly huge on the GPU (entire contents of the tensor.)

    3. The only way I know to tell OCaml to do cuda_free is to call "Caml.Gc.full_major()" which seems to slow down training quite a bit (doing a full GC on every training step.)

    4. Is there a way in OCaml to do "do a GC of all tensors, but without GCing the entire OCaml VM" ?

    opened by zeroexcuses 5
  • Custom training

    Custom training

    Hi Laurent, Thank you so much for providing such a good repo.

    I have used the mnistset so far. But I would like to expand to be able to recognise more math symbols: https://www.kaggle.com/datasets/sagyamthapa/handwritten-math-symbols

    What would be the easiest way to go about this? I do not have much expertise in gz files, so I rather not go into that. Can your dataset helper handle CSV files?

    I was thinking of CSV files like this: trainingset image CSV one row example: Pathname

    trainingset label CSV one row example: label

    or

    trainingset image CSV one row example: whole image in explained in bytes

    trainingset label CSV one row example: label

    opened by ChrisRawstone 1
  • Installing the GPU accelerated version of ocaml-torch

    Installing the GPU accelerated version of ocaml-torch

    Hello! Earlier I had installed ocaml-torch for a CPU version. I am trying to install ocaml-torch for GPU acceleration with the following commands. (Downloaded and unzipped libtorch 1.10)

    export LIBTORCH=./libtorch
    git clone https://github.com/LaurentMazare/ocaml-torch.git
    cd ocaml-torch
    make clean
    make all
    

    make fails with the following error trace.

    File "src/stubs/torch_bindings_generated.ml", line 8309, characters 2-16:
    Error: Multiple definition of the type name t.
           Names must be unique in a given structure or signature.
    File "src/wrapper/torch_bindings_generated.ml", line 8309, characters 2-16:
    Error: Multiple definition of the type name t.
           Names must be unique in a given structure or signature.
    Makefile:34: recipe for target 'all' failed
    make: *** [all] Error 1
    

    PS: I have deactivated the conda environment before running make.

    Can you kindly help me with this?

    Thank you!

    opened by Het-Shah 1
  • RL examples no longer compile (again)

    RL examples no longer compile (again)

    Eg. running ocaml-torch/_build/default/examples/reinforcement-learning/dqn_atari.exe produces

    actions: NOOP,FIRE,RIGHT,LEFT,RIGHTFIRE,LEFTFIRE 0 0 (0/0 frames) zsh: segmentation fault

    EDIT: Tried with char_rnn and it works fine, so possibly unique to RL?

    opened by kyliedunn526 1
  • Programs that use ocaml-torch with GPU acceleration segfault right before terminating

    Programs that use ocaml-torch with GPU acceleration segfault right before terminating

    Programs I write using ocaml-torch that use GPU acceleration segfault right before terminating:

    Segmentation fault (core dumped)
    

    This is not a huge deal as it happens when the program is about to terminate anyway but I was wondering if you had observed the same phenomenon.

    In particular, I replicated the problem on your mnist/conv and char_rnn examples.

    opened by jonathan-laurent 5
  • Support for probability distributions

    Support for probability distributions

    Hi,

    It seems to me that ocaml-torch does not provide API for torch.distributions. Specifically, it does not implement various probability distributions with some common functions such as sample and log_prob. I think they are important enough and should be implemented here.

    Best, Gwonsoo

    opened by Kwonsoo 5
Releases(0.17)
Owner
Laurent Mazare
Laurent Mazare
A tiny package to compare two neural networks in PyTorch

Compare neural networks by their feature similarity

Anand Krishnamoorthy 180 Dec 30, 2022
Code for paper "Energy-Constrained Compression for Deep Neural Networks via Weighted Sparse Projection and Layer Input Masking"

model_based_energy_constrained_compression Code for paper "Energy-Constrained Compression for Deep Neural Networks via Weighted Sparse Projection and

Haichuan Yang 16 Jun 15, 2022
PyTorch implementations of normalizing flow and its variants.

PyTorch implementations of normalizing flow and its variants.

Tatsuya Yatagawa 55 Dec 01, 2022
A Closer Look at Structured Pruning for Neural Network Compression

A Closer Look at Structured Pruning for Neural Network Compression Code used to reproduce experiments in https://arxiv.org/abs/1810.04622. To prune, w

Bayesian and Neural Systems Group 140 Dec 05, 2022
pip install antialiased-cnns to improve stability and accuracy

Antialiased CNNs [Project Page] [Paper] [Talk] Making Convolutional Networks Shift-Invariant Again Richard Zhang. In ICML, 2019. Quick & easy start Ru

Adobe, Inc. 1.6k Dec 28, 2022
PyTorch Extension Library of Optimized Autograd Sparse Matrix Operations

PyTorch Sparse This package consists of a small extension library of optimized sparse matrix operations with autograd support. This package currently

Matthias Fey 757 Jan 04, 2023
A few Windows specific scripts for PyTorch

It is a repo that contains scripts that makes using PyTorch on Windows easier. Easy Installation Update: Starting from 0.4.0, you can go to the offici

408 Dec 15, 2022
PyTorch wrappers for using your model in audacity!

PyTorch wrappers for using your model in audacity!

130 Dec 14, 2022
Code snippets created for the PyTorch discussion board

PyTorch misc Collection of code snippets I've written for the PyTorch discussion board. All scripts were testes using the PyTorch 1.0 preview and torc

461 Dec 26, 2022
Over9000 optimizer

Optimizers and tests Every result is avg of 20 runs. Dataset LR Schedule Imagenette size 128, 5 epoch Imagewoof size 128, 5 epoch Adam - baseline OneC

Mikhail Grankin 405 Nov 27, 2022
A very simple and small path tracer written in pytorch meant to be run on the GPU

MentisOculi Pytorch Path Tracer A very simple and small path tracer written in pytorch meant to be run on the GPU Why use pytorch and not some other c

Matthew B. Mirman 222 Dec 01, 2022
PyTorch extensions for fast R&D prototyping and Kaggle farming

Pytorch-toolbelt A pytorch-toolbelt is a Python library with a set of bells and whistles for PyTorch for fast R&D prototyping and Kaggle farming: What

Eugene Khvedchenya 1.3k Jan 05, 2023
GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks

GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks This repository implements a capsule model Inten

Joel Huang 15 Dec 24, 2022
Implements pytorch code for the Accelerated SGD algorithm.

AccSGD This is the code associated with Accelerated SGD algorithm used in the paper On the insufficiency of existing momentum schemes for Stochastic O

205 Jan 02, 2023
An implementation of Performer, a linear attention-based transformer, in Pytorch

Performer - Pytorch An implementation of Performer, a linear attention-based transformer variant with a Fast Attention Via positive Orthogonal Random

Phil Wang 900 Dec 22, 2022
PyTorch implementation of Glow, Generative Flow with Invertible 1x1 Convolutions

glow-pytorch PyTorch implementation of Glow, Generative Flow with Invertible 1x1 Convolutions

Kim Seonghyeon 433 Dec 27, 2022
Official implementations of EigenDamage: Structured Pruning in the Kronecker-Factored Eigenbasis.

EigenDamage: Structured Pruning in the Kronecker-Factored Eigenbasis This repo contains the official implementations of EigenDamage: Structured Prunin

Chaoqi Wang 107 Apr 20, 2022
Pytorch bindings for Fortran

Pytorch bindings for Fortran

Dmitry Alexeev 46 Dec 29, 2022
Model summary in PyTorch similar to `model.summary()` in Keras

Keras style model.summary() in PyTorch Keras has a neat API to view the visualization of the model which is very helpful while debugging your network.

Shubham Chandel 3.7k Dec 29, 2022
lookahead optimizer (Lookahead Optimizer: k steps forward, 1 step back) for pytorch

lookahead optimizer for pytorch PyTorch implement of Lookahead Optimizer: k steps forward, 1 step back Usage: base_opt = torch.optim.Adam(model.parame

Liam 318 Dec 09, 2022