Semi-Supervised Semantic Segmentation with Cross-Consistency Training (CCT)

Overview

Semi-Supervised Semantic Segmentation with Cross-Consistency Training (CCT)

Paper, Project Page

This repo contains the official implementation of CVPR 2020 paper: Semi-Supervised Semantic Segmentation with Cross-Consistecy Training, which adapts the traditional consistency training framework of semi-supervised learning for semantic segmentation, with an extension to weak-supervised learning and learning on multiple domains.

Highlights

(1) Consistency Training for semantic segmentation.
We observe that for semantic segmentation, due to the dense nature of the task, the cluster assumption is more easily enforced over the hidden representations rather than the inputs.

(2) Cross-Consistecy Training.
We propose CCT (Cross-Consistecy Training) for semi-supervised semantic segmentation, where we define a number of novel perturbations, and show the effectiveness of enforcing consistency over the encoder's outputs rather than the inputs.

(3) Using weak-labels and pixel-level labels from multiple domains.
The proposed method is quite simple and flexible, and can easily be extended to use image-level labels and pixel-level labels from multiple-domains.

Requirements

This repo was tested with Ubuntu 18.04.3 LTS, Python 3.7, PyTorch 1.1.0, and CUDA 10.0. But it should be runnable with recent PyTorch versions >=1.1.0.

The required packages are pytorch and torchvision, together with PIL and opencv for data-preprocessing and tqdm for showing the training progress. With some additional modules like dominate to save the results in the form of HTML files. To setup the necessary modules, simply run:

pip install -r requirements.txt

Dataset

In this repo, we use Pascal VOC, to obtain it, first download the original dataset, after extracting the files we'll end up with VOCtrainval_11-May-2012/VOCdevkit/VOC2012 containing the image sets, the XML annotation for both object detection and segmentation, and JPEG images.
The second step is to augment the dataset using the additionnal annotations provided by Semantic Contours from Inverse Detectors. Download the rest of the annotations SegmentationClassAug and add them to the path VOCtrainval_11-May-2012/VOCdevkit/VOC2012, now we're set, for training use the path to VOCtrainval_11-May-2012.

Training

To train a model, first download PASCAL VOC as detailed above, then set data_dir to the dataset path in the config file in configs/config.json and set the rest of the parameters, like the number of GPUs, cope size, data augmentation ... etc ,you can also change CCT hyperparameters if you wish, more details below. Then simply run:

python train.py --config configs/config.json

The log files and the .pth checkpoints will be saved in saved\EXP_NAME, to monitor the training using tensorboard, please run:

tensorboard --logdir saved

To resume training using a saved .pth model:

python train.py --config configs/config.json --resume saved/CCT/checkpoint.pth

Results: The results will be saved in saved as an html file, containing the validation results, and the name it will take is experim_name specified in configs/config.json.

Pseudo-labels

If you want to use image level labels to train the auxiliary labels as explained in section 3.3 of the paper. First generate the pseudo-labels using the code in pseudo_labels:

cd pseudo_labels
python run.py --voc12_root DATA_PATH

DATA_PATH must point to the folder containing JPEGImages in Pascal Voc dataset. The results will be saved in pseudo_labels/result/pseudo_labels as PNG files, the flag use_weak_labels needs to be set to True in the config file, and then we can train the model as detailed above.

Inference

For inference, we need a pretrained model, the jpg images we'd like to segment and the config used in training (to load the correct model and other parameters),

python inference.py --config config.json --model best_model.pth --images images_folder

The predictions will be saved as .png images in outputs\ is used, for Pacal VOC the default palette is:

Here are the flags available for inference:

--images       Folder containing the jpg images to segment.
--model        Path to the trained pth model.
--config       The config file used for training the model.

Pre-trained models

Pre-trained models can be downloaded here.

Citation ✏️ 📄

If you find this repo useful for your research, please consider citing the paper as follows:

