A library to inspect itermediate layers of PyTorch models.

Overview

A library to inspect itermediate layers of PyTorch models.

Why?

It's often the case that we want to inspect intermediate layers of a model without modifying the code e.g. visualize attention matrices of language models, get values from an intermediate layer to feed to another layer, or applying a loss function to intermediate layers.

Install

$ pip install surgeon-pytorch

PyPI - Python Version

Usage

Inspect

Given a PyTorch model we can display all layers using get_layers:

import torch
import torch.nn as nn

from surgeon_pytorch import Inspect, get_layers

class SomeModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(5, 3)
        self.layer2 = nn.Linear(3, 2)
        self.layer3 = nn.Linear(2, 1)

    def forward(self, x):
        x1 = self.layer1(x)
        x2 = self.layer2(x1)
        y = self.layer3(x2)
        return y


model = SomeModel()
print(get_layers(model)) # ['layer1', 'layer2', 'layer3']

Then we can wrap our model to be inspected using Inspect and in every forward call the new model we will also output the provided layer outputs (in second return value):

model_wrapped = Inspect(model, layer='layer2')
x = torch.rand(1, 5)
y, x2 = model_wrapped(x)
print(x2) # tensor([[-0.2726,  0.0910]], grad_fn=<AddmmBackward0>)

We can also provide a list of layers:

model_wrapped = Inspect(model, layer=['layer1', 'layer2'])
x = torch.rand(1, 5)
y, [x1, x2] = model_wrapped(x)
print(x1) # tensor([[ 0.1739,  0.3844, -0.4724]], grad_fn=<AddmmBackward0>)
print(x2) # tensor([[-0.2238,  0.0107]], grad_fn=<AddmmBackward0>)

Or a dictionary to get named outputs:

model_wrapped = Inspect(model, layer={'x1': 'layer1', 'x2': 'layer2'})
x = torch.rand(1, 5)
y, layers = model_wrapped(x)
print(layers)
"""
{
    'x1': tensor([[ 0.3707,  0.6584, -0.2970]], grad_fn=<AddmmBackward0>),
    'x2': tensor([[-0.1953, -0.3408]], grad_fn=<AddmmBackward0>)
}
"""

TODO

  • add extract function to get intermediate block
You might also like...
Ever felt tired after preprocessing the dataset, and not wanting to write any code further to train your model? Ever encountered a situation where you wanted to record the hyperparameters of the trained model and able to retrieve it afterward? Models Playground is here to help you do that. Models playground allows you to train your models right from the browser. pyhsmm - library for approximate unsupervised inference in Bayesian Hidden Markov Models (HMMs) and explicit-duration Hidden semi-Markov Models (HSMMs), focusing on the Bayesian Nonparametric extensions, the HDP-HMM and HDP-HSMM, mostly with weak-limit approximations. PyTorch implementation and pretrained models for XCiT models. See XCiT: Cross-Covariance Image Transformer
TorchGeo is a PyTorch domain library, similar to torchvision, that provides datasets, transforms, samplers, and pre-trained models specific to geospatial data.

TorchGeo is a PyTorch domain library, similar to torchvision, that provides datasets, transforms, samplers, and pre-trained models specific to geospatial data.

Pytorch library for end-to-end transformer models training and serving

Pytorch library for end-to-end transformer models training and serving

This repository provides an efficient PyTorch-based library for training deep models.

An Efficient Library for Training Deep Models This repository provides an efficient PyTorch-based library for training deep models. Installation Make

TorchMultimodal is a PyTorch library for training state-of-the-art multimodal multi-task models at scale.

TorchMultimodal (Alpha Release) Introduction TorchMultimodal is a PyTorch library for training state-of-the-art multimodal multi-task models at scale.

Objective of the repository is to learn and build machine learning models using Pytorch. 30DaysofML Using Pytorch
Objective of the repository is to learn and build machine learning models using Pytorch. 30DaysofML Using Pytorch

30 Days Of Machine Learning Using Pytorch Objective of the repository is to learn and build machine learning models using Pytorch. List of Algorithms

Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch
Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Comments
  • Use one backbone with different heads

    Use one backbone with different heads

    Is it possible to save the results from the backbone and apply them on the heads of the all the other models. My goal was to try to save time by avoiding repeating the backbone part. Instead of running the 3 complete models (left), only run the backbone 1 time and switch only the heads for the 3 models (right), therefore not repeating executing the backbone every time in yolov5 model.

    Thank you for the help!

    question 
    opened by brunopatricio2012 4
  • Support for DataParallel?

    Support for DataParallel?

    Hi, I noticed that the current version does not support parallel models (at least those created using torch.nn.DataParallel) since the forward hook does not differentiate between the different copies of the model and a model wrapped with Inspect will just return the intermediate features of the last copy of the parallelized model to run.

    Are you planning on fixing this issue/supporting this use case?

    opened by zimmerrol 1
Releases(0.0.4)
Owner
archinet.ai
AI Research Group
archinet.ai
Official Pytorch Implementation of Relational Self-Attention: What's Missing in Attention for Video Understanding

