Implementation of 'lightweight' GAN, proposed in ICLR 2021, in Pytorch. High resolution image generations that can be trained within a day or two

Overview

512x512 flowers after 12 hours of training, 1 gpu

256x256 flowers after 12 hours of training, 1 gpu

Pizza

'Lightweight' GAN

PyPI version

Implementation of 'lightweight' GAN proposed in ICLR 2021, in Pytorch. The main contributions of the paper is a skip-layer excitation in the generator, paired with autoencoding self-supervised learning in the discriminator. Quoting the one-line summary "converge on single gpu with few hours' training, on 1024 resolution sub-hundred images".

Install

$ pip install lightweight-gan

Use

One command

$ lightweight_gan --data ./path/to/images --image-size 512

Model will be saved to ./models/{name} every 1000 iterations, and samples from the model saved to ./results/{name}. name will be default, by default.

Training settings

Pretty self explanatory for deep learning practitioners

$ lightweight_gan \
    --data ./path/to/images \
    --name {name of run} \
    --batch-size 16 \
    --gradient-accumulate-every 4 \
    --num-train-steps 200000

Augmentation

Augmentation is essential for Lightweight GAN to work effectively in a low data setting

By default, the augmentation types is set to translation and cutout, with color omitted. You can include color as well with the following.

$ lightweight_gan --data ./path/to/images --aug-prob 0.25 --aug-types [translation,cutout,color]

Test augmentation

You can test and see how your images will be augmented before it pass into a neural network (if you use augmentation). Let's see how it works on this image:

Basic usage

Base code to augment your image, define --aug-test and put path to your image into --data:

lightweight_gan \
    --aug-test \
    --data ./path/to/lena.jpg

After this will be created the file lena_augs.jpg that will be look something like this:

Options

You can use some options to change result:

  • --image-size 256 to change size of image tiles in the result. Default: 256.
  • --aug-type [color,cutout,translation] to combine several augmentations. Default: [cutout,translation].
  • --batch-size 10 to change count of images in the result image. Default: 10.
  • --num-image-tiles 5 to change count of tiles in the result image. Default: 5.

Try this command:

lightweight_gan \
    --aug-test \
    --data ./path/to/lena.jpg \
    --batch-size 16 \
    --num-image-tiles 4 \
    --aug-types [color,translation]

result wil be something like that:

Types of augmentations

This library contains several types of embedded augmentations.
Some of these works by default, some of these can be controlled from a command as options in the --aug-types:

  • Horizontal flip (work by default, not under control, runs in the AugWrapper class);
  • color randomly change brightness, saturation and contrast;
  • cutout creates random black boxes on the image;
  • offset randomly moves image by x and y-axis with repeating image;
    • offset_h only by an x-axis;
    • offset_v only by a y-axis;
  • translation randomly moves image on the canvas with black background;

Full setup of augmentations is --aug-types [color,cutout,offset,translation].
General recommendation is using suitable augs for your data and as many as possible, then after sometime of training disable most destructive (for image) augs.

Color

Cutout

Offset

Only x-axis:

Only y-axis:

Translation

Mixed precision

You can turn on automatic mixed precision with one flag --amp

You should expect it to be 33% faster and save up to 40% memory

Multiple GPUs

Also one flag to use --multi-gpus

Generating

Once you have finished training, you can generate samples with one command. You can select which checkpoint number to load from. If --load-from is not specified, will default to the latest.

$ lightweight_gan \
  --name {name of run} \
  --load-from {checkpoint num} \
  --generate \
  --generate-types {types of result, default: [default,ema]} \
  --num-image-tiles {count of image result}

After run this command you will get folder near results image folder with postfix "-generated-{checkpoint num}".

You can also generate interpolations

$ lightweight_gan --name {name of run} --generate-interpolation

Show progress

After creating several checkpoints of model you can generate progress as sequence images by command:

$ lightweight_gan \
  --name {name of run} \
  --show-progress \
  --generate-types {types of result, default: [default,ema]} \
  --num-image-tiles {count of image result}

After running this command you will get a new folder in the results folder, with postfix "-progress". You can convert the images to a video with ffmpeg using the command "ffmpeg -framerate 10 -pattern_type glob -i '*-ema.jpg' out.mp4".

Show progress gif demonstration

Show progress video demonstration

Discriminator output size

The author has kindly let me know that the discriminator output size (5x5 vs 1x1) leads to different results on different datasets. (5x5 works better for art than for faces, as an example). You can toggle this with a single flag

# disc output size is by default 1x1
$ lightweight_gan --data ./path/to/art --image-size 512 --disc-output-size 5

Attention

You can add linear + axial attention to specific resolution layers with the following

# make sure there are no spaces between the values within the brackets []
$ lightweight_gan --data ./path/to/images --image-size 512 --attn-res-layers [32,64] --aug-prob 0.25