@InProceedings{Ouali_2020_CVPR,
  author = {Ouali, Yassine and Hudelot, Celine and Tami, Myriam},
  title = {Semi-Supervised Semantic Segmentation With Cross-Consistency Training},
  booktitle = {The IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
  month = {June},
  year = {2020}
}

For any questions, please contact Yassine Ouali.

Config file details ⚙️

Bellow we detail the CCT parameters that can be controlled in the config file configs/config.json, the rest of the parameters are self-explanatory.

{
    "name": "CCT",                              
    "experim_name": "CCT",                             // The name the results will take (html and the folder in /saved)
    "n_gpu": 1,                                             // Number of GPUs
    "n_labeled_examples": 1000,                             // Number of labeled examples (choices are 60, 100, 200, 
                                                            // 300, 500, 800, 1000, 1464, and the splits are in dataloaders/voc_splits)
    "diff_lrs": true,
    "ramp_up": 0.1,                                         // The unsupervised loss will be slowly scaled up in the first 10% of Training time
    "unsupervised_w": 30,                                   // Weighting of the unsupervised loss
    "ignore_index": 255,
    "lr_scheduler": "Poly",
    "use_weak_labels": false,                               // If the pseudo-labels were generated, we can use them to train the aux. decoders
    "weakly_loss_w": 0.4,                                   // Weighting of the weakly-supervised loss
    "pretrained": true,

    "model":{
        "supervised": true,                                  // Supervised setting (training only on the labeled examples)
        "semi": false,                                       // Semi-supervised setting
        "supervised_w": 1,                                   // Weighting of the supervised loss

        "sup_loss": "CE",                                    // supervised loss, choices are CE and ab-CE = ["CE", "ABCE"]
        "un_loss": "MSE",                                    // unsupervised loss, choices are CE and KL-divergence = ["MSE", "KL"]

        "softmax_temp": 1,
        "aux_constraint": false,                             // Pair-wise loss (sup. mat.)
        "aux_constraint_w": 1,
        "confidence_masking": false,                         // Confidence masking (sup. mat.)
        "confidence_th": 0.5,

        "drop": 6,                                           // Number of DropOut decoders
        "drop_rate": 0.5,                                    // Dropout probability
        "spatial": true,
    
        "cutout": 6,                                         // Number of G-Cutout decoders
        "erase": 0.4,                                        // We drop 40% of the area
    
        "vat": 2,                                            // Number of I-VAT decoders
        "xi": 1e-6,                                          // VAT parameters
        "eps": 2.0,

        "context_masking": 2,                               // Number of Con-Msk decoders
        "object_masking": 2,                                // Number of Obj-Msk decoders
        "feature_drop": 6,                                  // Number of F-Drop decoders

        "feature_noise": 6,                                 // Number of F-Noise decoders
        "uniform_range": 0.3                                // The range of the noise
    },

Acknowledgements

  • Pseudo-labels generation is based on Jiwoon Ahn's implementation irn.
  • Code structure was based on Pytorch-Template
  • ResNet backbone was downloaded from torchcv
Comments
  • custom dataset with 4 classes

    custom dataset with 4 classes

    Thank you so far for all your great help. I have an issue that I also found in the closed issues, but for me it isn't solved. I have my own custom data set with 4 classes (background and 3 objects, labeled 0-3), so I changed num_classes = 4 in voc.py The results with training fully supervised are as shown in the image below. There is one class with an IoU of 0.0. image I ran multiple tests, using semi and weakly supervised settings, the results are unpredictable and often show 0.0 for the object classes. Only the background has good results. Is there something I need to adjust in the code?

    opened by SuzannaLin 22
  • Training error!

    Training error!

    I want to train VOC2012, but get the error below:

    Traceback (most recent call last):
      File "train.py", line 98, in <module>
        main(config, args.resume)
      File "train.py", line 82, in main
        trainer.train()
      File "/home/byronnar/bigfile/projects/CCT/base/base_trainer.py", line 91, in train
        results = self._train_epoch(epoch)
      File "/home/byronnar/bigfile/projects/CCT/trainer.py", line 76, in _train_epoch
        total_loss, cur_losses, outputs = self.model(x_l=input_l, target_l=target_l, x_ul=input_ul, curr_iter=batch_idx, target_ul=target_ul, epoch=epoch-1)
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
        result = self.forward(*input, **kwargs)
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 150, in forward
        return self.module(*inputs[0], **kwargs[0])
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/byronnar/bigfile/projects/CCT/models/model.py", line 93, in forward
        output_l = self.main_decoder(self.encoder(x_l))
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/byronnar/bigfile/projects/CCT/models/encoder.py", line 61, in forward
        x = self.psp(x)
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/byronnar/bigfile/projects/CCT/models/encoder.py", line 36, in forward
        align_corners=False) for stage in self.stages])
      File "/home/byronnar/bigfile/projects/CCT/models/encoder.py", line 36, in <listcomp>
        align_corners=False) for stage in self.stages])
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
        result = self.forward(*input, **kwargs)
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/container.py", line 92, in forward
        input = module(input)
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 547, in __call__
        result = self.forward(*input, **kwargs)
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py", line 81, in forward
        exponential_average_factor, self.eps)
      File "/opt/conda/lib/python3.6/site-packages/torch/nn/functional.py", line 1652, in batch_norm
        raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size))
    ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 512, 1, 1])
      0%|                                                                                                         | 0/9118 [00:02<?, ?it/s]
    

    How should I do? Thank you

    opened by Byronnar 12
  • Poor mIoU when training on 1464 images in a supervised manner

    Poor mIoU when training on 1464 images in a supervised manner

    Hi, I trained the model on the 1464 images in a supervised manner. The highest mIoU on Val is 67.00%, but the reported number in your paper is 69.4%. Here is my config.json file. Can you have a look at which part is wrong?

     {  "name": "CCT",
        "experim_name": "CCT",
        "n_gpu": 1,
        "n_labeled_examples": 1464,
        "diff_lrs": true,
        "ramp_up": 0.1,
        "unsupervised_w": 30,
        "ignore_index": 255,
        "lr_scheduler": "Poly",
        "use_weak_lables":false,
        "weakly_loss_w": 0.4,
        "pretrained": true,
        "model":{
            "supervised": true,
            "semi": false,
            "supervised_w": 1,
    
            "sup_loss": "CE",
            "un_loss": "MSE",
    
            "softmax_temp": 1,
            "aux_constraint": false,
            "aux_constraint_w": 1,
            "confidence_masking": false,
            "confidence_th": 0.5,
    
            "drop": 6,
            "drop_rate": 0.5,
            "spatial": true,
        
            "cutout": 6,
            "erase": 0.4,
        
            "vat": 2,
            "xi": 1e-6,
            "eps": 2.0,
    
            "context_masking": 2,
            "object_masking": 2,
            "feature_drop": 6,
    
            "feature_noise": 6,
            "uniform_range": 0.3
        },
    
    
        "optimizer": {
            "type": "SGD",
            "args":{
                "lr": 1e-2,
                "weight_decay": 1e-4,
                "momentum": 0.9
            }
        },
    
    
        "train_supervised": {
            "data_dir": "../data/VOC2012",
            "batch_size": 10,
            "crop_size": 320,
            "shuffle": true,
            "base_size": 400,
            "scale": true,
            "augment": true,
            "flip": true,
            "rotate": false,
            "blur": false,
            "split": "train_supervised",
            "num_workers": 8
        },
    
        "train_unsupervised": {
            "data_dir": "VOCtrainval_11-May-2012",
            "weak_labels_output": "pseudo_labels/result/pseudo_labels",
            "batch_size": 10,
            "crop_size": 320,
            "shuffle": true,
            "base_size": 400,
            "scale": true,
            "augment": true,
            "flip": true,
            "rotate": false,
            "blur": false,
            "split": "train_unsupervised",
            "num_workers": 8
        },
    
        "val_loader": {
            "data_dir": "../data/VOC2012",
            "batch_size": 1,
            "val": true,
            "split": "val",
            "shuffle": false,
            "num_workers": 4
        },
    
        "trainer": {
            "epochs": 80,
            "save_dir": "saved/",
            "save_period": 5,
      
            "monitor": "max Mean_IoU",
            "early_stop": 10,
            
            "tensorboardX": true,
            "log_dir": "saved/",
            "log_per_iter": 20,
    
            "val": true,
            "val_per_epochs": 5
        }
    }
    
    opened by xiaomengyc 11
  • Fail to reimplement your paper's result for semi-supervised.

    Fail to reimplement your paper's result for semi-supervised.

    I use the default config file to conduct experiments, but I only got 68.9mIoU for not adopting weak label and got 70.09mIoU for adopting weak label following your readme. These results are far lower than yours. My env is pytorch 1.7.0 and python 3.8.5. Could provide some advice?

    opened by TyroneLi 8
  • checkerboard

    checkerboard

    Hi Yassine, I am using the CCT model to train on a satellite dataset. The images are size 128x128. For some reason the predictions show a clear checkerboard pattern as shown in this example. Left: prediction, Right: ground truth. image Do you have any idea what causes this and how to avoid it?

    opened by SuzannaLin 7
  • inference with 4-channel model

    inference with 4-channel model

    Hi Yassine! I have managed to train a model with 4 channels, but the inference is not working. I get this error message:

    !python inference.py --config configs/config_70_30_sup_alti.json --model './saved/ABCE_70_30_sup_alti/best_model.pth' --output 'CCT_output/ABCE_70_30_sup_alti/Angers/' --images 'val/Angers/BDORTHO'

    Loading pretrained model:models/backbones/pretrained/3x3resnet50-imagenet.pth Traceback (most recent call last): File "inference.py", line 155, in main() File "inference.py", line 102, in main conf=config['model'], testing=True, pretrained = True) File "/home/scuypers/CCT_4/models/model.py", line 55, in init self.encoder = Encoder(pretrained=pretrained) File "/home/scuypers/CCT_4/models/encoder.py", line 49, in init model = ResNetBackbone(backbone='deepbase_resnet50_dilated8', pretrained=pretrained) File "/home/scuypers/CCT_4/models/backbones/resnet_backbone.py", line 145, in ResNetBackbone orig_resnet = deepbase_resnet50(pretrained=pretrained) File "/home/scuypers/CCT_4/models/backbones/resnet_models.py", line 227, in deepbase_resnet50 model = ModuleHelper.load_model(model, pretrained=pretrained) File "/home/scuypers/CCT_4/models/backbones/module_helper.py", line 109, in load_model model.load_state_dict(load_dict) File "/home/scuypers/.conda/envs/envCCT/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1483, in load_state_dict self.class.name, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading state_dict for ResNet: size mismatch for prefix.conv1.weight: copying a param with shape torch.Size([64, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 4, 3, 3]).

    opened by SuzannaLin 7
  • Loss function

    Loss function

    Thank you for your contribution. I want to know what the |Dl| and |Du| in your cross-entropy loss function formula and semi-supervised loss function formula represent, thank you for your answer.

    opened by yrcrcy 7
  • Reproducing Cross-domain Experiments

    Reproducing Cross-domain Experiments

    @yassouali @SuzannaLin Firstly, thanks a lot for your great work! But I still have problems in reproducing the cross-domain experiments.

    I have seen the implementation from https://github.com/tarun005/USSS_ICCV19, however, I notice that there are still many alternatives for the training schemes.

    During training, in each iteration (i.e. each execution of 'optimizer.step()'), do you train the network by forwarding inputs from both two datasets? Or by forwarding inputs from only one dataset in the current iteration, and then executing "optimizer.step()", and then forwarding inputs from the other dataset in the next iteration, and so on?

    Also, are there any tricks to deal with the data imbalance situation, e.g. the CamVid dataset only contains 367 images while the Cityscapes dataset has 2975 training images? (Just like constructing a training batch with different ratios for two datasets or other sorts of things)

    Besides, can you give some hints on hyperparameters, e.g. the number of training iterations, batch size, learning rate, weight decay?

    Looking forward to your reply! Thanks a lot!

    opened by X-Lai 6
  • low performance for full supervised setting

    low performance for full supervised setting

    I modified the config file to set the code to 'supervised' mode, but the result seems to be very low: Epoch : 40 | Mean_IoU : 0.699999988079071 | PixelAcc : 0.933 | Val Loss : 0.26163 compared with 'semi' mode:

    Epoch : 40 | Mean_IoU : 0.7120000123977661 | PixelAcc : 0.931 | Val Loss : 0.31637

    Note that I have changed the supervised list to the 10k+ augmented list in the 'supervised' setting. Did I miss something here?

    opened by zhangyuygss 6
  • How to obtain figure 2(d)

    How to obtain figure 2(d)

    Hi, thank you for your nice work!

    I want to know how to produce figure2(d)? There are 2048 channels for hidden representation, how to visualize? Thanks for your help!

    opened by reluuu 5
  • low performance in semi-supervised mode when employing weakly_loss with 2 gpus

    low performance in semi-supervised mode when employing weakly_loss with 2 gpus

    Thank you for your nice work!

    I tried to training the model with 1464 labeled samples in semi-supervised mode, and I used 2 gpus. I set the epoch as 80, and stop it after 50 epoch. But the performance is poor, e.g., miou at epoch 5 is 34.70% while at epoch 10 is 11.40%. image

    I set the 'use_weak_labels' as true, the 'drop_last' as false, and the rest are default.

    Have you ever met this situation?

    opened by wqhIris 5
