Segmentation models with pretrained backbones. PyTorch.

Overview

logo
Python library with Neural Networks for Image
Segmentation based on PyTorch.

PyPI version Build Status Documentation Status
Downloads Generic badge

The main features of this library are:

  • High level API (just two lines to create a neural network)
  • 9 models architectures for binary and multi class segmentation (including legendary Unet)
  • 113 available encoders
  • All encoders have pre-trained weights for faster and better convergence

📚 Project Documentation 📚

Visit Read The Docs Project Page or read following README to know more about Segmentation Models Pytorch (SMP for short) library

📋 Table of content

  1. Quick start
  2. Examples
  3. Models
    1. Architectures
    2. Encoders
    3. Timm Encoders
  4. Models API
    1. Input channels
    2. Auxiliary classification output
    3. Depth
  5. Installation
  6. Competitions won with the library
  7. Contributing
  8. Citing
  9. License

Quick start

1. Create your first Segmentation model with SMP

Segmentation model is just a PyTorch nn.Module, which can be created as easy as:

import segmentation_models_pytorch as smp

model = smp.Unet(
    encoder_name="resnet34",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=1,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=3,                      # model output channels (number of classes in your dataset)
)
  • see table with available model architectures
  • see table with available encoders and their corresponding weights

2. Configure data preprocessing

All encoders have pretrained weights. Preparing your data the same way as during weights pre-training may give your better results (higher metric score and faster convergence). But it is relevant only for 1-2-3-channels images and not necessary in case you train the whole model, not only decoder.

from segmentation_models_pytorch.encoders import get_preprocessing_fn

preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')

Congratulations! You are done! Now you can train your model with your favorite framework!

💡 Examples

  • Training model for cars segmentation on CamVid dataset here.
  • Training SMP model with Catalyst (high-level framework for PyTorch), TTAch (TTA library for PyTorch) and Albumentations (fast image augmentation library) - here Open In Colab
  • Training SMP model with Pytorch-Lightning framework - here (clothes binary segmentation by @teranus).

📦 Models

Architectures

Encoders

The following is a list of supported encoders in the SMP. Select the appropriate family of encoders and click to expand the table and select a specific encoder and its pre-trained weights (encoder_name and encoder_weights parameters).

ResNet
Encoder Weights Params, M
resnet18 imagenet / ssl / swsl 11M
resnet34 imagenet 21M
resnet50 imagenet / ssl / swsl 23M
resnet101 imagenet 42M
resnet152 imagenet 58M
ResNeXt
Encoder Weights Params, M
resnext50_32x4d imagenet / ssl / swsl 22M
resnext101_32x4d ssl / swsl 42M
resnext101_32x8d imagenet / instagram / ssl / swsl 86M
resnext101_32x16d instagram / ssl / swsl 191M
resnext101_32x32d instagram 466M
resnext101_32x48d instagram 826M
ResNeSt
Encoder Weights Params, M
timm-resnest14d imagenet 8M
timm-resnest26d imagenet 15M
timm-resnest50d imagenet 25M
timm-resnest101e imagenet 46M
timm-resnest200e imagenet 68M
timm-resnest269e imagenet 108M
timm-resnest50d_4s2x40d imagenet 28M
timm-resnest50d_1s4x24d imagenet 23M
Res2Ne(X)t
Encoder Weights Params, M
timm-res2net50_26w_4s imagenet 23M
timm-res2net101_26w_4s imagenet 43M
timm-res2net50_26w_6s imagenet 35M
timm-res2net50_26w_8s imagenet 46M
timm-res2net50_48w_2s imagenet 23M
timm-res2net50_14w_8s imagenet 23M
timm-res2next50 imagenet 22M
RegNet(x/y)
Encoder Weights Params, M
timm-regnetx_002 imagenet 2M
timm-regnetx_004 imagenet 4M
timm-regnetx_006 imagenet 5M
timm-regnetx_008 imagenet 6M
timm-regnetx_016 imagenet 8M
timm-regnetx_032 imagenet 14M
timm-regnetx_040 imagenet 20M
timm-regnetx_064 imagenet 24M
timm-regnetx_080 imagenet 37M
timm-regnetx_120 imagenet 43M
timm-regnetx_160 imagenet 52M
timm-regnetx_320 imagenet 105M
timm-regnety_002 imagenet 2M
timm-regnety_004 imagenet 3M
timm-regnety_006 imagenet 5M
timm-regnety_008 imagenet 5M
timm-regnety_016 imagenet 10M
timm-regnety_032 imagenet 17M
timm-regnety_040 imagenet 19M
timm-regnety_064 imagenet 29M
timm-regnety_080 imagenet 37M
timm-regnety_120 imagenet 49M
timm-regnety_160 imagenet 80M
timm-regnety_320 imagenet 141M
GERNet
Encoder Weights Params, M
timm-gernet_s imagenet 6M
timm-gernet_m imagenet 18M
timm-gernet_l imagenet 28M
SE-Net
Encoder Weights Params, M
senet154 imagenet 113M
se_resnet50 imagenet 26M
se_resnet101 imagenet 47M
se_resnet152 imagenet 64M
se_resnext50_32x4d imagenet 25M
se_resnext101_32x4d imagenet 46M
SK-ResNe(X)t
Encoder Weights Params, M
timm-skresnet18 imagenet 11M
timm-skresnet34 imagenet 21M
timm-skresnext50_32x4d imagenet 25M
DenseNet
Encoder Weights Params, M
densenet121 imagenet 6M
densenet169 imagenet 12M
densenet201 imagenet 18M
densenet161 imagenet 26M
Inception
Encoder Weights Params, M
inceptionresnetv2 imagenet / imagenet+background 54M
inceptionv4 imagenet / imagenet+background 41M
xception imagenet 22M
EfficientNet
Encoder Weights Params, M
efficientnet-b0 imagenet 4M
efficientnet-b1 imagenet 6M
efficientnet-b2 imagenet 7M
efficientnet-b3 imagenet 10M
efficientnet-b4 imagenet 17M
efficientnet-b5 imagenet 28M
efficientnet-b6 imagenet 40M
efficientnet-b7 imagenet 63M
timm-efficientnet-b0 imagenet / advprop / noisy-student 4M
timm-efficientnet-b1 imagenet / advprop / noisy-student 6M
timm-efficientnet-b2 imagenet / advprop / noisy-student 7M
timm-efficientnet-b3 imagenet / advprop / noisy-student 10M
timm-efficientnet-b4 imagenet / advprop / noisy-student 17M
timm-efficientnet-b5 imagenet / advprop / noisy-student 28M
timm-efficientnet-b6 imagenet / advprop / noisy-student 40M
timm-efficientnet-b7 imagenet / advprop / noisy-student 63M
timm-efficientnet-b8 imagenet / advprop 84M
timm-efficientnet-l2 noisy-student 474M
timm-efficientnet-lite0 imagenet 4M
timm-efficientnet-lite1 imagenet 5M
timm-efficientnet-lite2 imagenet 6M
timm-efficientnet-lite3 imagenet 8M
timm-efficientnet-lite4 imagenet 13M
MobileNet
Encoder Weights Params, M
mobilenet_v2 imagenet 2M
timm-mobilenetv3_large_075 imagenet 1.78M
timm-mobilenetv3_large_100 imagenet 2.97M
timm-mobilenetv3_large_minimal_100 imagenet 1.41M
timm-mobilenetv3_small_075 imagenet 0.57M
timm-mobilenetv3_small_100 imagenet 0.93M
timm-mobilenetv3_small_minimal_100 imagenet 0.43M
DPN
Encoder Weights Params, M
dpn68 imagenet 11M
dpn68b imagenet+5k 11M
dpn92 imagenet+5k 34M
dpn98 imagenet 58M
dpn107 imagenet+5k 84M
dpn131 imagenet 76M
VGG
Encoder Weights Params, M
vgg11 imagenet 9M
vgg11_bn imagenet 9M
vgg13 imagenet 9M
vgg13_bn imagenet 9M
vgg16 imagenet 14M
vgg16_bn imagenet 14M
vgg19 imagenet 20M
vgg19_bn imagenet 20M

