Implementation of a Transformer, but completely in Triton

Overview

Transformer in Triton (wip)

Implementation of a Transformer, but completely in Triton. I'm completely new to lower-level neural net code, so this repository will mostly be a learning experience, with the end-goal being a vanilla transformer that is faster and more efficient to train.

Install

$ pip install triton-transformer

Usage

import torch
from triton_transformer import Transformer

model = Transformer(
    num_tokens = 256,
    max_seq_len = 1024,
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64
)

x = torch.randint(0, 256, (1, 1024))
mask = torch.ones(1, 1024).bool()

logits = model(x, mask = mask) # (1, 1024, 256)

Citations

@article{Tillet2019TritonAI,
    title   = {Triton: an intermediate language and compiler for tiled neural network computations},
    author  = {Philippe Tillet and H. Kung and D. Cox},
    journal = {Proceedings of the 3rd ACM SIGPLAN International Workshop on Machine Learning and Programming Languages},
    year    = {2019}
}
@misc{vaswani2017attention,
    title   = {Attention Is All You Need}, 
    author  = {Ashish Vaswani and Noam Shazeer and Niki Parmar and Jakob Uszkoreit and Llion Jones and Aidan N. Gomez and Lukasz Kaiser and Illia Polosukhin},
    year    = {2017},
    eprint  = {1706.03762},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
You might also like...
A Pytorch implementation of CVPR 2021 paper "RSG: A Simple but Effective Module for Learning Imbalanced Datasets"

RSG: A Simple but Effective Module for Learning Imbalanced Datasets (CVPR 2021) A Pytorch implementation of our CVPR 2021 paper "RSG: A Simple but Eff

A concise but complete implementation of CLIP with various experimental improvements from recent papers
A concise but complete implementation of CLIP with various experimental improvements from recent papers

x-clip (wip) A concise but complete implementation of CLIP with various experimental improvements from recent papers Install $ pip install x-clip Usag

A concise but complete implementation of CLIP with various experimental improvements from recent papers
A concise but complete implementation of CLIP with various experimental improvements from recent papers

x-clip (wip) A concise but complete implementation of CLIP with various experimental improvements from recent papers Install $ pip install x-clip Usag

Implementation of a protein autoregressive language model, but with autoregressive infilling objective (editing subsequences capability)
Implementation of a protein autoregressive language model, but with autoregressive infilling objective (editing subsequences capability)

Protein GLM (wip) Implementation of a protein autoregressive language model, but with autoregressive infilling objective (editing subsequences capabil

Third party Pytorch implement of Image Processing Transformer (Pre-Trained Image Processing Transformer arXiv:2012.00364v2)

ImageProcessingTransformer Third party Pytorch implement of Image Processing Transformer (Pre-Trained Image Processing Transformer arXiv:2012.00364v2)

Episodic Transformer (E.T.) is a novel attention-based architecture for vision-and-language navigation. E.T. is based on a multimodal transformer that encodes language inputs and the full episode history of visual observations and actions. CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped
CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped

CSWin-Transformer This repo is the official implementation of "CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows". Th

nnFormer: Interleaved Transformer for Volumetric Segmentation Code for paper "nnFormer: Interleaved Transformer for Volumetric Segmentation "

nnFormer: Interleaved Transformer for Volumetric Segmentation Code for paper "nnFormer: Interleaved Transformer for Volumetric Segmentation ". Please

3D-Transformer: Molecular Representation with Transformer in 3D Space

3D-Transformer: Molecular Representation with Transformer in 3D Space

Comments
  • Question concerning PyTorch build

    Question concerning PyTorch build

    Hello. I find your project very interesting and I have seen your comparison between PyTorch and Triton implementations.

    However, I am curious whether your PyTorch environment is a source build optimized for your machine or a pip/conda install.

    Source building has faster runtimes and if a conda install is being used for comparison, the difference in speed may simply be due to Triton optimizing CUDA for the run environment.

    Thank you again for your interesting project.

    opened by veritas9872 13
  • _layernorm implementation forward result not equal F.layer_norm

    _layernorm implementation forward result not equal F.layer_norm

    I have a try on your triton-transformer and test the layernorm module alone. It's very weird that the forward result is different while the backward result is equal.

    code: from triton_transformer.layernorm import layernorm import torch import torch.nn as nn

    torch.manual_seed(0) x = torch.randn(2,5).cuda() x.requires_grad_(True) dy = .1*torch.randn_like(x).cuda() dim = 5 norm = nn.LayerNorm(dim).cuda()

    y1 = layernorm(x, norm.weight, norm.bias, use_triton = True) y2 = layernorm(x, norm.weight, norm.bias, use_triton = False) print(y1, y2) print(torch.allclose(y1, y2))

    y1.backward(dy, retain_graph=True) dx_y1 = x.grad.clone()

    x.grad = None

    y2.backward(dy, retain_graph=True) dx_y2 = x.grad.clone() print(dx_y1, dx_y2) print(torch.allclose(dx_y1, dx_y2))

    result: `tensor([[ 0.9492, -0.0021, -0.9797, 0.4449, -0.4123], [-0.7624, 0.4399, 0.7299, -0.3091, -0.0983]], device='cuda:0', grad_fn=<_layernormBackward>) tensor([[ 1.4217, -0.0031, -1.4674, 0.6663, -0.6175], [-1.4342, 0.8276, 1.3732, -0.5815, -0.1850]], device='cuda:0', grad_fn=) False

    tensor([[-0.0706, 0.0288, -0.0813, 0.0446, 0.0785], [ 0.0218, -0.0152, 0.0141, -0.0522, 0.0315]], device='cuda:0') tensor([[-0.0706, 0.0288, -0.0813, 0.0446, 0.0785], [ 0.0218, -0.0152, 0.0141, -0.0522, 0.0315]], device='cuda:0') True`

    opened by Tengxu-Sun 1
  • Current state of benchmarking & contributing?

    Current state of benchmarking & contributing?

    Hey @lucidrains - hope you're doing well! I have some time to hack the next couple weeks, just wanted to get a sense of:

    • Current state of benchmarking (what Triton kernels provide how much lift, aggregate lift over a "vanilla Transformer implementation"
    • If there's anything I could help with, especially as I learn Triton!
    opened by siddk 0
  • Official layer norm added

    Official layer norm added

    Hi @lucidrains , in Triton layer norm was just added in examples, https://github.com/openai/triton/commit/d4baad426db72b83c5222e1c83c929c1860cae54 I tested it, it's twice as fast as Torch, often faster then Apex.

    I'm looking forward for your implementation of attention, so far the Torch implementation is the fastest with 12.3 / 14.5 (forw / back) vs the other Triton implementation in DeepSpeed which is 17.3/ 23.0 on my data.

    opened by olegklimov 2
Releases(0.1.1)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Tensorflow implementation of Swin Transformer model.

Swin Transformer (Tensorflow) Tensorflow reimplementation of Swin Transformer model. Based on Official Pytorch implementation. Requirements tensorflow

167 Jan 08, 2023
Meli Data Challenge 2021 - First Place Solution

My solution for the Meli Data Challenge 2021

Matias Moreyra 23 Mar 09, 2022
This repository is a series of notebooks that show solutions for the projects at Dataquest.io.

Dataquest Project Solutions This repository is a series of notebooks that show solutions for the projects at Dataquest.io. Of course, there are always

Dataquest 1.1k Dec 30, 2022
A self-supervised 3D representation learning framework named viewpoint bottleneck.

Pointly-supervised 3D Scene Parsing with Viewpoint Bottleneck Paper Created by Liyi Luo, Beiwen Tian, Hao Zhao and Guyue Zhou from Institute for AI In

63 Aug 11, 2022
How the Deep Q-learning method works and discuss the new ideas that makes the algorithm work

Deep Q-Learning Recommend papers The first step is to read and understand the method that you will implement. It was first introduced in a 2013 paper

1 Jan 25, 2022
This repository is for EMNLP 2021 paper: It is Not as Good as You Think! Evaluating Simultaneous Machine Translation on Interpretation Data

InterpretationData This repository is for our EMNLP 2021 paper: It is Not as Good as You Think! Evaluating Simultaneous Machine Translation on Interpr

4 Apr 21, 2022
Official repo for AutoInt: Automatic Integration for Fast Neural Volume Rendering in CVPR 2021

AutoInt: Automatic Integration for Fast Neural Volume Rendering CVPR 2021 Project Page | Video | Paper PyTorch implementation of automatic integration

Stanford Computational Imaging Lab 149 Dec 22, 2022
Tensorforce: a TensorFlow library for applied reinforcement learning

Tensorforce: a TensorFlow library for applied reinforcement learning Introduction Tensorforce is an open-source deep reinforcement learning framework,

Tensorforce 3.2k Jan 02, 2023
[ICCV 2021 Oral] NerfingMVS: Guided Optimization of Neural Radiance Fields for Indoor Multi-view Stereo

NerfingMVS Project Page | Paper | Video | Data NerfingMVS: Guided Optimization of Neural Radiance Fields for Indoor Multi-view Stereo Yi Wei, Shaohui

Yi Wei 369 Dec 24, 2022
A computer vision pipeline to identify the "icons" in Christian paintings

Christian-Iconography A computer vision pipeline to identify the "icons" in Christian paintings. A bit about iconography. Iconography is related to id

Rishab Mudliar 3 Jul 30, 2022
Simple Python application to transform Serial data into OSC messages

SerialToOSC-Bridge Simple Python application to transform Serial data into OSC messages. The current purpose is to be a compatibility layer between ha

Division of Applied Acoustics at Chalmers University of Technology 3 Jun 03, 2021
Bu repo SAHI uygulamasını mantığını öğreniyoruz.

SAHI-Learn: SAHI'den Beraber Kodlamak İster Misiniz Herkese merhabalar ben Kadir Nar. SAHI kütüphanesine gönüllü geliştiriciyim. Bu repo SAHI kütüphan

Kadir Nar 11 Aug 22, 2022
An algorithm study of the 6th iOS 10 set of Boost Camp Web Mobile

알고리즘 스터디 🔥 부스트캠프 웹모바일 6기 iOS 10조의 알고리즘 스터디 입니다. 개인적인 사정 등으로 S034, S055만 참가하였습니다. 스터디 목적 상진: 코테 합격 + 부캠끝나고 아침에 일어나기 위해 필요한 사이클 기완: 꾸준하게 자리에 앉아 공부하기 +

2 Jan 11, 2022
A Comprehensive Empirical Study of Vision-Language Pre-trained Model for Supervised Cross-Modal Retrieval

CLIP4CMR A Comprehensive Empirical Study of Vision-Language Pre-trained Model for Supervised Cross-Modal Retrieval The original data and pre-calculate

24 Dec 26, 2022
Code accompanying our NeurIPS 2021 traffic4cast challenge

Traffic forecasting on traffic movie snippets This repo contains all code to reproduce our approach to the IARAI Traffic4cast 2021 challenge. In the c

Nina Wiedemann 2 Aug 09, 2022
Stereo Hybrid Event-Frame (SHEF) Cameras for 3D Perception, IROS 2021

For academic use only. Stereo Hybrid Event-Frame (SHEF) Cameras for 3D Perception Ziwei Wang, Liyuan Pan, Yonhon Ng, Zheyu Zhuang and Robert Mahony Th

Ziwei Wang 11 Jan 04, 2023
Efficient training of deep recommenders on cloud.

HybridBackend Introduction HybridBackend is a training framework for deep recommenders which bridges the gap between evolving cloud infrastructure and

Alibaba 111 Dec 23, 2022
HPRNet: Hierarchical Point Regression for Whole-Body Human Pose Estimation

HPRNet: Hierarchical Point Regression for Whole-Body Human Pose Estimation Official PyTroch implementation of HPRNet. HPRNet: Hierarchical Point Regre

Nermin Samet 53 Dec 04, 2022
Creating predictive checklists from data using integer programming.

Learning Optimal Predictive Checklists A Python package to learn simple predictive checklists from data subject to customizable constraints. For more

Healthy ML 5 Apr 19, 2022
Little tool in python to watch anime from the terminal (the better way to watch anime)

ani-cli Script working again :), thanks to the fork by Dink4n for the alternative approach to by pass the captcha on gogoanime A cli to browse and wat

Harshith 4.5k Dec 31, 2022