2 Jul 19, 2022
Poisson Surface Reconstruction for LiDAR Odometry and Mapping

Poisson Surface Reconstruction for LiDAR Odometry and Mapping Surfels TSDF Our Approach Table: Qualitative comparison between the different mapping te

Photogrammetry & Robotics Bonn 305 Dec 21, 2022
Implementation for our ICCV 2021 paper: Dual-Camera Super-Resolution with Aligned Attention Modules

DCSR: Dual Camera Super-Resolution Implementation for our ICCV 2021 oral paper: Dual-Camera Super-Resolution with Aligned Attention Modules paper | pr

Tengfei Wang 110 Dec 20, 2022
Model Serving Made Easy

The easiest way to build Machine Learning APIs BentoML makes moving trained ML models to production easy: Package models trained with any ML framework

BentoML 4.4k Jan 08, 2023
Distributing reference energies for SMIRNOFF implementations

Warning: This code is currently experimental and under active development. Is it not yet suitable for distribution or use as reference implementation.

Open Force Field Initiative 1 Dec 07, 2021
Real time sign language recognition

The proposed work aims at converting american sign language gestures into English that can be understood by everyone in real time.

Mohit Kaushik 6 Jun 13, 2022
Self-Supervised Document-to-Document Similarity Ranking via Contextualized Language Models and Hierarchical Inference