Relational Self-Attention: What's Missing in Attention for Video Understanding This repository is the official implementation of "Relational Self-Atte

mandos 43 Dec 07, 2022
Official PyTorch(Geometric) implementation of DPGNN(DPGCN) in "Distance-wise Prototypical Graph Neural Network for Node Imbalance Classification"

DPGNN This repository is an official PyTorch(Geometric) implementation of DPGNN(DPGCN) in "Distance-wise Prototypical Graph Neural Network for Node Im

Yu Wang (Jack) 18 Oct 12, 2022
[AAAI 2022] Sparse Structure Learning via Graph Neural Networks for Inductive Document Classification

Sparse Structure Learning via Graph Neural Networks for inductive document classification Make graph dataset create co-occurrence graph for datasets.

16 Dec 22, 2022
Predicting Tweet Sentiment Maching Learning and streamlit

Predicting-Tweet-Sentiment-Maching-Learning-and-streamlit (I prefere using Visual Studio Code ) Open the folder in VS Code Run the first cell in requi

1 Nov 20, 2021
The PyTorch implementation of Directed Graph Contrastive Learning (DiGCL), NeurIPS-2021

Directed Graph Contrastive Learning The PyTorch implementation of Directed Graph Contrastive Learning (DiGCL). In this paper, we present the first con

Tong Zekun 28 Jan 08, 2023
In the case of your data having only 1 channel while want to use timm models

timm_custom Description In the case of your data having only 1 channel while want to use timm models (with or without pretrained weights), run the fol

2 Nov 26, 2021
An expansion for RDKit to read all types of files in one line

RDMolReader An expansion for RDKit to read all types of files in one line How to use? Add this single .py file to your project and import MolFromFile(

Ali Khodabandehlou 1 Dec 18, 2021
Simulator for FRC 2022 challenge: Rapid React

rrsim Simulator for FRC 2022 challenge: Rapid React out-1.mp4 Usage In order to run the simulator use the following: python3 rrsim.py [config_path] wh

1 Jan 18, 2022
discovering subdomains, hidden paths, extracting unique links

python-website-crawler discovering subdomains, hidden paths, extracting unique links pip install -r requirements.txt discover subdomain: You can give

merve 4 Sep 05, 2022
My personal Home Assistant configuration.

About This is my personal Home Assistant configuration. My guiding princile is to have full local control of all my devices. I intend everything to ru

Chris Turra 13 Jun 07, 2022
Türkiye Canlı Mobese Görüntülerinde Profesyonel Nesne Takip Sistemi

Türkiye Mobese Görüntü Takip Türkiye Mobese görüntülerinde OPENCV ve Yolo ile takip sistemi Multiple Object Tracking System in Turkish Mobese with OPE

15 Dec 22, 2022
Tooling for GANs in TensorFlow

TensorFlow-GAN (TF-GAN) TF-GAN is a lightweight library for training and evaluating Generative Adversarial Networks (GANs). Can be installed with pip

803 Dec 24, 2022
Improving Contrastive Learning by Visualizing Feature Transformation, ICCV 2021 Oral

Improving Contrastive Learning by Visualizing Feature Transformation This project hosts the codes, models and visualization tools for the paper: Impro

Bingchen Zhao 83 Dec 15, 2022
Yoga - Yoga asana classifier for python

Yoga Asana Classifier Description Hi welcome to my new deep learning project "Yo

Programminghut 35 Dec 12, 2022
Modified fork of Xuebin Qin's U-2-Net Repository. Used for demonstration purposes.

U^2-Net (U square net) Modified version of U2Net used for demonstation purposes. Paper: U^2-Net: Going Deeper with Nested U-Structure for Salient Obje

Shreyas Bhat Kera 13 Aug 28, 2022
functorch is a prototype of JAX-like composable function transforms for PyTorch.

functorch is a prototype of JAX-like composable function transforms for PyTorch.

Facebook Research 1.2k Jan 09, 2023
PyTorch implementation of 'Gen-LaneNet: a generalized and scalable approach for 3D lane detection'

(pytorch) Gen-LaneNet: a generalized and scalable approach for 3D lane detection Introduction This is a pytorch implementation of Gen-LaneNet, which p

Yuliang Guo 233 Jan 06, 2023
Pacman-AI - AI project designed by UC Berkeley. Designed reflex and minimax agents for the game Pacman.

Pacman AI Jussi Doherty CAP 4601 - Introduction to Artificial Intelligence - Fall 2020 Python version 3.0+ Source of this project This repo contains a

Jussi Doherty 1 Jan 03, 2022
Code for "Steerable Pyramid Transform Enables Robust Left Ventricle Quantification"

Code for "Steerable Pyramid Transform Enables Robust Left Ventricle Quantification" This is an end-to-end framework for accurate and robust left ventr

2 Jul 09, 2022
Density-aware Single Image De-raining using a Multi-stream Dense Network (CVPR 2018)

DID-MDN Density-aware Single Image De-raining using a Multi-stream Dense Network He Zhang, Vishal M. Patel [Paper Link] (CVPR'18) We present a novel d

He Zhang 224 Dec 12, 2022