当前位置:网站首页>[代码阅读] CycleGAN: Unpaired Image-To-Image Translation Using Cycle-Consistent Adversarial Networks
[代码阅读] CycleGAN: Unpaired Image-To-Image Translation Using Cycle-Consistent Adversarial Networks
2022-08-04 09:51:00 【xiongxyowo】
损失函数
为了简单起见,只考虑Cycle中的半个循环,另一半是对称的。
Adversarial Loss
对抗损失的作用是判断生成的图像是否"像"目标域的图像,本文使用Least Square Loss(平方误差)来进行这一判断过程,公式如下:
L LSGAN ( G , D Y , X , Y ) = E y ∼ p data ( y ) [ ( D Y ( y ) − 1 ) 2 ] + E x ∼ p data ( x ) [ D Y ( G ( x ) ) 2 ] \mathcal{L}_{\text{LSGAN}}(G,\ D_{Y},\ X,\ Y)=\mathbb{E}_{y\sim p_{\text{data}}(y)}[(D_{Y}(y)-1)^{2}]+\mathbb{E}_{x\sim p_{\text{data}}(x)}[D_{Y}(G(x))^{2}] LLSGAN(G, DY, X, Y)=Ey∼pdata(y)[(DY(y)−1)2]+Ex∼pdata(x)[DY(G(x))2] 其中,G代表生成器, D Y D_{Y} DY表示目标域的判别器(用于区分生成的 G ( x ) G(x) G(x)与真实的 y y y), X X X表示源域, Y Y Y表示目标域。代码如下:
- 实现Least Square Loss:
# cycle_gan_model.py 90行
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) # define GAN loss.
# networks.py 231行
if gan_mode == 'lsgan':
self.loss = nn.MSELoss()
- 将源域A的图像x(real_A)送入生成器,得到假的目标域B的图像(fake_B):
# cycle_gan_model.py 114行
self.fake_B = self.netG_A(self.real_A) # G_A(A)
- 将真实的目标域B的图像送入判别器,得到一个分数。我们希望判别器能够区分出真实的图像:
# cycle_gan_model.py 131行
pred_real = netD(real)
loss_D_real = self.criterionGAN(pred_real, True)
- 将生成的目标域B的图像送入判别器,得到一个分数,我们希望判别器能够区分虚假的图像:
# cycle_gan_model.py 134行
pred_fake = netD(fake.detach())
loss_D_fake = self.criterionGAN(pred_fake, False)
至此对抗损失的判别器部分结束。生成器部分如下,我们希望生成的图像能够骗过判别器:
# cycle_gan_model.py 169行
self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
Cycle Loss
循环损失的作用在于令源域A的图像在转换至目标域B后,仍然能转换回源域A,这就间接要求A->B的过程不去丢失图像中原有的信息,从而实现图像风格转换。本文使用L1 Loss(绝对误差)来实现这一过程: L cyc ( G , F ) = E x ∼ p data ( x ) [ ∥ F ( G ( x ) ) − x ∥ 1 ] \mathcal{L}_{\text{cyc}}(G,\ F)=\mathbb{E}_{x\sim p_{\text{data}}(x)}[\Vert F(G(x))-x \Vert_{1}] Lcyc(G, F)=Ex∼pdata(x)[∥F(G(x))−x∥1] 代码如下:
- 实现L1 Loss:
# cycle_gan_model.py 91行
self.criterionCycle = torch.nn.L1Loss()
- 将源域A的图像x(real_A)送入生成器,得到假的目标域B的图像(fake_B),再转换回源域A的图像(rec_A):
# cycle_gan_model.py 114行
self.fake_B = self.netG_A(self.real_A) # G_A(A)
self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A))
判断重建的图像与原始图像是否相同:
# cycle_gan_model.py 173行
self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A)
边栏推荐
- 一文带你了解 ESLint
- 【正点原子STM32连载】第四章 STM32初体验 摘自【正点原子】MiniPro STM32H750 开发指南_V1.1
- NAT/NAPT地址转换(内外网通信)技术详解【华为eNSP】
- [Punctuality Atom STM32 Serial] Chapter 2 STM32 Introduction Excerpted from [Punctual Atom] MiniPro STM32H750 Development Guide_V1.1
- Since his 97, I roll but he...
- I am 37 this year, and I was rushed by a big factory to...
- HTB-Sense
- 【COS 加码福利】COS 用户实践有奖征文,等你来投稿!
- 参数优化。
- 多媒体和物联网技术让版本“活”起来 129张黑胶唱片“百年留声”
猜你喜欢
cannot import name 'import_string' from 'werkzeug' [bug solution]
redis解决分布式session问题
v-model原理,在“radio”、“checkbox”、“select”、修饰符
leetcode动态规划系列(求路径篇)
LVS+Keepalived群集部署
Detailed Explanation of Addresses Delivered by DHCP on Routing/Layer 3 Switches [Huawei eNSP]
leetcode经典例题——49.字母异位词分组
请你谈谈网站是如何进行访问的?【web领域面试题】
IDEA 自动导入的配置(Auto import)
leetcode每天5题-Day06
随机推荐
关于DSP驱动外挂flash
命里有时终须有--记与TiDB的一次次擦肩而过
张朝阳对话俞敏洪:谈宇宙、谈焦虑、谈创业、谈退休、谈人生
HTB-Nibbles
Four common methods of network attacks and their protection
LVS-DR集群部署
After four years of outsourcing, the autumn recruits finally landed
路由/三层交换机DHCP下发地址详解【华为eNSP】
Win11如何隐藏输入法悬浮窗?
请问下Flink SQL如何写hologres分区表?我想要每天一个分区
matlab练习程序(多线段交点)
双指针方法
参数优化文档介绍
常用的输入对象
IDEA 自动导入的配置(Auto import)
MindSpore:model.train中的dataset_sink_mode该如何理解?
How to restore the Youxuan database with only data files
I am 37 this year, and I was rushed by a big factory to...
Layer 3 Switch/Router OSPF Configuration Details [Huawei eNSP Experiment]
cannot import name ‘import_string‘ from ‘werkzeug‘【bug解决】