Self-Supervised Document Similarity Ranking (SDR) via Contextualized Language Models and Hierarchical Inference This repo is the implementation for SD

Microsoft 36 Nov 28, 2022
Data and code for the paper "Importance of Kernel Bandwidth in Quantum Machine Learning"

Reproducibility materials for "Importance of Kernel Bandwidth in Quantum Machine Learning" Repo structure: code contains Python scripts used to genera

Ruslan Shaydulin 3 Oct 23, 2022
The Official TensorFlow Implementation for SPatchGAN (ICCV2021)

SPatchGAN: Official TensorFlow Implementation Paper "SPatchGAN: A Statistical Feature Based Discriminator for Unsupervised Image-to-Image Translation"

39 Dec 30, 2022
Traductor de lengua de señas al español basado en Python con Opencv y MedaiPipe

Traductor de señas Traductor de lengua de señas al español basado en Python con Opencv y MedaiPipe Requerimientos 🔧 Python 3.8 o inferior para evitar

Jahaziel Hernandez Hoyos 3 Nov 12, 2022
Progressive Image Deraining Networks: A Better and Simpler Baseline

Progressive Image Deraining Networks: A Better and Simpler Baseline [arxiv] [pdf] [supp] Introduction This paper provides a better and simpler baselin

190 Dec 01, 2022
This Repostory contains the pretrained DTLN-aec model for real-time acoustic echo cancellation.