Bonus

You can also train with transparent images

$ lightweight_gan --data ./path/to/images --transparent

Or greyscale

$ lightweight_gan --data ./path/to/images --greyscale

Alternatives

If you want the current state of the art GAN, you can find it at https://github.com/lucidrains/stylegan2-pytorch

Citations

@inproceedings{
    anonymous2021towards,
    title={Towards Faster and Stabilized {\{}GAN{\}} Training for High-fidelity Few-shot Image Synthesis},
    author={Anonymous},
    booktitle={Submitted to International Conference on Learning Representations},
    year={2021},
    url={https://openreview.net/forum?id=1Fqg133qRaI},
    note={under review}
}
@inproceedings{
    anonymous2021global,
    title={Global Self-Attention Networks},
    author={Anonymous},
    booktitle={Submitted to International Conference on Learning Representations},
    year={2021},
    url={https://openreview.net/forum?id=KiFeuZu24k},
    note={under review}
}
@misc{cao2020global,
    title={Global Context Networks},
    author={Yue Cao and Jiarui Xu and Stephen Lin and Fangyun Wei and Han Hu},
    year={2020},
    eprint={2012.13375},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}
@misc{qin2020fcanet,
    title={FcaNet: Frequency Channel Attention Networks},
    author={Zequn Qin and Pengyi Zhang and Fei Wu and Xi Li},
    year={2020},
    eprint={2012.11879},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}
@misc{sinha2020topk,
    title={Top-k Training of GANs: Improving GAN Performance by Throwing Away Bad Samples},
    author={Samarth Sinha and Zhengli Zhao and Anirudh Goyal and Colin Raffel and Augustus Odena},
    year={2020},
    eprint={2002.06224},
    archivePrefix={arXiv},
    primaryClass={stat.ML}
}

What I cannot create, I do not understand - Richard Feynman

Comments
  • Troubles with global context module in 0.15.0

    Troubles with global context module in 0.15.0

    @lucidrains

    After update to this version https://github.com/lucidrains/lightweight-gan/releases/tag/0.15.0 I cant continue train my network and did start in from zero. Previous version was in state 117k batches by 4 (468k images, around 66 hours of trainig) image and was pretty good. In new version 0.15.0 on same dataset with same parameters (--image-size 1024 --aug-types [color,offset_h] --aug-prob 1 --amp --batch-size 7) after 77k batches by 7 (539k images, around 49 hours of training) I see some bugs like oil puddle. Did you meet this or do you know how avoid this?

    image

    In previous version with sle-spatial I didnt meet something like this.

    opened by Dok11 9
  • What is sle_spatial?

    What is sle_spatial?

    I have seen this function argument mentioned in this issue:

    https://github.com/lucidrains/lightweight-gan/issues/14#issuecomment-733432989

    What is sle_spatial?

    opened by woctezuma 8
  • unable to load save model. please try downgrading the package to the version specified by the saved model

    unable to load save model. please try downgrading the package to the version specified by the saved model

    I have the following problem since today. How to do/solve this?

    continuing from previous epoch - 118 loading from version 0.21.4 unable to load save model. please try downgrading the package to the version specified by the saved model Traceback (most recent call last): File "/opt/conda/bin/lightweight_gan", line 8, in sys.exit(main()) File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/cli.py", line 193, in main fire.Fire(train_from_folder) File "/opt/conda/lib/python3.8/site-packages/fire/core.py", line 141, in Fire component_trace = _Fire(component, args, parsed_flag_args, context, name) File "/opt/conda/lib/python3.8/site-packages/fire/core.py", line 466, in _Fire component, remaining_args = _CallAndUpdateTrace( File "/opt/conda/lib/python3.8/site-packages/fire/core.py", line 681, in _CallAndUpdateTrace component = fn(*varargs, **kwargs) File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/cli.py", line 184, in train_from_folder run_training(0, 1, model_args, data, load_from, new, num_train_steps, name, seed, use_aim, aim_repo, aim_run_hash) File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/cli.py", line 59, in run_training model.load(load_from) File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/lightweight_gan.py", line 1603, in load raise e File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/lightweight_gan.py", line 1600, in load self.GAN.load_state_dict(load_data['GAN']) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1497, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for LightweightGAN: Missing key(s) in state_dict: "G.layers.0.0.2.1.weight", "G.layers.0.0.2.1.bias", "G.layers.0.0.4.weight", "G.layers.0.0.4.bias", "G.layers.0.0.4.running_mean", "G.layers.0.0.4.running_var", "G.layers.1.0.2.1.weight", "G.layers.1.0.2.1.bias", "G.layers.1.0.4.weight", "G.layers.1.0.4.bias", "G.layers.1.0.4.running_mean", "G.layers.1.0.4.running_var", "G.layers.2.0.2.1.weight", "G.layers.2.0.2.1.bias", "G.layers.2.0.4.weight", "G.layers.2.0.4.bias", "G.layers.2.0.4.running_mean", "G.layers.2.0.4.running_var", "G.layers.3.0.2.1.weight", "G.layers.3.0.2.1.bias", "G.layers.3.0.4.weight", "G.layers.3.0.4.bias", "G.layers.3.0.4.running_mean", "G.layers.3.0.4.running_var", "G.layers.3.2.fn.to_lin_q.weight", "G.layers.3.2.fn.to_lin_kv.net.0.weight", "G.layers.3.2.fn.to_lin_kv.net.1.weight", "G.layers.3.2.fn.to_kv.weight", "G.layers.4.0.2.1.weight", "G.layers.4.0.2.1.bias", "G.layers.4.0.4.weight", "G.layers.4.0.4.bias", "G.layers.4.0.4.running_mean", "G.layers.4.0.4.running_var", "G.layers.5.0.2.1.weight", "G.layers.5.0.2.1.bias", "G.layers.5.0.4.weight", "G.layers.5.0.4.bias", "G.layers.5.0.4.running_mean", "G.layers.5.0.4.running_var", "D.residual_layers.3.1.fn.to_lin_q.weight", "D.residual_layers.3.1.fn.to_lin_kv.net.0.weight", "D.residual_layers.3.1.fn.to_lin_kv.net.1.weight", "D.residual_layers.3.1.fn.to_kv.weight", "D.to_shape_disc_out.1.fn.fn.to_lin_q.weight", "D.to_shape_disc_out.1.fn.fn.to_lin_kv.net.0.weight", "D.to_shape_disc_out.1.fn.fn.to_lin_kv.net.1.weight", "D.to_shape_disc_out.1.fn.fn.to_kv.weight", "D.to_shape_disc_out.3.fn.fn.to_lin_q.weight", "D.to_shape_disc_out.3.fn.fn.to_lin_kv.net.0.weight", "D.to_shape_disc_out.3.fn.fn.to_lin_kv.net.1.weight", "D.to_shape_disc_out.3.fn.fn.to_kv.weight", "GE.layers.0.0.2.1.weight", "GE.layers.0.0.2.1.bias", "GE.layers.0.0.4.weight", "GE.layers.0.0.4.bias", "GE.layers.0.0.4.running_mean", "GE.layers.0.0.4.running_var", "GE.layers.1.0.2.1.weight", "GE.layers.1.0.2.1.bias", "GE.layers.1.0.4.weight", "GE.layers.1.0.4.bias", "GE.layers.1.0.4.running_mean", "GE.layers.1.0.4.running_var", "GE.layers.2.0.2.1.weight", "GE.layers.2.0.2.1.bias", "GE.layers.2.0.4.weight", "GE.layers.2.0.4.bias", "GE.layers.2.0.4.running_mean", "GE.layers.2.0.4.running_var", "GE.layers.3.0.2.1.weight", "GE.layers.3.0.2.1.bias", "GE.layers.3.0.4.weight", "GE.layers.3.0.4.bias", "GE.layers.3.0.4.running_mean", "GE.layers.3.0.4.running_var", "GE.layers.3.2.fn.to_lin_q.weight", "GE.layers.3.2.fn.to_lin_kv.net.0.weight", "GE.layers.3.2.fn.to_lin_kv.net.1.weight", "GE.layers.3.2.fn.to_kv.weight", "GE.layers.4.0.2.1.weight", "GE.layers.4.0.2.1.bias", "GE.layers.4.0.4.weight", "GE.layers.4.0.4.bias", "GE.layers.4.0.4.running_mean", "GE.layers.4.0.4.running_var", "GE.layers.5.0.2.1.weight", "GE.layers.5.0.2.1.bias", "GE.layers.5.0.4.weight", "GE.layers.5.0.4.bias", "GE.layers.5.0.4.running_mean", "GE.layers.5.0.4.running_var", "D_aug.D.residual_layers.3.1.fn.to_lin_q.weight", "D_aug.D.residual_layers.3.1.fn.to_lin_kv.net.0.weight", "D_aug.D.residual_layers.3.1.fn.to_lin_kv.net.1.weight", "D_aug.D.residual_layers.3.1.fn.to_kv.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_lin_q.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_lin_kv.net.0.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_lin_kv.net.1.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_kv.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_lin_q.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_lin_kv.net.0.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_lin_kv.net.1.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_kv.weight". Unexpected key(s) in state_dict: "G.layers.0.0.2.weight", "G.layers.0.0.2.bias", "G.layers.0.0.3.bias", "G.layers.0.0.3.running_mean", "G.layers.0.0.3.running_var", "G.layers.0.0.3.num_batches_tracked", "G.layers.1.0.2.weight", "G.layers.1.0.2.bias", "G.layers.1.0.3.bias", "G.layers.1.0.3.running_mean", "G.layers.1.0.3.running_var", "G.layers.1.0.3.num_batches_tracked", "G.layers.2.0.2.weight", "G.layers.2.0.2.bias", "G.layers.2.0.3.bias", "G.layers.2.0.3.running_mean", "G.layers.2.0.3.running_var", "G.layers.2.0.3.num_batches_tracked", "G.layers.3.0.2.weight", "G.layers.3.0.2.bias", "G.layers.3.0.3.bias", "G.layers.3.0.3.running_mean", "G.layers.3.0.3.running_var", "G.layers.3.0.3.num_batches_tracked", "G.layers.3.2.fn.to_kv.net.0.weight", "G.layers.3.2.fn.to_kv.net.1.weight", "G.layers.4.0.2.weight", "G.layers.4.0.2.bias", "G.layers.4.0.3.bias", "G.layers.4.0.3.running_mean", "G.layers.4.0.3.running_var", "G.layers.4.0.3.num_batches_tracked", "G.layers.5.0.2.weight", "G.layers.5.0.2.bias", "G.layers.5.0.3.bias", "G.layers.5.0.3.running_mean", "G.layers.5.0.3.running_var", "G.layers.5.0.3.num_batches_tracked", "D.residual_layers.3.1.fn.to_kv.net.0.weight", "D.residual_layers.3.1.fn.to_kv.net.1.weight", "D.to_shape_disc_out.1.fn.fn.to_kv.net.0.weight", "D.to_shape_disc_out.1.fn.fn.to_kv.net.1.weight", "D.to_shape_disc_out.3.fn.fn.to_kv.net.0.weight", "D.to_shape_disc_out.3.fn.fn.to_kv.net.1.weight", "GE.layers.0.0.2.weight", "GE.layers.0.0.2.bias", "GE.layers.0.0.3.bias", "GE.layers.0.0.3.running_mean", "GE.layers.0.0.3.running_var", "GE.layers.0.0.3.num_batches_tracked", "GE.layers.1.0.2.weight", "GE.layers.1.0.2.bias", "GE.layers.1.0.3.bias", "GE.layers.1.0.3.running_mean", "GE.layers.1.0.3.running_var", "GE.layers.1.0.3.num_batches_tracked", "GE.layers.2.0.2.weight", "GE.layers.2.0.2.bias", "GE.layers.2.0.3.bias", "GE.layers.2.0.3.running_mean", "GE.layers.2.0.3.running_var", "GE.layers.2.0.3.num_batches_tracked", "GE.layers.3.0.2.weight", "GE.layers.3.0.2.bias", "GE.layers.3.0.3.bias", "GE.layers.3.0.3.running_mean", "GE.layers.3.0.3.running_var", "GE.layers.3.0.3.num_batches_tracked", "GE.layers.3.2.fn.to_kv.net.0.weight", "GE.layers.3.2.fn.to_kv.net.1.weight", "GE.layers.4.0.2.weight", "GE.layers.4.0.2.bias", "GE.layers.4.0.3.bias", "GE.layers.4.0.3.running_mean", "GE.layers.4.0.3.running_var", "GE.layers.4.0.3.num_batches_tracked", "GE.layers.5.0.2.weight", "GE.layers.5.0.2.bias", "GE.layers.5.0.3.bias", "GE.layers.5.0.3.running_mean", "GE.layers.5.0.3.running_var", "GE.layers.5.0.3.num_batches_tracked", "D_aug.D.residual_layers.3.1.fn.to_kv.net.0.weight", "D_aug.D.residual_layers.3.1.fn.to_kv.net.1.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_kv.net.0.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_kv.net.1.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_kv.net.0.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_kv.net.1.weight". size mismatch for G.layers.0.0.3.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.1.0.3.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.2.0.3.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.3.0.3.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.3.2.fn.to_out.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 1024, 1, 1]). size mismatch for G.layers.4.0.3.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.5.0.3.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for D.residual_layers.3.1.fn.to_out.weight: copying a param with shape torch.Size([128, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 1024, 1, 1]). size mismatch for D.to_shape_disc_out.1.fn.fn.to_out.weight: copying a param with shape torch.Size([64, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 1024, 1, 1]). size mismatch for D.to_shape_disc_out.3.fn.fn.to_out.weight: copying a param with shape torch.Size([32, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([32, 1024, 1, 1]). size mismatch for GE.layers.0.0.3.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.1.0.3.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.2.0.3.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.3.0.3.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.3.2.fn.to_out.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 1024, 1, 1]). size mismatch for GE.layers.4.0.3.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.5.0.3.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for D_aug.D.residual_layers.3.1.fn.to_out.weight: copying a param with shape torch.Size([128, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 1024, 1, 1]). size mismatch for D_aug.D.to_shape_disc_out.1.fn.fn.to_out.weight: copying a param with shape torch.Size([64, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 1024, 1, 1]). size mismatch for D_aug.D.to_shape_disc_out.3.fn.fn.to_out.weight: copying a param with shape torch.Size([32, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([32, 1024, 1, 1]).

    opened by sebastiantrella 7
  • Greyscale image generation

    Greyscale image generation

    Hi,

    thank you for this repo, I've been playing with it a bit and it seems very good! I am trying to generate greyscale images, so I modified the channel accordingly

    init_channel = 4 if transparent else 1

    unfortunately, this seemed to have no effect as the images generated are still RGB (even though they converge towards greyscale with time), even weirder IMO is that I can modify the number of channels for the generator and keep the original 3 for the discriminator without any issue.

    I have also changed this part to no effect

    convert_image_fn = partial(convert_image_to, 'RGBA' if transparent else 'L') num_channels = 1 if not transparent else 4

    Am I missing something here?

    opened by stefanorosss 7
  • Getting NoneType is not subscriptable when trying to start training.

    Getting NoneType is not subscriptable when trying to start training.

    I've been able to train models before but after changing my dataset I'm getting the error.

    My trace: File "/usr/local/lib/python3.6/dist-packages/lightweight_gan/lightweight_gan.py", line 1356, in load name = checkpoints[-1] TypeError: 'NoneType' object is not subscriptable

    opened by TomCallan 6
  • Optimal parameters for Google Colab

    Optimal parameters for Google Colab

    Hello,

    First of all, thank you for sharing your code and insights with the rest of us!

    As for your code, I plan to run it for 12 hours on Google Colab, similarly to the set-up for what is shown in the README.

    My datasets consists of images of 256x256 resolution, and I have started training with the following command-line:

    !lightweight_gan \
     --data {image_dir} \
     --disc-output-size 5 \
     --aug-prob 0.25 \
     --aug-types [translation,cutout,color] \
     --amp \
    

    I have noticed that the expected training time is 112.5 hours with 150k iterations (the default setting), which is consistent with the average time of 2.7 seconds per iteration shown in the log. However, it is ~ 9 times more than what is shown in the README. So I wonder if I am doing something wrong, and I see 2 solutions.

    First, I could decrease the number of iterations so that it takes 12 hours, by choosing 16k iterations instead of 150k with:

     --num-train-steps 16000 \
    

    Is it what you have done for the results shown in the README?

    Second, I have noticed that I am only using 3.8 GB of GPU memory, so I could increase the batch size, as you mentioned in https://github.com/lucidrains/lightweight-gan/issues/13#issuecomment-732486110. Edit: However, the training time increases with a larger batch size. For instance, I am using 7.2 GB of GPU memory, and it takes 8.2 seconds per iteration, with the following:

     --batch-size 32 \
     --gradient-accumulate-every 4 \
    
    opened by woctezuma 6
  • Added Experimentation Tracking.

    Added Experimentation Tracking.

    Added Experimentation Tracking using Aim.

    Now you can:

    Track all the model hyperparameters and architectural choices. Track all types of losses. Filter all the experiments with respect to hyperparameters or the architecture Group and aggregate w.r.t. all the trackables to dive into granular experimentation assessment. Track the generated images to track how the model improves.

    Screen Shot 2022-04-12 at 16 56 35 Screen Shot 2022-04-12 at 16 57 24
    opened by hnhnarek 5
  • Aim installation error

    Aim installation error

    I'm trying to run the generator after training, to generate fake samples using the following command

    lightweight_gan --generate --load-from 299

    I get this following error:

    Traceback (most recent call last):
      File "C:\anaconda3\lib\runpy.py", line 197, in _run_module_as_main
        return _run_code(code, main_globals, None,
      File "C:\anaconda3\lib\runpy.py", line 87, in _run_code
        exec(code, run_globals)
      File "C:\anaconda3\Scripts\lightweight_gan.exe\__main__.py", line 7, in <module>
      File "C:\anaconda3\lib\site-packages\lightweight_gan\cli.py", line 195, in main
        fire.Fire(train_from_folder)
      File "C:\anaconda3\lib\site-packages\fire\core.py", line 141, in Fire
        component_trace = _Fire(component, args, parsed_flag_args, context, name)
      File "C:\anaconda3\lib\site-packages\fire\core.py", line 466, in _Fire
        component, remaining_args = _CallAndUpdateTrace(
      File "C:\anaconda3\lib\site-packages\fire\core.py", line 681, in _CallAndUpdateTrace
        component = fn(*varargs, **kwargs)
      File "C:\anaconda3\lib\site-packages\lightweight_gan\cli.py", line 158, in train_from_folder
        model = Trainer(**model_args)
      File "C:\anaconda3\lib\site-packages\lightweight_gan\lightweight_gan.py", line 1057, in __init__
        self.run = self.aim.Run(run_hash=aim_run_hash, repo=aim_repo)
    AttributeError: 'Trainer' object has no attribute 'aim'
    

    and when I try to run pip install aim, I get a dependency error with aimrocks

      ERROR: Command errored out with exit status 1:
       command: 'C:\Anaconda3\envs\aerialweb\python.exe' 'C:\Anaconda3\envs\aerialweb\lib\site-packages\pip' install --ignore-installed --no-user --prefix 'C:\Users\ahmed\AppData\Local\Temp\pip-build-env-b2ysw94t\overlay' --no-warn-script-location --no-binary :none: --only-binary :none: -i https://pypi.org/simple -- setuptools 'cython >= 3.0.0a9' 'aimrocks == 0.2.1'
           cwd: None
      Complete output (12 lines):
      WARNING: Ignoring invalid distribution -pencv-python (c:\anaconda3\envs\aerialweb\lib\site-packages)
      WARNING: Ignoring invalid distribution -cipy (c:\anaconda3\envs\aerialweb\lib\site-packages)
      Collecting setuptools
        Using cached setuptools-59.6.0-py3-none-any.whl (952 kB)
      Collecting cython>=3.0.0a9
        Using cached Cython-3.0.0a10-py2.py3-none-any.whl (1.1 MB)
      ERROR: Could not find a version that satisfies the requirement aimrocks==0.2.1 (from versions: 0.1.3a14, 0.2.0.dev1, 0.2.0)
      ERROR: No matching distribution found for aimrocks==0.2.1
      WARNING: Ignoring invalid distribution -pencv-python (c:\anaconda3\envs\aerialweb\lib\site-packages)
      WARNING: Ignoring invalid distribution -cipy (c:\anaconda3\envs\aerialweb\lib\site-packages)
      WARNING: Ignoring invalid distribution -pencv-python (c:\anaconda3\envs\aerialweb\lib\site-packages)
      WARNING: Ignoring invalid distribution -cipy (c:\anaconda3\envs\aerialweb\lib\site-packages)
    

    What is aimrocks and what does it actually do? I am unable to find a matching distribution or even a wheel file to install it manually. Please help

    opened by demiahmed 4
  • Can't find

    Can't find "__main__" module (sorry if noob question)

    Hello, I hope it's not too much of a noob question, I don't have any background in coding.

    After creating the env and installing Pytorch I ran "python setup.py install" and then I ran "python lightweight_gan --data /source --image-size 512" (I filled a "source" folder with pictures of fishes) but I get the error "can't find 'main' module". More exactly, C:\Programmes perso\Logiciels\Anaconda\envs\lightweightgan\python.exe: can't find 'main' module in 'C:\Programmes perso\Logiciels\LightweightGan\lightweight_gan' I tried to copy and rename some of the other modules (init, lightweight_gan...), the code seems to start to run but stops before doing anything. So I guess some file must be missing, or did I do something wrong ?

    Thanks a lot for the repo and have a nice day

    opened by SPotaufeux 4
  • Hard cutoff straight lines/boxes of nothing in generated images

    Hard cutoff straight lines/boxes of nothing in generated images

    Hello! Training on Google Colab with

    !lightweight_gan --data my/images/ --name my-name --image-size 256 --transparent --dual-contrast-loss --num-train-steps 250000
    

    I'm at 250k iterations over the course of 5 days at 2s/it, and have gotten strange results with boxes.

    I've circled some examples of this below. image

    My training data is 22k images of 256x256 .pngs that do not contain large hard edges or boxes like this. They're video game sprites with hard edges being limited to at most 10x10px

    Are there any suggestions I can do with arguments in order to decrease the chance of the models learning that transparent boxes are good? Would converting to a white background help?

    Thank you!

    opened by timendez 4
  • Amount of training steps

    Amount of training steps

    If I bring down the number of training steps from 150 000 to 30 000, will the trained model be overall bad? Does it really need the 100 000 or 150 000 training steps?

    opened by MartimQS 4
  • Executing with a trailing \ in the arguments sets the variable new to the truthy value '\\' and deletes all progress

    Executing with a trailing \ in the arguments sets the variable new to the truthy value '\\' and deletes all progress

    A rather frustrating issue:

    calling it with a trailing \ like lightweight_gan --data full_cleaned256/ --name c256 --models_dir models --results_dir results \

    sets the variable new to the truthy value '\' and deletes all progress.

    This might well be an issue with Fire but might be mitigated or fixed here too, I am unsure about that.

    Thanks. Jonas

    opened by deklesen 0
  • Projecting generated Images to Latent Space

    Projecting generated Images to Latent Space

    Is there any way to reverse engineer the generated images into the latent space?

    I am trying to embed fresh RGB as well as ones generated by the Generator into the latent space so I can find its nearest neighbour, pretty much like AI image editing tools.

    I plan to convert my RGB image into tensor embeddings based on my trained model and tweak the feature vectors.

    How can I achieve this with lightweight-gan?

    opened by demiahmed 0
  • Discriminator Loss converges to 0 while Generator loss pretty high

    Discriminator Loss converges to 0 while Generator loss pretty high

    I am trying to train with a custom image dataset for about 600,000 epochs. At about halfway, my D_loss converges to 0 while my G_loss stays put at 2.5

    My evaluation outputs are slowly starting to fade out to either black or white.

    Is there any thing that I could to tweak my model? Either by increasing the threshold for the Discriminator or by training the Generator only?

    opened by demiahmed 3
  • loss implementation differs from paper

    loss implementation differs from paper

    Hi,

    Thanks for this amazing implementation! I have a question concerning the loss implementation, as it seems to differ from the original equations. The screenshot below shows the GAN loss as presented in the paper :

    paper_losses

    • in red, the discriminator loss (D loss) on the true labels,
    • in green the D loss on labels for fake generated images,
    • and in blue, the generator loss (G loss) on labels for fake images.

    This makes sense to me. Since it is assumed that D outputs values between 0 and 1 (0 = fake, 1 = real) :

    • in red, we want D to output 1 for true images → let's assume D indeed outputs 1 for true images : -min(0, -1 + D(x)) = 0, which is indeed the minimum achievable,
    • in green, we want D to output 0 (from the discriminator perspective) for fake images → let's assume D indeed outputs 0 for fake images : -min(0, -1 - D(x^)) = 1, which is the minimum achievable if D outputs values only between 0 and 1,
    • in blue, we want D to output 1 (from the generator perspective) for fake images : the equation follows directly.

    Now, the way the authors implement this in the code provided in the supplementary materials of the paper is as follows (the colors match the ones in the above picture)

    og_code_loss_d_real og_code_loss_d_fake og_code_loss_g

    Except for the strange involved randomness (already explained in https://github.com/lucidrains/lightweight-gan/issues/11), their implementation is a one to one match with the paper equations.


    The way it is implemented in this repo however is quite different, and I do not understand why..

    lighweight_gan_losses

    Let's start with the discriminator loss :

    • in red, you want D to output small values (negative if allowed), to set this term as small as possible (0 if D can output negative values)
    • in green, you want D to output values as large as possible (larger or equal to 1) to cancel this term out as well

    For the generator loss :

    • in blue, you want the opposite of green, that is for D to output values as small as possible

    This implementation seems to be meaningful, and yields coherent results (as proven in examples). It also seems to me that D is not limited to output values between 0 and 1, but any real value (I might be wrong). I am just wondering why this choice? Could you perhaps elaborate why you decided to implement the loss differently from the original paper?

    opened by maximeraafat 1
  • showing results while training ?

    showing results while training ?

    how to show generator results after every epoch during training ?

    this is my current configuration

     lightweight_gan \
      --data "/content/dataset/Dataset/" \
      --num-train-steps 100000 \
      --image-size 128 \
      --name GAN2DBlood5k \
      --batch-size 32 \
      --gradient-accumulate-every 5 \
      --disc-output-size 1 \
      --dual-contrast-loss \
      --attn-res-layers [] \
      --calculate_fid_every 1000\
      --greyscale \
      --amp
    

    using --show-progress only works after training. Also it seems that there is no longer checkpoints per epoch

    opened by galaelized 2
Releases(1.1.1)
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
Block-wisely Supervised Neural Architecture Search with Knowledge Distillation (CVPR 2020)

DNA This repository provides the code of our paper: Blockwisely Supervised Neural Architecture Search with Knowledge Distillation. Illustration of DNA

Changlin Li 215 Dec 19, 2022
PyTorch implementation of our CVPR2021 (oral) paper "Prototype Augmentation and Self-Supervision for Incremental Learning"

PASS - Official PyTorch Implementation [CVPR2021 Oral] Prototype Augmentation and Self-Supervision for Incremental Learning Fei Zhu, Xu-Yao Zhang, Chu

67 Dec 27, 2022
The code for "Deep Level Set for Box-supervised Instance Segmentation in Aerial Images".

Deep Levelset for Box-supervised Instance Segmentation in Aerial Images Wentong Li, Yijie Chen, Wenyu Liu, Jianke Zhu* Any questions or discussions ar

sunshine.lwt 112 Jan 05, 2023
[CVPR 2021] Official PyTorch Implementation for "Iterative Filter Adaptive Network for Single Image Defocus Deblurring"

IFAN: Iterative Filter Adaptive Network for Single Image Defocus Deblurring Checkout for the demo (GUI/Google Colab)! The GUI version might occasional

Junyong Lee 173 Dec 30, 2022
A general python framework for single object tracking in LiDAR point clouds, based on PyTorch Lightning.

Open3DSOT A general python framework for single object tracking in LiDAR point clouds, based on PyTorch Lightning. The official code release of BAT an

Kangel Zenn 172 Dec 23, 2022
Malmo Collaborative AI Challenge - Team Pig Catcher

The Malmo Collaborative AI Challenge - Team Pig Catcher Approach The challenge involves 2 agents who can either cooperate or defect. The optimal polic

Kai Arulkumaran 66 Jun 29, 2022
DeepStruc is a Conditional Variational Autoencoder which can predict the mono-metallic nanoparticle from a Pair Distribution Function.

ChemRxiv | [Paper] XXX DeepStruc Welcome to DeepStruc, a Deep Generative Model (DGM) that learns the relation between PDF and atomic structure and the

Emil Thyge Skaaning Kjær 13 Aug 01, 2022
code for our BMVC 2021 paper "HCV: Hierarchy-Consistency Verification for Incremental Implicitly-Refined Classification"

HCV_IIRC code for our BMVC 2021 paper HCV: Hierarchy-Consistency Verification for Incremental Implicitly-Refined Classification by Kai Wang, Xialei Li

kai wang 13 Oct 03, 2022
ManipulaTHOR, a framework that facilitates visual manipulation of objects using a robotic arm

ManipulaTHOR: A Framework for Visual Object Manipulation Kiana Ehsani, Winson Han, Alvaro Herrasti, Eli VanderBilt, Luca Weihs, Eric Kolve, Aniruddha

AI2 65 Dec 30, 2022
Massively parallel Monte Carlo diffusion MR simulator written in Python.

Disimpy Disimpy is a Python package for generating simulated diffusion-weighted MR signals that can be useful in the development and validation of dat

Leevi 16 Nov 11, 2022
Code for ICCV 2021 paper "Distilling Holistic Knowledge with Graph Neural Networks"

HKD Code for ICCV 2021 paper "Distilling Holistic Knowledge with Graph Neural Networks" cifia-100 result The implementation of compared methods are ba

Wang Yucheng 30 Dec 18, 2022
A PyTorch implementation of "Predict then Propagate: Graph Neural Networks meet Personalized PageRank" (ICLR 2019).

APPNP ⠀ A PyTorch implementation of Predict then Propagate: Graph Neural Networks meet Personalized PageRank (ICLR 2019). Abstract Neural message pass

Benedek Rozemberczki 329 Dec 30, 2022
Human Activity Recognition example using TensorFlow on smartphone sensors dataset and an LSTM RNN. Classifying the type of movement amongst six activity categories - Guillaume Chevalier

LSTMs for Human Activity Recognition Human Activity Recognition (HAR) using smartphones dataset and an LSTM RNN. Classifying the type of movement amon

Guillaume Chevalier 3.1k Dec 30, 2022
SlotRefine: A Fast Non-Autoregressive Model forJoint Intent Detection and Slot Filling

SlotRefine: A Fast Non-Autoregressive Model for Joint Intent Detection and Slot Filling Reference Main paper to be cited (Di Wu et al., 2020) @article

Moore 34 Nov 03, 2022
PyTorch implementation DRO: Deep Recurrent Optimizer for Structure-from-Motion

DRO: Deep Recurrent Optimizer for Structure-from-Motion This is the official PyTorch implementation code for DRO-sfm. For technical details, please re

Alibaba Cloud 56 Dec 12, 2022
Multi-Agent Reinforcement Learning for Active Voltage Control on Power Distribution Networks (MAPDN)

Multi-Agent Reinforcement Learning for Active Voltage Control on Power Distribution Networks (MAPDN) This is the implementation of the paper Multi-Age

Future Power Networks 83 Jan 06, 2023
A machine learning benchmark of in-the-wild distribution shifts, with data loaders, evaluators, and default models.

WILDS is a benchmark of in-the-wild distribution shifts spanning diverse data modalities and applications, from tumor identification to wildlife monitoring to poverty mapping.

P-Lambda 437 Dec 30, 2022
Tracking Progress in Question Answering over Knowledge Graphs

Tracking Progress in Question Answering over Knowledge Graphs Table of contents Question Answering Systems with Descriptions The QA Systems Table cont

Knowledge Graph Question Answering 47 Jan 02, 2023
Video-Music Transformer

VMT Video-Music Transformer (VMT) is an attention-based multi-modal model, which generates piano music for a given video. Paper https://arxiv.org/abs/

Chin-Tung Lin 5 Jul 13, 2022
A python library for implementing a recommender system

python-recsys A python library for implementing a recommender system. Installation Dependencies python-recsys is build on top of Divisi2, with csc-pys

Oscar Celma 1.5k Dec 17, 2022