当前位置:网站首页>SwinIR combat: record the training process of SwinIR in detail
SwinIR combat: record the training process of SwinIR in detail
2022-08-03 16:42:00 【HUAWEI CLOUD】
@[toc]
SwinIR实战:详细记录SwinIR的训练过程.
论文地址:https://arxiv.org/pdf/2108.10257.pdf
预训练模型下载:https://github.com/JingyunLiang/SwinIR/releases
训练代码下载:https://github.com/cszn/KAIR
测试代码:https://github.com/JingyunLiang/SwinIR
论文翻译:https://blog.csdn.net/hhhhhhhhhhwwwwwwwwww/article/details/124434886
测试:https://wanghao.blog.csdn.net/article/details/124517210
在写这边文章之前,我已经翻译了论文,讲解了如何使用SWinIR进行测试?
接下来,我们讲讲如何SwinIR完成训练,有于作者训练了很多任务,我只复现其中的一种任务.
下载训练代码
地址:https://github.com/cszn/KAIR
这是个超分的库,里面包含多个超分的模型,比如SCUNet、VRT、SwinIR、BSRGGAN、USRNet等模型.
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-B5Md9i7H-1651410061139)(https://gitee.com/wanghao1090220084/cloud-image/raw/master/img/face_09_comparison.png)]
下载后解压,训练SwinIR的REANDME.md,路径:./docs/README_SwinIR.md
数据集
训练和测试集可以下载如下. 请将它们分别放在 trainsets
和 testsets
中.
任务 | 训练集 | 测试集 |
---|---|---|
classical/lightweight image SR | DIV2K (800 training images) or DIV2K +Flickr2K (2650 images) | set5 + Set14 + BSD100 + Urban100 + Manga109 download all |
real-world image SR | SwinIR-M (middle size): DIV2K (800 training images) +Flickr2K (2650 images) + OST (10324 images,sky,water,grass,mountain,building,plant,animal) SwinIR-L (large size): DIV2K + Flickr2K + OST + WED(4744 images) + FFHQ (first 2000 images, face) + Manga109 (manga) + SCUT-CTW1500 (first 100 training images, texts) | RealSRSet+5images |
color/grayscale image denoising | DIV2K (800 training images) + Flickr2K (2650 images) + BSD500 (400 training&testing images) + WED(4744 images) | grayscale: Set12 + BSD68 + Urban100 color: CBSD68 + Kodak24 + McMaster + Urban100 download all |
JPEG compression artifact reduction | DIV2K (800 training images) + Flickr2K (2650 images) + BSD500 (400 training&testing images) + WED(4744 images) | grayscale: Classic5 +LIVE1 download all |
我下载了DIV2K数据集和 Flickr2K数据集,DIV2K大小有7G+,Flickr2K约20G.如果网速不好建议只下载DIV2K.
注:在选用classical任务,做训练时,只能使用DIV2K或者Flickr2K,不能把两种数据集放在一起训练,否则就出现维度对不上的情况,如下图:
暂时没有找到原因.
构建测试集,测试集的路径如下图:
由于表格中的测试集放在google,我不能下载,但是SwinIR的测试代码中有测试集,代码链接:https://github.com/JingyunLiang/SwinIR,下载下来直接复制到testsets文件夹下面.
构建训练集,将下载下来的DIV2K解压.将DIV2K_train_HR复制到trainsets文件夹下面,将其改为trainH.
将DIV2K_train_LR_bicubic文件夹的X2文件夹复制到trainsets文件夹下面,然后将其改名为trainL.
到这里,数据集部分就完成了,接下来开始训练.
训练
首先,打开options/swinir/train_swinir_sr_classical.json文件,查看里面的内容.
"task": "swinir_sr_classical_patch48_x2"
训练任务的名字.
"gpu_ids": [0,1]
选择GPU的ID,如果只有一快GPU,改为 [0].如果有更多的GPU,直接往后面添加即可.
"scale": 2 //2,3,48
放大的倍数,可以设置为2、3、4、8.
"datasets": { "train": { "name": "train_dataset" // just name , "dataset_type": "sr" // "dncnn" | "dnpatch" | "fdncnn" | "ffdnet" | "sr" | "srmd" | "dpsr" | "plain" | "plainpatch" | "jpeg" , "dataroot_H": "trainsets/trainH"// path of H training dataset. DIV2K (800 training images) , "dataroot_L": "trainsets/trainL" // path of L training dataset , "H_size": 96 // 96/144|192/384 | 128/192/256/512. LR patch size is set to 48 or 64 when compared with RCAN or RRDB. , "dataloader_shuffle": true , "dataloader_num_workers": 4 , "dataloader_batch_size": 1 // batch size 1 | 16 | 32 | 48 | 64 | 128. Total batch size =4x8=32 in SwinIR } , "test": { "name": "test_dataset" // just name , "dataset_type": "sr" // "dncnn" | "dnpatch" | "fdncnn" | "ffdnet" | "sr" | "srmd" | "dpsr" | "plain" | "plainpatch" | "jpeg" , "dataroot_H": "testsets/Set5/HR" // path of H testing dataset , "dataroot_L": "testsets/Set5/LR_bicubic/X2" // path of L testing dataset }}
上面的参数是对数据集的设置.
“H_size”: 96 ,HR图像的大小,和下面的img_size有对应关系,大小设置为img_size×scale.
“dataloader_num_workers”: 4,CPU的核数设置.
“dataloader_batch_size”: 32 ,设置训练的batch_size.
dataset_type:sr,指的是数据集类型SwinIR.
"netG": { "net_type": "swinir" , "upscale": 2 // 2 | 3 | 4 | 8 , "in_chans": 3 , "img_size": 48 // For fair comparison, LR patch size is set to 48 or 64 when compared with RCAN or RRDB. , "window_size": 8 , "img_range": 1.0 , "depths": [6, 6, 6, 6, 6, 6] , "embed_dim": 180 , "num_heads": [6, 6, 6, 6, 6, 6] , "mlp_ratio": 2 , "upsampler": "pixelshuffle" // "pixelshuffle" | "pixelshuffledirect" | "nearest+conv" | null , "resi_connection": "1conv" // "1conv" | "3conv" , "init_type": "default" }
upscale:2,放大的倍数,和上面的scale参数对应.
img_size:48,这里可以设置两个数值,48和64.和测试的training_patch_size参数对应.
官方提供的指令是基于DDP方式,比较复杂一下,好处是速度快.如下:
# 001 Classical Image SR (middle size)python -m torch.distributed.launch --nproc_per_node=8 --master_port=1234 main_train_psnr.py --opt options/swinir/train_swinir_sr_classical.json --dist True# 002 Lightweight Image SR (small size)python -m torch.distributed.launch --nproc_per_node=8 --master_port=1234 main_train_psnr.py --opt options/swinir/train_swinir_sr_lightweight.json --dist True# 003 Real-World Image SR (middle size)python -m torch.distributed.launch --nproc_per_node=8 --master_port=1234 main_train_psnr.py --opt options/swinir/train_swinir_sr_realworld_psnr.json --dist True# before training gan, put the PSNR-oriented model into superresolution/swinir_sr_realworld_x4_gan/models/python -m torch.distributed.launch --nproc_per_node=8 --master_port=1234 main_train_psnr.py --opt options/swinir/train_swinir_sr_realworld_gan.json --dist True# 004 Grayscale Image Deoising (middle size)python -m torch.distributed.launch --nproc_per_node=8 --master_port=1234 main_train_psnr.py --opt options/swinir/train_swinir_denoising_gray.json --dist True# 005 Color Image Deoising (middle size)python -m torch.distributed.launch --nproc_per_node=8 --master_port=1234 main_train_psnr.py --opt options/swinir/train_swinir_denoising_color.json --dist True# 006 JPEG Compression Artifact Reduction (middle size)python -m torch.distributed.launch --nproc_per_node=8 --master_port=1234 main_train_psnr.py --opt options/swinir/train_swinir_car_jpeg.json --dist True
我没有使用上面的方式,而是选择用DP的方式,虽然慢一点,但是简单,更稳定.
在Terminal里面输入:
python main_train_psnr.py --opt options/swinir/train_swinir_sr_classical.json
即可开始训练.
运行结果如下:
等待训练完成后,我们使用测试代码测试.将模型复制到./model_zoo/swinir文件夹下面
输入命令:
python main_test_swinir.py --task classical_sr --scale 2 --training_patch_size 48 --model_path model_zoo/swinir/45000_G.pth --folder_lq testsets/Set5/LR_bicubic/X2
然后在result下面可以看到测试结果.
完整的代码:
https://download.csdn.net/download/hhhhhhhhhhwwwwwwwwww/85258387
边栏推荐
猜你喜欢
MySQL相关介绍
leetcode SVM
使用 PowerShell 将 Windows 转发事件导入 SQL Server
error:Illegal instruction (core dumped),离线下载安装这个other版本numpy
Windows 事件查看器记录到 MYSQL
2年开发经验去面试,吊打面试官,即将面试的程序员这些笔记建议复习
华为、联想、北汽等入选工信部“企业数字化转型和安全能力提升”首批实训基地
MySQL窗口函数
Understand the recommendation system in one article: Outline 02: The link of the recommendation system, from recalling rough sorting, to fine sorting, to rearranging, and finally showing the recommend
deepstresam的插件配置说明,通过配置osd,设置字体的背景为透明
随机推荐
C专家编程 第1章 C:穿越时空的迷雾 1.7 编译限制
Detailed ReentrantLock
如何使用MATLAB绘制极坐标堆叠柱状图
简易网络传输方法
Cookie和Session的关系
MySQL窗口函数 PARTITION BY()函数介绍
Kubernetes 笔记 / 入门 / 生产环境 / 用部署工具安装 Kubernetes / 用 kubeadm 启动集群 / 安装 kubeadm
Windows 事件转发到 SQL 数据库
组件通信--下拉菜单案例
生产环境如何删除表呢?只能在SQL脚本里执行 drop table 吗
通俗理解apt-get 和pip的区别是什么
socket快速理解
C专家编程 第1章 C:穿越时空的迷雾 1.10 “安静的改变”究竟有多少安静
[Unity Getting Started Plan] Basic Concepts (6) - Sprite Renderer Sprite Renderer
视频人脸识别和图片人脸识别的关系
C语言03、数组
node连接mongoose数据库流程
devops-2:Jenkins的使用及Pipeline语法讲解
smp,numa和mpp体系结构总结
C语言04、操作符