This Repostory contains the pretrained DTLN-aec model for real-time acoustic echo cancellation.

Nils L. Westhausen 182 Jan 07, 2023
Prevent `CUDA error: out of memory` in just 1 line of code.

🐨 Koila Koila solves CUDA error: out of memory error painlessly. Fix it with just one line of code, and forget it. 🚀 Features 🙅 Prevents CUDA error

RenChu Wang 1.7k Jan 02, 2023
This repository is for Contrastive Embedding Distribution Refinement and Entropy-Aware Attention Network (CEDR)

CEDR This repository is for Contrastive Embedding Distribution Refinement and Entropy-Aware Attention Network (CEDR) introduced in the following paper

phoenix 3 Feb 27, 2022
[内测中]前向式Python环境快捷封装工具,快速将Python打包为EXE并添加CUDA、NoAVX等支持。

QPT - Quick packaging tool 快捷封装工具 GitHub主页 | Gitee主页 QPT是一款可以“模拟”开发环境的多功能封装工具,最短只需一行命令即可将普通的Python脚本打包成EXE可执行程序,并选择性添加CUDA和NoAVX的支持,尽可能兼容更多的用户环境。 感觉还可

QPT Family 545 Dec 28, 2022
Meta graph convolutional neural network-assisted resilient swarm communications

Resilient UAV Swarm Communications with Graph Convolutional Neural Network This repository contains the source codes of Resilient UAV Swarm Communicat

62 Dec 06, 2022
Code for the SIGGRAPH 2021 paper "Consistent Depth of Moving Objects in Video".

Consistent Depth of Moving Objects in Video This repository contains training code for the SIGGRAPH 2021 paper "Consistent Depth of Moving Objects in

Google 203 Jan 05, 2023
PyTorch implementation of ENet

PyTorch-ENet PyTorch (v1.1.0) implementation of ENet: A Deep Neural Network Architecture for Real-Time Semantic Segmentation, ported from the lua-torc

David Silva 333 Dec 29, 2022
Meta-TTS: Meta-Learning for Few-shot SpeakerAdaptive Text-to-Speech

Meta-TTS: Meta-Learning for Few-shot SpeakerAdaptive Text-to-Speech This repository is the official implementation of "Meta-TTS: Meta-Learning for Few

Sung-Feng Huang 128 Dec 25, 2022
Code implementing "Improving Deep Learning Interpretability by Saliency Guided Training"

Saliency Guided Training Code implementing "Improving Deep Learning Interpretability by Saliency Guided Training" by Aya Abdelsalam Ismail, Hector Cor

8 Sep 22, 2022