当前位置:网站首页>mm中的GAN模型架构
mm中的GAN模型架构
2022-07-02 23:09:00 【Kun Li】
mmgeneration中的GAN的架构包括,输入一张真实图像,如果是条件gan的话,输入的是一张真实图像和相应的标签,
1.首先对判别器进行训练,对判别器中梯度在训练时保存,和对判别器优化器进行zero_grad(),gan中生成器和判别器各有一个优化器,其实是两个相对独立的训练过程。
2.在训练判别器时也先使用生成器生成一个假样本,注意此时的生成器是不训练,仅仅是生成一个假样本,在条件gan时,还存在一个batch_accumulation_step的步骤,它在生成器训练时会多训练几次。
3.用判别器分别对真样本和假样本进行判别,计算判别器训练时的损失,正常的ganloss以及相应的辅助损失,放在disc_auxiliary_loss中,感觉gan中能够扩展的方向并不多,从架构上讲是transformers或者self-attention之类的,在理论上大体就是损失函数,l约束上的一些操作,所以此时的gen_auxiliary_loss是不同的一个点。判别器的loss本质上是个二分类,负样本输入,给的标签是0,正样本输入给的标签是1.
4.计算完损失之后,loss_disc.bachward()反向传播,optimizer['disciminator'].step()梯度更新,到这里判别器就训练完成了。
5.进行5-8轮的判别器训练之后开始转向生成器训练,每次训练前其实都对梯度进行了清零,相当于在每一轮时并不进行梯度累计,首先对判别器的梯度不进行保存,注意一开始训练判别器时,判别器的梯度是保存的,然后对生成器梯度进行zero_grad()。
6.训练生成器,生成器的输入还是噪声,得到生成器的图像之后,进行一次判别器判定,判别器最后一层一般是个Linear(n,1)的层,也就是输出是个N,1维的,输入到gen_loss中,这里gen_loss中核心也是辅助损失gen_auxiliary_loss中,此处也可能添加一些l约束之类的。
7.计算完损失之后,loss_gen.backward()反向传播,optimizer['generator'].step()梯度更新,生成器训练完成。

上图是一个典型的生成器和判别器的结构,在生成器中,我们需要一个将噪声向量转换成为二维特征的模块,也就是noise2feat block。接下来需要连续经过几个上采样块将低分辨率的特征转成高分辨率的特征,在 DCGAN 中,我们使用的是 transposed convolution 来实现。最后,需要一个 to_rgb 块来将特征图的通道数映射为3通道,从而生成图片。那判别器其实就是生成器的一个反转,我们需要通过 img2feat 和大量的下采样块将特征图不断降低分辨率,最后输送给 decision head,来对当前的输入图片进行评判。
边栏推荐
- Go custom sort
- [shutter] Introduction to the official example of shutter Gallery (learning example | email application | retail application | wealth management application | travel application | news application | a
- MySQL advanced learning notes (III)
- Shell脚本基本使用
- Bloom filter
- 95 pages of smart education solutions 2022
- Multi process programming (III): message queue
- 毕业总结
- Array de duplication
- Several methods of the minimum value in the maximum value of group query
猜你喜欢

redis21道经典面试题,极限拉扯面试官

Digital twin visualization solution digital twin visualization 3D platform

67 page overall planning and construction plan for a new smart city (download attached)
![MATLAB signal processing [Q & a notes-1]](/img/53/ae081820fe81ce28e1f04914678a6f.png)
MATLAB signal processing [Q & a notes-1]

秒杀系统设计

哪些软件可以整篇翻译英文论文?

Markdown使用教程

MySQL 23 classic interview hanging interviewer

What website can you find English literature on?

CMake基本使用
随机推荐
The privatization deployment of SaaS services is the most efficient | cloud efficiency engineer points north
S12. Verify multi host SSH mutual access script based on key
NC50965 Largest Rectangle in a Histogram
Multiprocess programming (I): basic concepts
Understanding and application of least square method
大学生课堂作业2000~3000字的小论文,标准格式是什么?
Install docker and use docker to install MySQL
Should you study kubernetes?
国外的论文在那找?
redis21道经典面试题,极限拉扯面试官
Go自定义排序
Introduction of UART, RS232, RS485, I2C and SPI
Realization of mask recognition based on OpenCV
MySQL 23 classic interview hanging interviewer
In February 2022, the ranking list of domestic databases: oceanbase regained its popularity with "three consecutive increases", and gaussdb is expected to achieve the largest increase this month
AcWing_ 188. Warrior cattle_ bfs
MySQL advanced learning notes (III)
关于Unity屏幕相关Screen的练习题目,Unity内部环绕某点做运动
pod生命周期详解
[shutter] shutter open source project reference