* ssl, swsl - semi-supervised and weakly-supervised learning on ImageNet (repo).

Timm Encoders

docs

Pytorch Image Models (a.k.a. timm) has a lot of pretrained models and interface which allows using these models as encoders in smp, however, not all models are supported

  • transformer models do not have features_only functionality implemented
  • some models do not have appropriate strides

Total number of supported encoders: 467

🔁 Models API

  • model.encoder - pretrained backbone to extract features of different spatial resolution
  • model.decoder - depends on models architecture (Unet/Linknet/PSPNet/FPN)
  • model.segmentation_head - last block to produce required number of mask channels (include also optional upsampling and activation)
  • model.classification_head - optional block which create classification head on top of encoder
  • model.forward(x) - sequentially pass x through model`s encoder, decoder and segmentation head (and classification head if specified)
Input channels

Input channels parameter allows you to create models, which process tensors with arbitrary number of channels. If you use pretrained weights from imagenet - weights of first convolution will be reused. For 1-channel case it would be a sum of weights of first convolution layer, otherwise channels would be populated with weights like new_weight[:, i] = pretrained_weight[:, i % 3] and than scaled with new_weight * 3 / new_in_channels.

model = smp.FPN('resnet34', in_channels=1)
mask = model(torch.ones([1, 1, 64, 64]))
Auxiliary classification output

All models support aux_params parameters, which is default set to None. If aux_params = None then classification auxiliary output is not created, else model produce not only mask, but also label output with shape NC. Classification head consists of GlobalPooling->Dropout(optional)->Linear->Activation(optional) layers, which can be configured by aux_params as follows:

aux_params=dict(
    pooling='avg',             # one of 'avg', 'max'
    dropout=0.5,               # dropout ratio, default is None
    activation='sigmoid',      # activation function, default is None
    classes=4,                 # define number of output labels
)
model = smp.Unet('resnet34', classes=4, aux_params=aux_params)
mask, label = model(x)
Depth

Depth parameter specify a number of downsampling operations in encoder, so you can make your model lighter if specify smaller depth.

model = smp.Unet('resnet34', encoder_depth=4)

🛠 Installation

PyPI version:

$ pip install segmentation-models-pytorch

Latest version from source:

$ pip install git+https://github.com/qubvel/segmentation_models.pytorch

🏆 Competitions won with the library

Segmentation Models package is widely used in the image segmentation competitions. Here you can find competitions, names of the winners and links to their solutions.

🤝 Contributing

Run test
$ docker build -f docker/Dockerfile.dev -t smp:dev . && docker run --rm smp:dev pytest -p no:cacheprovider
Generate table
$ docker build -f docker/Dockerfile.dev -t smp:dev . && docker run --rm smp:dev python misc/generate_table.py

📝 Citing

@misc{Yakubovskiy:2019,
  Author = {Pavel Yakubovskiy},
  Title = {Segmentation Models Pytorch},
  Year = {2020},
  Publisher = {GitHub},
  Journal = {GitHub repository},
  Howpublished = {\url{https://github.com/qubvel/segmentation_models.pytorch}}
}

🛡️ License

Project is distributed under MIT License

Comments
  • RuntimeError: Error(s) in loading state_dict for Unet

    RuntimeError: Error(s) in loading state_dict for Unet

    i get the error below when i try to use my unet with se_resnext50 models trained weight file for ensembling :

    RuntimeError: Error(s) in loading state_dict for Unet: Missing key(s) in state_dict: "decoder.blocks.0.conv1.0.weight", "decoder.blocks.0.conv1.1.weight", "decoder.blocks.0.conv1.1.bias", "decoder.blocks.0.conv1.1.running_mean", "decoder.blocks.0.conv1.1.running_var", "decoder.blocks.0.conv2.0.weight", "decoder.blocks.0.conv2.1.weight", "decoder.blocks.0.conv2.1.bias", "decoder.blocks.0.conv2.1.running_mean", "decoder.blocks.0.conv2.1.running_var", "decoder.blocks.1.conv1.0.weight", "decoder.blocks.1.conv1.1.weight", "decoder.blocks.1.conv1.1.bias", "decoder.blocks.1.conv1.1.running_mean", "decoder.blocks.1.conv1.1.running_var", "decoder.blocks.1.conv2.0.weight", "decoder.blocks.1.conv2.1.weight", "decoder.blocks.1.conv2.1.bias", "decoder.blocks.1.conv2.1.running_mean", "decoder.blocks.1.conv2.1.running_var", "decoder.blocks.2.conv1.0.weight", "decoder.blocks.2.conv1.1.weight", "decoder.blocks.2.conv1.1.bias", "decoder.blocks.2.conv1.1.running_mean", "decoder.blocks.2.conv1.1.running_var", "decoder.blocks.2.conv2.0.weight", "decoder.blocks.2.conv2.1.weight", "decoder.blocks.2.conv2.1.bias", "decoder.blocks.2.conv2.1.running_mean", "decoder.blocks.2.conv2.1.running_var", "decoder.blocks.3.conv1.0.weight", "decoder.blocks.3.conv1.1.weight", "decoder.blocks.3.conv1.1.bias", "decoder.blocks.3.conv1.1.running_mean", "decoder.blocks.3.conv1.1.running_var", "decoder.blocks.3.conv2.0.weight", "decoder.blocks.3.conv2.1.weight", "decoder.blocks.3.conv2.1.bias", "decoder.blocks.3.conv2.1.running_mean", "decoder.blocks.3.conv2.1.running_var", "decoder.blocks.4.conv1.0.weight", "decoder.blocks.4.conv1.1.weight", "decoder.blocks.4.conv1.1.bias", "decoder.blocks.4.conv1.1.running_mean", "decoder.blocks.4.conv1.1.running_var", "decoder.blocks.4.conv2.0.weight", "decoder.blocks.4.conv2.1.weight", "decoder.blocks.4.conv2.1.bias", "decoder.blocks.4.conv2.1.running_mean", "decoder.blocks.4.conv2.1.running_var", "segmentation_head.0.weight", "segmentation_head.0.bias". Unexpected key(s) in state_dict: "decoder.layer1.block.0.block.0.weight", "decoder.layer1.block.0.block.1.weight", "decoder.layer1.block.0.block.1.bias", "decoder.layer1.block.0.block.1.running_mean", "decoder.layer1.block.0.block.1.running_var", "decoder.layer1.block.0.block.1.num_batches_tracked", "decoder.layer1.block.1.block.0.weight", "decoder.layer1.block.1.block.1.weight", "decoder.layer1.block.1.block.1.bias", "decoder.layer1.block.1.block.1.running_mean", "decoder.layer1.block.1.block.1.running_var", "decoder.layer1.block.1.block.1.num_batches_tracked", "decoder.layer2.block.0.block.0.weight", "decoder.layer2.block.0.block.1.weight", "decoder.layer2.block.0.block.1.bias", "decoder.layer2.block.0.block.1.running_mean", "decoder.layer2.block.0.block.1.running_var", "decoder.layer2.block.0.block.1.num_batches_tracked", "decoder.layer2.block.1.block.0.weight", "decoder.layer2.block.1.block.1.weight", "decoder.layer2.block.1.block.1.bias", "decoder.layer2.block.1.block.1.running_mean", "decoder.layer2.block.1.block.1.running_var", "decoder.layer2.block.1.block.1.num_batches_tracked", "decoder.layer3.block.0.block.0.weight", "decoder.layer3.block.0.block.1.weight", "decoder.layer3.block.0.block.1.bias", "decoder.layer3.block.0.block.1.running_mean", "decoder.layer3.block.0.block.1.running_var", "decoder.layer3.block.0.block.1.num_batches_tracked", "decoder.layer3.block.1.block.0.weight", "decoder.layer3.block.1.block.1.weight", "decoder.layer3.block.1.block.1.bias", "decoder.layer3.block.1.block.1.running_mean", "decoder.layer3.block.1.block.1.running_var", "decoder.layer3.block.1.block.1.num_batches_tracked", "decoder.layer4.block.0.block.0.weight", "decoder.layer4.block.0.block.1.weight", "decoder.layer4.block.0.block.1.bias", "decoder.layer4.block.0.block.1.running_mean", "decoder.layer4.block.0.block.1.running_var", "decoder.layer4.block.0.block.1.num_batches_tracked", "decoder.layer4.block.1.block.0.weight", "decoder.layer4.block.1.block.1.weight", "decoder.layer4.block.1.block.1.bias", "decoder.layer4.block.1.block.1.running_mean", "decoder.layer4.block.1.block.1.running_var", "decoder.layer4.block.1.block.1.num_batches_tracked", "decoder.layer5.block.0.block.0.weight", "decoder.layer5.block.0.block.1.weight", "decoder.layer5.block.0.block.1.bias", "decoder.layer5.block.0.block.1.running_mean", "decoder.layer5.block.0.block.1.running_var", "decoder.layer5.block.0.block.1.num_batches_tracked", "decoder.layer5.block.1.block.0.weight", "decoder.layer5.block.1.block.1.weight", "decoder.layer5.block.1.block.1.bias", "decoder.layer5.block.1.block.1.running_mean", "decoder.layer5.block.1.block.1.running_var", "decoder.layer5.block.1.block.1.num_batches_tracked", "decoder.final_conv.weight", "decoder.final_conv.bias".

    Stale 
    opened by mobassir94 28
  • cannot import name 'container_abcs' from 'torch._six'

    cannot import name 'container_abcs' from 'torch._six'

    Encountering this from today: 19-Jun-2021

    ImportError Traceback (most recent call last) in () 4 get_ipython().system('pip install -U segmentation-models-pytorch') 5 ----> 6 import segmentation_models_pytorch as smp 7 8

    11 frames /usr/local/lib/python3.7/dist-packages/timm/models/layers/helpers.py in () 4 """ 5 from itertools import repeat ----> 6 from torch._six import container_abcs 7 8

    ImportError: cannot import name 'container_abcs' from 'torch._six' (/usr/local/lib/python3.7/dist-packages/torch/_six.py)

    opened by lifeischaotic 27
  • How to use metrics for multi-class binary mask target and multi-class multi-channel output?

    How to use metrics for multi-class binary mask target and multi-class multi-channel output?

    I saw in the documentation that the metrics for multilabel prediction require (batch, num_class, height, width). But then, I have a multi-class mask of one channel as target where each pixel are labeled by the class.

    How do I use this for that scenario?

    Also, this seems to be computing it per batch. How do I do it per epoch?

    import segmentation_models_pytorch as smp
    
    # lets assume we have multilabel prediction for 3 classes
    output = torch.rand([10, 3, 256, 256])
    target = torch.rand([10, 3, 256, 256]).round().long()
    
    # first compute statistics for true positives, false positives, false negative and
    # true negative "pixels"
    tp, fp, fn, tn = smp.metrics.get_stats(output, target, mode='multilabel', threshold=0.5)
    
    # then compute metrics with required reduction (see metric docs)
    iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
    f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro")
    f2_score = smp.metrics.fbeta_score(tp, fp, fn, tn, beta=2, reduction="micro")
    accuracy = smp.metrics.accuracy(tp, fp, fn, tn, reduction="macro")
    recall = smp.metrics.recall(tp, fp, fn, tn, reduction="micro-imagewise")
    
    Stale 
    opened by sarmientoj24 21
  • Add Apple's MobileOne encoder

    Add Apple's MobileOne encoder

    Hello,

    I added support for Apple's MobileOne encoder.

    Paper: Link

    There were very few changes I had to make to their official github repo: Link

    It works with all decoders and has impressive inference time for images with 256x256:

    | Encoder-Decoder| Inference time in vanilla torch| | ------------- | ------------- | | mobileone_s1_pspnet_256 | 0.0313718318939209 | mobileone_s0_pan_256 | 0.03421592712402344 mobileone_s2_pspnet_256 | 0.036206960678100586 mobileone_s3_pspnet_256 | 0.04711484909057617 mobileone_s1_pan_256 | 0.05329489707946777 mobileone_s0_linknet_256 | 0.05789995193481445 mobileone_s0_deeplabv3plus_256 | 0.058853864669799805 mobileone_s0_fpn_256 | 0.07664108276367188 mobileone_s4_pspnet_256 | 0.0768282413482666 mobileone_s1_deeplabv3plus_256 | 0.07886672019958496 mobileone_s2_pan_256 | 0.07946181297302246 mobileone_s3_pan_256 | 0.09101414680480957 mobileone_s1_fpn_256 | 0.09615683555603027 mobileone_s1_linknet_256 | 0.09956574440002441 mobileone_s2_fpn_256 | 0.11291790008544922 mobileone_s0_unet_256 | 0.11676502227783203 mobileone_s2_linknet_256 | 0.12518310546875 mobileone_s3_deeplabv3plus_256 | 0.12642478942871094 mobileone_s2_deeplabv3plus_256 | 0.1289658546447754 mobileone_s3_fpn_256 | 0.1370537281036377 mobileone_s4_pan_256 | 0.14015984535217285 mobileone_s1_unet_256 | 0.15249204635620117 mobileone_s3_linknet_256 | 0.15824413299560547 mobileone_s4_deeplabv3plus_256 | 0.16476082801818848 mobileone_s0_manet_256 | 0.17203474044799805 mobileone_s2_unet_256 | 0.17334604263305664 mobileone_s4_fpn_256 | 0.182358980178833 mobileone_s3_unet_256 | 0.20330286026000977 mobileone_s4_linknet_256 | 0.21462082862854004 mobileone_s0_deeplabv3_256 | 0.22992897033691406 mobileone_s4_unet_256 | 0.24337363243103027 mobileone_s0_unetplusplus_256 | 0.29451799392700195 mobileone_s1_deeplabv3_256 | 0.31217503547668457 mobileone_s1_manet_256 | 0.3140380382537842 mobileone_s1_unetplusplus_256 | 0.5090749263763428 mobileone_s2_deeplabv3_256 | 0.5372707843780518 mobileone_s3_deeplabv3_256 | 0.5489542484283447 mobileone_s2_unetplusplus_256 | 0.5728631019592285 mobileone_s4_deeplabv3_256 | 0.638185977935791 mobileone_s2_manet_256 | 0.6446411609649658 mobileone_s3_manet_256 | 0.6838269233703613 mobileone_s3_unetplusplus_256 | 0.6991360187530518 mobileone_s4_manet_256 | 0.748121976852417 mobileone_s4_unetplusplus_256 | 0.9898359775543213

    opened by kevinpl07 19
  • Class weights for Losses

    Class weights for Losses

    Hi, love using this library.

    I have encountered problem, that my datasets are very imbalanced, they have multiple classes, but classes take less than 2% of the image space, they are mainly small objects, the rest is background and it seems that Unet fails to predict accurately.

    Using your segmentation_models for Tensorflow library I was able to use class weights for losses and it increased model prediction accuracy.

    Is it possible to use class weights on this library? Might there be any code snippet?

    Best Regards, Augustas

    opened by augasur 19
  • Feature: support `timm` features_only functionality

    Feature: support `timm` features_only functionality

    I've noticed more and more timm backbones being added here, which is great, but a lot of the effort is currently duplicating some features of timm, ie tracking channel numbers, modifying the networks, etc.

    timm has a features_only arg in the model factory that will return a model setup as a backbone to produce pyramid features. It has a .features_info attribute you can query to understand what the channels of each output, the approx reduction factor is, etc.

    I've adapted the unet and deeplab impl here in the past to use this successfully, although it was quick hack and train work, nothing to serve as a clean example.

    If this was supported, any timm model (vit excluded right now) can be used as a backbone in generic fashion, just by model name string passed to creation fn, possibly a small config mapping of model types to index specificiations (some models have slightly different out_indices alignment to strides if they happen be a stride 64 model, or don't have a stride=2 feature, etc). All tap points are the latest possible point for a given feature map stride. Some, but not all of the timm backbones also support an output_stride= arg that will dilate the blocks appropriately for 8, 16 network strides.

    Some references:

    • https://rwightman.github.io/pytorch-image-models/feature_extraction/#multi-scale-feature-maps-feature-pyramid
    • https://github.com/rwightman/efficientdet-pytorch/blob/92bb66fd0cf91d0e23fe8b10cba97e2f0bb9884f/effdet/efficientdet.py#L554-L569

    For most of the models, the featuers are extracted by flattening part of the backbone model via wrapper. A few models where the feature taps are embedded deep within the model use hooks, which causes some issues with torchscript but that will likely be fixed soon in PyTorch.

    opened by rwightman 18
  • How to modify the sample code for multiple classifications. I have modified it according to the readme file, but the result after training is a single classification, and the masks of other categories are empty.

    How to modify the sample code for multiple classifications. I have modified it according to the readme file, but the result after training is a single classification, and the masks of other categories are empty.

    Thank you very much for any help, your code is so cool! Hi, I am using the segmentation code from the example. I use my own data set to perfectly segment individual categories. But when I tried to split images of multiple categories, I still just split one category. The image I split includes three categories and a background. What I have done is ACTIVATION = 'softmax2d' I changed the category to 4 (including background) The current result is that the output of the training output is 4 masks, but only one class and one background class are included, and the other two classes are empty. Thank you again!

    Stale 
    opened by siyangbing 17
  • update diceloss

    update diceloss

    In the master branch, DiceLoss is implemented in such a way that the loss computed is along all class masks while it should be a mean of each diceloss for each class. And so, multiclass segmentation does not work well. This update should correct this problem

    Stale 
    opened by julienguegan 13
  • Conversion PyTorch => ONNX => TensorRT

    Conversion PyTorch => ONNX => TensorRT

    Hi,

    I'm trying to convert a segmentation model (ENCODER = efficientnet-b2, DECODER = FPN) to ONNX and afterwards to TensorRT (TRT). Converting to ONNX seems to work but I can't get the conversion to TRT right. I tried the 'torch2trt' library but couldn't succeed...

    Does anyone has experience with this?

    Running: efficientnet-pytorch==0.6.3 onnx==1.7.0 segmentation-models-pytorch==0.1.0 torch==1.6.0 torch2trt==0.1.0 torchvision==0.7.0

    Thanks in advance,

    Michiel

    Stale 
    opened by michieljanssen97 13
  • Size mismatch occurs in UNet model at 5th stage

    Size mismatch occurs in UNet model at 5th stage

    I used the SMP library to create a UNet model with the following configurations: model = smp.Unet(encoder_name='resnet50', encoder_weights='imagenet', in_channels=3, classes=30)

    However, I have also tried with other encoders (including the default resnet34) and the error seems to appear for every encoder that I choose. I am training it on a custom dataset of which the dimensions of the images are: w=320, h=192

    My code runs fine until one of the final steps in the decoder block. The error traces back to smp/unet/decoder.py. When I'm running a training epoch, the error occurs in def forward(self, x, skip=None) of decoder.py

    def forward(self, x, skip=None):
        x = F.interpolate(x, scale_factor=2, mode="nearest")
        if skip is not None:
                x = torch.cat([x, skip], dim=1)
                x = self.attention1(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.attention2(x)
        return x
    

    For the first steps, everything runs fine and the dimensions of 'x' match with 'skip'. Below you can find a list of the dimensions of both x and skip as I go through the decoder:

    STEP 1
    x.shape
    Out[1]: torch.Size([1, 2048, 14, 20])
    skip.shape
    Out[2]: torch.Size([1, 1024, 14, 20])
    STEP 2
    x.shape
    Out[3]: torch.Size([1, 256, 28, 40])
    skip.shape
    Out[4]: torch.Size([1, 512, 28, 40])
    STEP 3
    x.shape
    Out[5]: torch.Size([1, 128, 56, 80])
    skip.shape
    Out[6]: torch.Size([1, 256, 55, 80])
    STEP 4
    x.shape
    Out[7]: torch.Size([1, 128, 56, 80])
    skip.shape
    Out[8]: torch.Size([1, 256, 55, 80])
    STEP 5
    x.shape
    Out[9]: torch.Size([1, 3, 192, 320])
    skip.shape
    Out[10]: torch.Size([1, 256, 55, 80])
    

    Around step 3, a mismatch between the tensors starts occurring which causes the error. This error traceback can be seen in the indented block below. What I find weird about this, is that I have used the exact same codebase with a different dataset that only consisted of 6 classes and in that case there was no issue. I am also unsure where this is happening as I cannot seem to find the root cause.

    Traceback

    (most recent call last): File "/Users/fc/Desktop/ct/segmentation_code/main.py", line 141, in trainer.train() File "/Users/fc/Desktop/ct/segmentation_code/ops/trainer.py", line 44, in train self.train_logs = self.train_epoch.run(self.trainloader) File "/Users/fc/.local/lib/python3.8/site-packages/segmentation_models_pytorch/utils/train.py", line 47, in run loss, y_pred = self.batch_update(x, y) File "/Users/fc/.local/lib/python3.8/site-packages/segmentation_models_pytorch/utils/train.py", line 87, in batch_update prediction = self.model.forward(x) File "/Users/fc/.local/lib/python3.8/site-packages/segmentation_models_pytorch/base/model.py", line 16, in forward decoder_output = self.decoder(*features) File "/Users/fc/miniconda3/envs/ct/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "/Users/fc/.local/lib/python3.8/site-packages/segmentation_models_pytorch/unet/decoder.py", line 119, in forward x = decoder_block(x, skip) File "/Users/fc/miniconda3/envs/ct/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "/Users/fc/.local/lib/python3.8/site-packages/segmentation_models_pytorch/unet/decoder.py", line 38, in forward x = torch.cat([x, skip], dim=1) RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 56 but got size 55 for tensor number 1 in the list.

    Stale 
    opened by Fritskee 12
  • Unet decoder upsampling

    Unet decoder upsampling

    Hi

    I am using a Unet model with the encoder set to 'resnet34' and the pretrained weights are imagenet.

    When I look at the model I do not see where the upsampling is occuring. The convolutions in the encoder side are occuring (although the downsampling is seemingly occuring after the intended layer e.g downsampling from layer 1 to layer 2 only occurs after layer 2), however I do not see where the upsampling takes place in the decoder side.

    There is also the case where I do not see the centre block convolutions occuring.

    Can I please be explained where the upsampling occurs?

    My model for reference:

    resnet34 Unet model.txt

    opened by DamienLopez1 12
  • Recommended way to load pretrained weights for encoder from checkpoint file.

    Recommended way to load pretrained weights for encoder from checkpoint file.

    I have a pretrained model checkpoint that I would like to use as the encoder weights for a segmentation model and then train this segmentation model on a new task. It looks likethe only options for encoder_weights argument are strings to certain pretrained weights within the smp library that are listed in the table. Is there a workaround to for example load some other pretrained resnet50 backbone in form of a checkpoint file as the encoder weight to an smp model?

    opened by nilsleh 0
  • MixVisionTransformer in combination with PAN fails with

    MixVisionTransformer in combination with PAN fails with "encoder does not support dilated mode"

    import segmentation_models_pytorch as smp
    
    smp.PAN(encoder_name="mit_b0")
    

    raises the exception:

    ValueError: MixVisionTransformer encoder does not support dilated mode
    

    Since the default PAN uses dilation, this config is uncompatible atm?

    If we use a configuration of PAN that does not use dilation the error, of course, does not apper:

    smp.PAN(encoder_name="mit_b0", encoder_output_stride=32)
    

    I did not test yet though if output strides of 32 still deliver comparable results. My guess would be that the default stride of 16 should encode a lot more of information that might be beneficial for better performence.

    Is there any way to get it to work with dilation?

    opened by Daniel451 2
  • How to compute metrics for each class in multi class segmentation

    How to compute metrics for each class in multi class segmentation

    I would compute the metrics individually for each class so I would like to have in output a (1xC) vector where C is the number of classes, I was trying like this but it throws me an error:

    output = torch.rand([10, 3, 256, 256])
    target = torch.rand([10, 1, 256, 256]).round().long()
    
    # first compute statistics for true positives, false positives, false negative and
    # true negative "pixels"
    tp, fp, fn, tn = smp.metrics.get_stats(output, target, mode='multi class', num_classes = 3)
    
    # then compute metrics with required reduction (see metric docs)
    iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="macro-imagewise")
    f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="macro-imagewise")
    false_negatives = smp.metrics.false_negative_rate(tp, fp, fn, tn, reduction=None)
    recall = smp.metrics.recall(tp, fp, fn, tn, reduction=None)
    

    The error:

    ValueError: For ``multiclass`` mode ``target`` should be one of the integer types, got torch.float32.
    
    opened by santurini 1
  • Softmax activation function throws deprecation warning

    Softmax activation function throws deprecation warning

    When defining a smp model in __init__() as:

    self.base = smp.Unet(encoder_name='resnet50', pretrained='imagenet', 
                                      in_channels=3, classes=7,
                                      activation='softmax') 
    

    This will throw the following warning upon initialisation:

    ~/anaconda3/envs/geo/lib/python3.7/site-packages/segmentation_models_pytorch/base/modules.py:116: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument. return self.activation(x)

    As the softmax is defined by passing 'softmax' as an arg, I'm not sure where/how to include the dim as the warning suggests? Many thanks

    See also this closed (but without stating resolution) issue: Originally posted by @vdplasthijs in https://github.com/qubvel/segmentation_models.pytorch/issues/169#issuecomment-1334066128

    opened by vdplasthijs 0
  • AttributeError: module 'segmentation_models_pytorch' has no attribute 'utils'

    AttributeError: module 'segmentation_models_pytorch' has no attribute 'utils'

    I was following the cars example but I keep getting this error. I installed the module with pip in both suggested ways but I keep getting this error even if I checked the presence of utils in the repo. What am I doing wrong?

    This is the code:

    !pip install git+https://github.com/qubvel/segmentation_models.pytorch
    # !pip install -q segmentation-models-pytorch
    
    train_epoch = smp.utils.train.TrainEpoch(
        model, 
        loss=loss, 
        metrics=metrics, 
        optimizer=optimizer,
        device=DEVICE,
        verbose=True,
    )
    
    opened by santurini 2
Releases(v0.3.1)
  • v0.3.1(Nov 30, 2022)

  • v0.3.0(Jul 29, 2022)

    Updates

    • Added smp.metrics module with different metrics based on confusion matrix, see docs
    • Added new notebook with training example using pytorch-lightning Open In Colab
    • Improved handling of incorrect input image size error (checking image size is 2^n)
    • Codebase refactoring and style checks (black, flake8)
    • Minor typo fixes and bug fixes

    Breaking changes

    • utils module is going to be deprecated, if you still need it import it manually from segmentation_models_pytorch import utils

    Thanks a lot for all contributors!

    Source code(tar.gz)
    Source code(zip)
  • v0.2.1(Nov 18, 2021)

  • v0.2.0(Jul 5, 2021)

    Updates

    • New architecture: MANet (#310)
    • New encoders from timm: mobilenetv3 (#355) and gernet (#344)
    • New loss functions in smp.losses module (smp.utils.losses would be deprecated in future versions)
    • New pretrained weight initialization for first convolution if in_channels > 3
    • Updated timm version (0.4.12)
    • Bug fixes and docs improvement

    Thanks to @azkalot1 @JulienMaille @originlake @Kupchanski @loopdigga96 @zurk @nmerty @ludics @Vozf @markson14 and others!

    Source code(tar.gz)
    Source code(zip)
  • v0.1.3(Dec 13, 2020)

    Updates

    • New architecture Unet++ (#279)
    • New encoders RegNet, ResNest, SK-Net, Res2Net (#286)
    • Updated timm version (0.3.2)
    • Improved docstrings and typehints for models
    • Project documentation on https://smp.readthedocs.io

    Thanks to @azkalot1 for the new encoders and architecture!

    Source code(tar.gz)
    Source code(zip)
  • v0.1.2(Sep 28, 2020)

  • v0.1.1(Sep 26, 2020)

    Updates

    • New decoders DeepLabV3, DeepLabV3+, PAN
    • New backbones (encoders) timm-efficientnet*
    • New pretrained weights (ssl, wsl) for resnets
    • New pretrained weights (advprop) for efficientnets

    And some small fixes.

    Thanks @IlyaDobrynin @gavrin-s @lizmisha @suitre77 @thisisiron @phamquiluan and all other contributers!

    Source code(tar.gz)
    Source code(zip)
  • V0.1.0(Dec 9, 2019)

    Updates

    1. New backbones (mobilenet, efficientnet, inception)
    2. depth and in_channels options for all models
    3. Auxiliary classification output

    Note!

    Model architectures have been changed, use previous versions for weights compatibility!

    Source code(tar.gz)
    Source code(zip)
  • v0.0.3(Sep 28, 2019)

  • v0.0.2(Sep 19, 2019)

Owner
Pavel Yakubovskiy
Pavel Yakubovskiy
NEO: Non Equilibrium Sampling on the orbit of a deterministic transform

NEO: Non Equilibrium Sampling on the orbit of a deterministic transform Description of the code This repo describes the NEO estimator described in the

0 Dec 01, 2021
Normalization Matters in Weakly Supervised Object Localization (ICCV 2021)

Normalization Matters in Weakly Supervised Object Localization (ICCV 2021) 99% of the code in this repository originates from this link. ICCV 2021 pap

Jeesoo Kim 10 Feb 01, 2022
A bunch of random PyTorch models using PyTorch's C++ frontend

PyTorch Deep Learning Models using the C++ frontend Gettting started Clone the repo 1. https://github.com/mrdvince/pytorchcpp 2. cd fashionmnist or

Vince 0 Jul 13, 2021
Code for the tech report Toward Training at ImageNet Scale with Differential Privacy

Differentially private Imagenet training Code for the tech report Toward Training at ImageNet Scale with Differential Privacy by Alexey Kurakin, Steve

Google Research 29 Nov 03, 2022
This tool uses Deep Learning to help you draw and write with your hand and webcam.

This tool uses Deep Learning to help you draw and write with your hand and webcam. A Deep Learning model is used to try to predict whether you want to have 'pencil up' or 'pencil down'.

lmagne 169 Dec 10, 2022
Code for LIGA-Stereo Detector, ICCV'21

LIGA-Stereo Introduction This is the official implementation of the paper LIGA-Stereo: Learning LiDAR Geometry Aware Representations for Stereo-based

Xiaoyang Guo 75 Dec 09, 2022
This is a repository with the code for the ACL 2019 paper

The Story of Heads This is the official repo for the following papers: (ACL 2019) Analyzing Multi-Head Self-Attention: Specialized Heads Do the Heavy

231 Nov 15, 2022
[Nature Machine Intelligence' 21] "Advancing COVID-19 Diagnosis with Privacy-Preserving Collaboration in Artificial Intelligence"

[UCADI] COVID-19 Diagnosis With Federated Learning Intro We developed a Federated Learning (FL) Framework for global researchers to collaboratively tr

HUST EIC AI-LAB 30 Dec 12, 2022
High performance, easy-to-use, and scalable machine learning (ML) package, including linear model (LR), factorization machines (FM), and field-aware factorization machines (FFM) for Python and CLI interface.

What is xLearn? xLearn is a high performance, easy-to-use, and scalable machine learning package that contains linear model (LR), factorization machin

Chao Ma 3k Jan 03, 2023
The code repository for "PyCIL: A Python Toolbox for Class-Incremental Learning" in PyTorch.

PyCIL: A Python Toolbox for Class-Incremental Learning Introduction • Methods Reproduced • Reproduced Results • How To Use • License • Acknowledgement

Fu-Yun Wang 258 Dec 31, 2022
PyTorch implementation of "MLP-Mixer: An all-MLP Architecture for Vision" Tolstikhin et al. (2021)

mlp-mixer-pytorch PyTorch implementation of "MLP-Mixer: An all-MLP Architecture for Vision" Tolstikhin et al. (2021) Usage import torch from mlp_mixer

isaac 27 Jul 09, 2022
NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size

NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size Xuanyi Dong, Lu Liu, Katarzyna Musial, Bogdan Gabrys in IEEE Transactions o

D-X-Y 137 Dec 20, 2022
Repository to run object detection on a model trained on an autonomous driving dataset.

Autonomous Driving Object Detection on the Raspberry Pi 4 Description of Repository This repository contains code and instructions to configure the ne

Ethan 51 Nov 17, 2022
Official code for 'Robust Siamese Object Tracking for Unmanned Aerial Manipulator' and offical introduction to UAMT100 benchmark

SiamSA: Robust Siamese Object Tracking for Unmanned Aerial Manipulator Demo video 📹 Our video on Youtube and bilibili demonstrates the evaluation of

Intelligent Vision for Robotics in Complex Environment 12 Dec 18, 2022
Using python and scikit-learn to make stock predictions

MachineLearningStocks in python: a starter project and guide EDIT as of Feb 2021: MachineLearningStocks is no longer actively maintained MachineLearni

Robert Martin 1.3k Dec 29, 2022
An Easy-to-use, Modular and Prolongable package of deep-learning based Named Entity Recognition Models.

DeepNER An Easy-to-use, Modular and Prolongable package of deep-learning based Named Entity Recognition Models. This repository contains complex Deep

Derrick 9 May 30, 2022
Lung Pattern Classification for Interstitial Lung Diseases Using a Deep Convolutional Neural Network

ild-cnn This is supplementary material for the manuscript: "Lung Pattern Classification for Interstitial Lung Diseases Using a Deep Convolutional Neur

22 Nov 05, 2022
Image-to-image translation with conditional adversarial nets

pix2pix Project | Arxiv | PyTorch Torch implementation for learning a mapping from input images to output images, for example: Image-to-Image Translat

Phillip Isola 9.3k Jan 08, 2023
Pre-trained BERT Models for Ancient and Medieval Greek, and associated code for LaTeCH 2021 paper titled - "A Pilot Study for BERT Language Modelling and Morphological Analysis for Ancient and Medieval Greek"

Ancient Greek BERT The first and only available Ancient Greek sub-word BERT model! State-of-the-art post fine-tuning on Part-of-Speech Tagging and Mor

Pranaydeep Singh 22 Dec 08, 2022
A mini library for Policy Gradients with Parameter-based Exploration, with reference implementation of the ClipUp optimizer from NNAISENSE.

PGPElib A mini library for Policy Gradients with Parameter-based Exploration [1] and friends. This library serves as a clean re-implementation of the

NNAISENSE 56 Jan 01, 2023