PyTorch implementation of "A Full-Band and Sub-Band Fusion Model for Real-Time Single-Channel Speech Enhancement."

Overview

FullSubNet

Platform Python version Pytorch Version GitHub repo size

This Git repository for the official PyTorch implementation of "A Full-Band and Sub-Band Fusion Model for Real-Time Single-Channel Speech Enhancement", submitted to ICASSP 2021.

🌼 See the demo page in this link.

workflow

fullsubnet_result

You can use all of these things:

  • Available models
    • FullSubNet
    • Delayed Sub-Band LSTM
    • Fullband LSTM Baseline
  • Available Datasets
    • Deep Noise Suppression Challenge - INTERSPEECH 2020
    • Demand + CSTR VCTK Corpus

Documentation

Citation

If you use this code for your research, please consider citing:

@misc{hao2020fullsubnet,
      title={FullSubNet: A Full-Band and Sub-Band Fusion Model for Real-Time Single-Channel Speech Enhancement}, 
      author={Xiang Hao and Xiangdong Su and Radu Horaud and Xiaofei Li},
      year={2020},
      eprint={2010.15508},
      archivePrefix={arXiv},
      primaryClass={eess.AS}
}

License

License: MIT

Comments
  • Will there be a 44.1 or 48kHz pre-trained model released?

    Will there be a 44.1 or 48kHz pre-trained model released?

    Hi guys! Your work is absolutely amazing and inspiring. Trying your model on data that was not on your datasets, it performed very well.

    I did have to convert to 16kHz first since as I understand, it was trained on 16kHz?

    My question is: will you guys release a 44.1kHz or 48kHz pre-trained model in the near future or not?

    While I could try my hand at training that model, I'm nowhere near experienced as you guys at this and feel that I'd miss so many things and would not be able to create a model that generalizes as well as you have, or maybe you can prove me wrong.

    opened by youssefavx 8
  • How to Use Pretrained, pickled model in Releases with No Documentation?

    How to Use Pretrained, pickled model in Releases with No Documentation?

    @haoxiangsnr Hi, it's already April and there still isn't any documentation for the pretrained model in releases. How do we go about using the pickled file data.pkl for inference? Thanks!

    opened by uwstudent123 3
  • Any plans about releasing the pretrained models?

    Any plans about releasing the pretrained models?

    First, thanks for the open-source implementation. I saw that the pretrained model is on your TO-DO list in the baseline readme. Do you have any plans for the releasing schedule? Thanks a lot!

    Type: Documentation Priority: Critical Status: In Progress 
    opened by 121898 3
  • [Question] The real-time speech enhance is poor, Need help!

    [Question] The real-time speech enhance is poor, Need help!

    At present, I have completed the modification of cumulative_laplace_norm, and then sent it to the network in batches through Stft for streaming inference, and obtained the Hidden state and cell state of the network. But the results are poor. I looked at the previous issue, you said that you need to replace LSTM with LSTMCell, what is the difference between the two? Why do you convert it? Pictures are as follow:

    • Fullwav_load:

    image

    • Stream_load:

    image

    Metrics: (My experiment)

    ... | NB_PESQ | WB_PESQ | SI_SDR | STOI -- | -- | -- | -- | -- FullSubNet-cum| Epoch 130 | 3.364 | 2.861 | 17.65 | 96.25 FullSubNet-cum-stream | | 3.155 | 2.466 | 14.77 | 94.30

    opened by Kayden-Wang 2
  • Training and Validation cRM Mismatch

    Training and Validation cRM Mismatch

    During training, with batch size 10, we observe the following shapes:

    cRM torch.Size([10, 128, 193, 2])
    noisy_real torch.Size([10, 257, 193])
    noisy_imag torch.Size([10, 257, 193])
    

    However, during validation, we see:

    cRM torch.Size([1, 257, 626, 2])
    noisy_real torch.Size([1, 257, 626])
    noisy_imag torch.Size([1, 257, 626])
    

    Why is dimension 1 and 2 of the cRM different during training but not during validation?

    Without these, I am unable to get the enhanced waveform during training, since this calculation fails:

    cRM = decompress_cIRM(cRM)
    
    enhanced_real = cRM[..., 0] * noisy_real - cRM[..., 1] * noisy_imag
    enhanced_imag = cRM[..., 1] * noisy_real + cRM[..., 0] * noisy_imag
    
    opened by jhkonan 2
  • Sub-band model

    Sub-band model

    Hi, thanks for sharing this excellent project. I took a great interest in the model(Delayed Sub-Band LSTM) posted in the paper. I've tried my hard to reproduce this model, but still can't get a well performance. Can you release the code about your sub-band models ? Thanks a lot!

    opened by HWLhsu 2
  • error !

    error !

    hi @haoxiangsnr, I run pre-training on Google Colab followed by this link: https://github.com/haoxiangsnr/FullSubNet/blob/main/docs/getting_started.md but I got an issue like this: command: !python inference.py -C /content/FullSubNet/recipes/dns_interspeech_2020/fullband_baseline/inference.toml -M /content/drive/MyDrive/Colab_Notebooks/FullSubNet/fullsubnet_best_model_58epochs.tar -O /content/drive/MyDrive/Colab_Notebooks/FullSubNet/output_dir result:

    Loading inference dataset... Loading model... Traceback (most recent call last): File "inference.py", line 32, in main(configuration, checkpoint_path, output_dir) File "inference.py", line 16, in main output_dir File "/content/FullSubNet/recipes/dns_interspeech_2020/inferencer.py", line 50, in init super().init(config, checkpoint_path, output_dir) File "/content/FullSubNet/audio_zen/inferencer/base_inferencer.py", line 27, in init self.model, epoch = self._load_model(config["model"], checkpoint_path, self.device) File "/content/FullSubNet/audio_zen/inferencer/base_inferencer.py", line 91, in _load_model model = initialize_module(model_config["path"], args=model_config["args"], initialize=True) File "/content/FullSubNet/audio_zen/utils.py", line 87, in initialize_module module = importlib.import_module(module_path) File "/usr/lib/python3.7/importlib/init.py", line 127, in import_module return _bootstrap._gcd_import(name[level:], package, level) File "", line 1006, in _gcd_import File "", line 983, in _find_and_load File "", line 965, in _find_and_load_unlocked ModuleNotFoundError: No module named 'model'

    please help me a hand, thanks a lot

    opened by vinh1988 2
  • # of Epochs for training a FullBand baseline model

    # of Epochs for training a FullBand baseline model

    Hello,

    My question is about the training. I am just trying to replicate the results with the DNS challange dataset. The number of epochs for the fullband_baseline.toml file is set to 9999 which seems to be a "little" high :) Could you please shed some light on it ? Is this the default value ?

    Thank you for sharing your work.

    B.R.

    opened by kadir-gunel 2
  • Not an issue, but I wanted to show you some impressive results...

    Not an issue, but I wanted to show you some impressive results...

    Greetings! I have been helping restore an old (1980) recording of a recording of an interview with an elderly person relaying stories of the early history of the Baha'i Faith in the U.S.. I have surveyed and tried a number of machine learning methods to denoise and enhance the recording. I just finished processing with your FullSubNet today, and it far surpassed the other ones I tried out in removing the recording noise to make the voice easier to understand. Enclosed is a graphic of the comparison of the frequency spectograms of the 3 files where the top one is the original recording, the middle is the result of using another method (that was dozens of times slower than yours and principally dealt with white noise) and at the bottom is the result of FullSubNet using your pretrained checkpoint. The reduction in noise going from the original recording to what your method produced was astonishing! I can check to see if the archivist would allow me to provide the recordings (so if you are interested in getting them, please let me know). Thanks so much for making the code available here! Regards -Steve

    compare_orig_n2n_fullsubnet

    opened by sjscotti 2
  • The batch size for the validation stage must be one

    The batch size for the validation stage must be one

    Hi Hao Xiang

    I can't run the demo due to the limit that gpu usage must be over 20 percentage. Therefore, I found the validation set the batchsize

    must be one. Can I change the batchsize in validation?

    Hope your reply!

    opened by zc1616 2
  • Questions about the training process

    Questions about the training process

    Very interesting project. Thank you for sharing.

    I have a quastion - what are the text files noise.txt, rir.txt and clean_0.6.txt? Are they part of the original dataset or dedicated files that you've created for the training?

    Another qaustion - is it possible to run it on Windows run without the "dist" feature (using a single GPU)? (I mean after commecting all parts related to the 'dist')

    opened by ahikaml 2
  • 有关look-ahead的疑问

    有关look-ahead的疑问

    hi,我理解的look-ahead是使用多少未来帧,可是我在看您代码的过程中发现是在后面补两帧0,noisy_mag = F.pad(noisy_mag, [0, self.look_ahead]),最后只取第二帧之后的数据output = sb_mask[:, :, :, self.look_ahead:]

    是不是在推理的过程中,不需要补0,而是直接处理3帧,结果出一帧(output = sb_mask[:, :, :, self.look_ahead:])之后,然后流式的一帧进 一帧出

    opened by LXP-Never 0
  • Normalization

    Normalization

    Dear authors,

    I notice that in snr_mix, the signal dBFs will be [-35, -15], meaning the intensity can change randomly. However, in inference.py, normalization is applied, which is weird. From my understanding, we either normalize all data or don't normalize all data, but why do you choose to normalize it in inference while discarding it during training? Maybe I have some misunderstanding, please correct me if possible.

    Best

    opened by lixinghe1999 2
  • Unable to fine-tune pre-trained model (fullsubnet_best_model_58epochs.tar)

    Unable to fine-tune pre-trained model (fullsubnet_best_model_58epochs.tar)

    I am trying to continue training the pre-trained FullSubNet model provided by this repo:

    fullsubnet_best_model_58epochs.tar

    I can confirm the model works for inference. However, I run into issues loading the state dictionary for training based on how the model was saved.

    Here is the error in full:

    (FullSubNet) $ torchrun --standalone --nnodes=1 --nproc_per_node=1 train.py -C fullsubnet/train.toml -R
    1 process initialized.
    Traceback (most recent call last):
      File "/home/github/FullSubNet/recipes/dns_interspeech_2020/train.py", line 99, in <module>
        entry(local_rank, configuration, args.resume, args.only_validation)
      File "/home/github/FullSubNet/recipes/dns_interspeech_2020/train.py", line 59, in entry
        trainer = trainer_class(
      File "/home/github/FullSubNet/recipes/dns_interspeech_2020/fullsubnet/trainer.py", line 17, in __init__
        super().__init__(dist, rank, config, resume, only_validation, model, loss_function, optimizer)
      File "/home/github/FullSubNet/audio_zen/trainer/base_trainer.py", line 84, in __init__
        self._resume_checkpoint()
      File "/home/github/FullSubNet/audio_zen/trainer/base_trainer.py", line 153, in _resume_checkpoint
        self.scaler.load_state_dict(checkpoint["scaler"])
      File "/home/anaconda3/envs/FullSubNet/lib/python3.9/site-packages/torch/cuda/amp/grad_scaler.py", line 502, in load_state_dict
        raise RuntimeError("The source state dict is empty, possibly because it was saved "
    RuntimeError: The source state dict is empty, possibly because it was saved from a disabled instance of GradScaler.
    ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 1822537) of binary: /home/anaconda3/envs/FullSubNet/bin/python
    Traceback (most recent call last):
      File "/home/anaconda3/envs/FullSubNet/bin/torchrun", line 33, in <module>
        sys.exit(load_entry_point('torch==1.11.0', 'console_scripts', 'torchrun')())
      File "/home/anaconda3/envs/FullSubNet/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 345, in wrapper
        return f(*args, **kwargs)
      File "/home/anaconda3/envs/FullSubNet/lib/python3.9/site-packages/torch/distributed/run.py", line 724, in main
        run(args)
      File "/home/anaconda3/envs/FullSubNet/lib/python3.9/site-packages/torch/distributed/run.py", line 715, in run
        elastic_launch(
      File "/home/anaconda3/envs/FullSubNet/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 131, in __call__
        return launch_agent(self._config, self._entrypoint, list(args))
      File "/home/anaconda3/envs/FullSubNet/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 245, in launch_agent
        raise ChildFailedError(
    torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
    ============================================================
    train.py FAILED
    ------------------------------------------------------------
    Failures:
      <NO_OTHER_FAILURES>
    ------------------------------------------------------------
    Root Cause (first observed failure):
    [0]:
      time      : 2022-06-19_21:25:20
      host      : host-server
      rank      : 0 (local_rank: 0)
      exitcode  : 1 (pid: 1822537)
      error_file: <N/A>
      traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
    ============================================================
    
    

    Are there specific modifications that need to be made to continue training?

    Thank you for your help.

    opened by jhkonan 1
  • 仓库里模型比比赛时提交的结果好很多, 这个是有什么不同吗

    仓库里模型比比赛时提交的结果好很多, 这个是有什么不同吗

    您好, 我注意到FullSubNet 参加DNS2021 的成绩是在dev_testset MOS 3.06. https://www.microsoft.com/en-us/research/uploads/prod/2020/12/Challenge_Results.pdf

    但是我实际测试仓库的的模型, 指标是:3.44 ==> DNSMOS_SIG : 3.790972579288795 ==> DNSMOS_BAK : 4.130271822666175 ==> DNSMOS_OVR : 3.441530761341177

    请问是当时提交的模型与这里的不一样吗?

    感谢!

    opened by lhwcv 0
  • error

    error

    con somebody help me fix this

    command : python inference.py -C C:\Users\punnp\Desktop\FullSubNet\recipes\dns_interspeech_2020\fullsubnet/inference.toml -M C:\Users\punnp\Desktop\FullSubNet\model\fullsubnet_best_model_58epochs.tar -O C:\Users\punnp\Desktop\FullSubNet\output

    result : Traceback (most recent call last): File "C:\Users\punnp\anaconda3\envs\FullSubNet\lib\site-packages\toml\decoder.py", line 512, in loads multibackslash) File "C:\Users\punnp\anaconda3\envs\FullSubNet\lib\site-packages\toml\decoder.py", line 778, in load_line value, vtype = self.load_value(pair[1], strictly_valid) File "C:\Users\punnp\anaconda3\envs\FullSubNet\lib\site-packages\toml\decoder.py", line 880, in load_value return (self.load_array(v), "array") File "C:\Users\punnp\anaconda3\envs\FullSubNet\lib\site-packages\toml\decoder.py", line 1026, in load_array nval, ntype = self.load_value(a[i]) File "C:\Users\punnp\anaconda3\envs\FullSubNet\lib\site-packages\toml\decoder.py", line 866, in load_value raise ValueError("Reserved escape sequence used") ValueError: Reserved escape sequence used

    During handling of the above exception, another exception occurred:

    Traceback (most recent call last): File "inference.py", line 30, in configuration = toml.load(config_path.as_posix()) File "C:\Users\punnp\anaconda3\envs\FullSubNet\lib\site-packages\toml\decoder.py", line 134, in load return loads(ffile.read(), _dict, decoder) File "C:\Users\punnp\anaconda3\envs\FullSubNet\lib\site-packages\toml\decoder.py", line 514, in loads raise TomlDecodeError(str(err), original, pos) toml.decoder.TomlDecodeError: Reserved escape sequence used (line 20 column 1 char 241)

    opened by MisuyaXZ 0
Releases(v0.2)
  • v0.2(Jan 16, 2021)

    Checkpoints

    This page has two released model checkpoints. All checkpoints include "model_state_dict", "optimizer_state_dict", and some other meta information.

    The first model checkpoint is the original model checkpoint at the 58th epoch. The performance is shown in this table:

    | | With Reverb | | | | No Reverb | | | | |:----------:|:-----------:|:-------:|:------:|:-----:|:---------:|:-------:|:------:|:-----:| | Method | WB-PESQ | NB-PESQ | SI-SDR | STOI | WB-PESQ | NB-PESQ | SI-SDR | STOI | | FullSubNet | 2.987 | 3.496 | 15.756 | 0.926 | 2.889 | 3.385 | 17.635 | 0.964 |

    In addition, some people are interested in the performance when using cumulative normalization. The below one is a pre-trained FullSubNet using cumulative normalization:

    | | With Reverb | | | | No Reverb | | | | |:----------:|:-----------:|:-------:|:------:|:-----:|:---------:|:-------:|:------:|:-----:| | Method | WB-PESQ | NB-PESQ | SI-SDR | STOI | WB-PESQ | NB-PESQ | SI-SDR | STOI | |FullSubNet (Cumulative Norm)| 2.978| 3.503 | 15.820 | 0.928 | 2.863| 3.376 | 17.913 | 0.964 |

    If you want to inference or fine-tune based on these checkpoints, please check the usage in the documents.

    Room Impulse Responses

    As mentioned in the paper, the room impulse responses (RIRs) come from the Multichannel Impulse Response Database and the Reverb Challenge dataset. Please download the zip package "RIR (Multichannel Impulse Response Database + The REVERB challenge).zip" if you would like to retrain the FullSubNet.

    Note that the zip package includes a folder "rir" and a file "rir.txt." The folder "rir" contains all separated single-channel RIRs extracted from the above two datasets. The suffix (e.g., "m_") of the filename is the index of a microphone. The file "rir.txt" is just a path list of all RIRs. Please modify it to fit your case before you use it.

    For some cases, if you would like to extract channel by yourself, you can download these RIRs from pages:

    1. Multichannel Impulse Response Database: https://www.eng.biu.ac.il/~gannot/RIR_DATABASE/
    2. The REVERB challenge data: https://reverb2014.dereverberation.com/tools/reverb_tools_for_Generate_mcTrainData.tgz and https://reverb2014.dereverberation.com/tools/reverb_tools_for_Generate_SimData.tgz

    Enjoy ~

    Source code(tar.gz)
    Source code(zip)
    cum_fullsubnet_best_model_218epochs.tar(64.53 MB)
    fullsubnet_best_model_58epochs.tar(64.53 MB)
    RIR.Multichannel.Impulse.Response.Database.+.The.REVERB.challenge.zip(10.68 MB)
Owner
郝翔
Audio/Speech Signal Processing
郝翔
Methods to get the probability of a changepoint in a time series.

Bayesian Changepoint Detection Methods to get the probability of a changepoint in a time series. Both online and offline methods are available. Read t

Johannes Kulick 554 Dec 30, 2022
Voice control for Garry's Mod

WIP: Talonvoice GMod integrations Very work in progress voice control demo for Garry's Mod. HOWTO Install https://talonvoice.com/ Press https://i.imgu

Meta Construct 5 Nov 15, 2022
Official PyTorch implemention of our paper "Learning to Rectify for Robust Learning with Noisy Labels".

WarPI The official PyTorch implemention of our paper "Learning to Rectify for Robust Learning with Noisy Labels". Run python main.py --corruption_type

Haoliang Sun 3 Sep 03, 2022
Supervised & unsupervised machine-learning techniques are applied to the database of weighted P4s which admit Calabi-Yau hypersurfaces.

Weighted Projective Spaces ML Description: The database of 5-vectors describing 4d weighted projective spaces which admit Calabi-Yau hypersurfaces are

Ed Hirst 3 Sep 08, 2022
Pun Detection and Location

Pun Detection and Location “The Boating Store Had Its Best Sail Ever”: Pronunciation-attentive Contextualized Pun Recognition Yichao Zhou, Jyun-yu Jia

lawson 3 May 13, 2022
Some pre-commit hooks for OpenMMLab projects

pre-commit-hooks Some pre-commit hooks for OpenMMLab projects. Using pre-commit-hooks with pre-commit Add this to your .pre-commit-config.yaml - rep

OpenMMLab 16 Nov 29, 2022
Implementation of the federated dual coordinate descent (FedDCD) method.

FedDCD.jl Implementation of the federated dual coordinate descent (FedDCD) method. Installation To install, just call Pkg.add("https://github.com/Zhen

Zhenan Fan 6 Sep 21, 2022
Software & Hardware to do multi color printing with Sharpies

3D Print Colorizer is a combination of 3D printed parts and a Cura plugin which allows anyone with an Ender 3 like 3D printer to produce multi colored

343 Jan 06, 2023
A practical ML pipeline for data labeling with experiment tracking using DVC.

Auto Label Pipeline A practical ML pipeline for data labeling with experiment tracking using DVC Goals: Demonstrate reproducible ML Use DVC to build a

Todd Cook 4 Mar 08, 2022
Unofficial PyTorch implementation of MobileViT.

MobileViT Overview This is a PyTorch implementation of MobileViT specified in "MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Tr

Chin-Hsuan Wu 348 Dec 23, 2022
The full training script for Enformer (Tensorflow Sonnet) on TPU clusters

Enformer TPU training script (wip) The full training script for Enformer (Tensorflow Sonnet) on TPU clusters, in an effort to migrate the model to pyt

Phil Wang 10 Oct 19, 2022
This repository is for Contrastive Embedding Distribution Refinement and Entropy-Aware Attention Network (CEDR)

CEDR This repository is for Contrastive Embedding Distribution Refinement and Entropy-Aware Attention Network (CEDR) introduced in the following paper

phoenix 3 Feb 27, 2022
OpenMMLab Computer Vision Foundation

English | 简体中文 Introduction MMCV is a foundational library for computer vision research and supports many research projects as below: MMCV: OpenMMLab

OpenMMLab 4.6k Jan 09, 2023
PyTorch Implementation for Deep Metric Learning Pipelines

Easily Extendable Basic Deep Metric Learning Pipeline Karsten Roth ([email 

Karsten Roth 543 Jan 04, 2023
PyTorch code for our paper "Gated Multiple Feedback Network for Image Super-Resolution" (BMVC2019)

Gated Multiple Feedback Network for Image Super-Resolution This repository contains the PyTorch implementation for the proposed GMFN [arXiv]. The fram

Qilei Li 66 Nov 03, 2022
DanceTrack: Multiple Object Tracking in Uniform Appearance and Diverse Motion

DanceTrack DanceTrack is a benchmark for tracking multiple objects in uniform appearance and diverse motion. DanceTrack provides box and identity anno

260 Dec 28, 2022
Unofficial pytorch implementation of 'Image Inpainting for Irregular Holes Using Partial Convolutions'

pytorch-inpainting-with-partial-conv Official implementation is released by the authors. Note that this is an ongoing re-implementation and I cannot f

Naoto Inoue 525 Jan 01, 2023
A TensorFlow implementation of Neural Program Synthesis from Diverse Demonstration Videos

ViZDoom http://vizdoom.cs.put.edu.pl ViZDoom allows developing AI bots that play Doom using only the visual information (the screen buffer). It is pri

Hyeonwoo Noh 1 Aug 19, 2020
LSSY量化交易系统

LSSY量化交易系统 该项目是本人3年来研究量化慢慢积累开发的一套系统,属于早期作品慢慢修改而来,仅供学习研究,回测分析,实盘交易部分未公开

55 Oct 04, 2022
机器学习、深度学习、自然语言处理等人工智能基础知识总结。

说明 机器学习、深度学习、自然语言处理基础知识总结。 目前主要参考李航老师的《统计学习方法》一书,也有一些内容例如XGBoost、聚类、深度学习相关内容、NLP相关内容等是书中未提及的。

Peter 445 Dec 12, 2022