当前位置:网站首页>[代码阅读] CycleGAN: Unpaired Image-To-Image Translation Using Cycle-Consistent Adversarial Networks
[代码阅读] CycleGAN: Unpaired Image-To-Image Translation Using Cycle-Consistent Adversarial Networks
2022-08-04 09:51:00 【xiongxyowo】
损失函数
为了简单起见,只考虑Cycle中的半个循环,另一半是对称的。
Adversarial Loss
对抗损失的作用是判断生成的图像是否"像"目标域的图像,本文使用Least Square Loss(平方误差)来进行这一判断过程,公式如下:
L LSGAN ( G , D Y , X , Y ) = E y ∼ p data ( y ) [ ( D Y ( y ) − 1 ) 2 ] + E x ∼ p data ( x ) [ D Y ( G ( x ) ) 2 ] \mathcal{L}_{\text{LSGAN}}(G,\ D_{Y},\ X,\ Y)=\mathbb{E}_{y\sim p_{\text{data}}(y)}[(D_{Y}(y)-1)^{2}]+\mathbb{E}_{x\sim p_{\text{data}}(x)}[D_{Y}(G(x))^{2}] LLSGAN(G, DY, X, Y)=Ey∼pdata(y)[(DY(y)−1)2]+Ex∼pdata(x)[DY(G(x))2] 其中,G代表生成器, D Y D_{Y} DY表示目标域的判别器(用于区分生成的 G ( x ) G(x) G(x)与真实的 y y y), X X X表示源域, Y Y Y表示目标域。代码如下:
- 实现Least Square Loss:
# cycle_gan_model.py 90行
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) # define GAN loss.
# networks.py 231行
if gan_mode == 'lsgan':
self.loss = nn.MSELoss()
- 将源域A的图像x(real_A)送入生成器,得到假的目标域B的图像(fake_B):
# cycle_gan_model.py 114行
self.fake_B = self.netG_A(self.real_A) # G_A(A)
- 将真实的目标域B的图像送入判别器,得到一个分数。我们希望判别器能够区分出真实的图像:
# cycle_gan_model.py 131行
pred_real = netD(real)
loss_D_real = self.criterionGAN(pred_real, True)
- 将生成的目标域B的图像送入判别器,得到一个分数,我们希望判别器能够区分虚假的图像:
# cycle_gan_model.py 134行
pred_fake = netD(fake.detach())
loss_D_fake = self.criterionGAN(pred_fake, False)
至此对抗损失的判别器部分结束。生成器部分如下,我们希望生成的图像能够骗过判别器:
# cycle_gan_model.py 169行
self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
Cycle Loss
循环损失的作用在于令源域A的图像在转换至目标域B后,仍然能转换回源域A,这就间接要求A->B的过程不去丢失图像中原有的信息,从而实现图像风格转换。本文使用L1 Loss(绝对误差)来实现这一过程: L cyc ( G , F ) = E x ∼ p data ( x ) [ ∥ F ( G ( x ) ) − x ∥ 1 ] \mathcal{L}_{\text{cyc}}(G,\ F)=\mathbb{E}_{x\sim p_{\text{data}}(x)}[\Vert F(G(x))-x \Vert_{1}] Lcyc(G, F)=Ex∼pdata(x)[∥F(G(x))−x∥1] 代码如下:
- 实现L1 Loss:
# cycle_gan_model.py 91行
self.criterionCycle = torch.nn.L1Loss()
- 将源域A的图像x(real_A)送入生成器,得到假的目标域B的图像(fake_B),再转换回源域A的图像(rec_A):
# cycle_gan_model.py 114行
self.fake_B = self.netG_A(self.real_A) # G_A(A)
self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A))
判断重建的图像与原始图像是否相同:
# cycle_gan_model.py 173行
self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A)
边栏推荐
- 学习使用php把stdClass Object转array的方法整理
- 学习在php中分析switch与ifelse的执行效率
- Libtomcrypt AES 加密及解密
- LeetCode简单题之最好的扑克手牌
- [Punctuality Atom STM32 Serial] Chapter 3 Development Environment Construction Excerpted from [Punctual Atom] MiniPro STM32H750 Development Guide_V1.1
- safe-point(safepoint 安全点) 和 safe-region(安全区域)「建议收藏」
- 【正点原子STM32连载】第一章 本书学习方法 摘自【正点原子】MiniPro STM32H750 开发指南_V1.1
- [Punctuality Atomic STM32 Serial] Chapter 1 Learning Method of the Book Excerpted from [Punctuality Atomic] MiniPro STM32H750 Development Guide_V1.1
- 数据万象内容审核 — 共建安全互联网,专项开展“清朗”直播整治行动
- 什么是元宇宙?
猜你喜欢
随机推荐
NAT/NAPT地址转换(内外网通信)技术详解【华为eNSP】
使用ClickHouse分析COS的清单和访问日志
leetcode经典例题——49.字母异位词分组
冰蝎工具开发实现动态二进制加密WebShell
Win11怎么进行左右键对调?
学习在微信小程序中判断url的文件后缀格式
TiDB升级与案例分享(TiDB v4.0.1 → v5.4.1)
leetcode二叉树系列(一)
MindSpore:【model_zoo】【resnet】尝试用THOR优化器运行时报cannot import name ‘THOR‘
学习在php中将特大数字转成带有千/万/亿为单位的字符串
MindSpore:损失函数问题
在测试集上训练,还能中CVPR?这篇IEEE批判论文是否合理?
MindSpore:model.train中的dataset_sink_mode该如何理解?
字符串与正则表达式(C#)
LeetCode中等题之设计循环队列
MindSpore:mirrorpad算子速度过慢的问题
MindSpore:图算融合报错
一文带你了解 ESLint
我和 TiDB 的故事 | 缘份在,那就终是能相遇的
DOM简述