当前位置:网站首页>mm中的GAN模型架构
mm中的GAN模型架构
2022-07-02 23:09:00 【Kun Li】
mmgeneration中的GAN的架构包括,输入一张真实图像,如果是条件gan的话,输入的是一张真实图像和相应的标签,
1.首先对判别器进行训练,对判别器中梯度在训练时保存,和对判别器优化器进行zero_grad(),gan中生成器和判别器各有一个优化器,其实是两个相对独立的训练过程。
2.在训练判别器时也先使用生成器生成一个假样本,注意此时的生成器是不训练,仅仅是生成一个假样本,在条件gan时,还存在一个batch_accumulation_step的步骤,它在生成器训练时会多训练几次。
3.用判别器分别对真样本和假样本进行判别,计算判别器训练时的损失,正常的ganloss以及相应的辅助损失,放在disc_auxiliary_loss中,感觉gan中能够扩展的方向并不多,从架构上讲是transformers或者self-attention之类的,在理论上大体就是损失函数,l约束上的一些操作,所以此时的gen_auxiliary_loss是不同的一个点。判别器的loss本质上是个二分类,负样本输入,给的标签是0,正样本输入给的标签是1.
4.计算完损失之后,loss_disc.bachward()反向传播,optimizer['disciminator'].step()梯度更新,到这里判别器就训练完成了。
5.进行5-8轮的判别器训练之后开始转向生成器训练,每次训练前其实都对梯度进行了清零,相当于在每一轮时并不进行梯度累计,首先对判别器的梯度不进行保存,注意一开始训练判别器时,判别器的梯度是保存的,然后对生成器梯度进行zero_grad()。
6.训练生成器,生成器的输入还是噪声,得到生成器的图像之后,进行一次判别器判定,判别器最后一层一般是个Linear(n,1)的层,也就是输出是个N,1维的,输入到gen_loss中,这里gen_loss中核心也是辅助损失gen_auxiliary_loss中,此处也可能添加一些l约束之类的。
7.计算完损失之后,loss_gen.backward()反向传播,optimizer['generator'].step()梯度更新,生成器训练完成。

上图是一个典型的生成器和判别器的结构,在生成器中,我们需要一个将噪声向量转换成为二维特征的模块,也就是noise2feat block。接下来需要连续经过几个上采样块将低分辨率的特征转成高分辨率的特征,在 DCGAN 中,我们使用的是 transposed convolution 来实现。最后,需要一个 to_rgb 块来将特征图的通道数映射为3通道,从而生成图片。那判别器其实就是生成器的一个反转,我们需要通过 img2feat 和大量的下采样块将特征图不断降低分辨率,最后输送给 decision head,来对当前的输入图片进行评判。
边栏推荐
- JS interviewer wants to know how much you understand call, apply, bind no regrets series
- [shutter] open the third-party shutter project
- SQL query statement parameters are written successfully
- 多进程编程(三):消息队列
- 教育学大佬是怎么找外文参考文献的?
- 秒杀系统设计
- Define MySQL function to realize multi module call
- What website can you find English literature on?
- Luogu_ P1149 [noip2008 improvement group] matchstick equation_ Enumeration and tabulation
- Digital twin visualization solution digital twin visualization 3D platform
猜你喜欢

哪些软件可以整篇翻译英文论文?

Architecture: database architecture design

67 page overall planning and construction plan for a new smart city (download attached)
![[shutter] shutter open source project reference](/img/3f/b1d4edd8f8e8fd8e6b39548448270d.jpg)
[shutter] shutter open source project reference

Monitor container runtime tool Falco

Mutual exclusion and synchronization of threads

Explain in detail the significance of the contour topology matrix obtained by using the contour detection function findcontours() of OpenCV, and how to draw the contour topology map with the contour t

Should you study kubernetes?

Pytorch里面多任务Loss是加起来还是分别backward?

Hit the industry directly! The propeller launched the industry's first model selection tool
随机推荐
AcWing_ 188. Warrior cattle_ bfs
Which websites can I search for references when writing a thesis?
collections. What is the purpose of chainmap- What is the purpose of collections. ChainMap?
Interface automation coverage statistics - used by Jacobo
多进程编程(四):共享内存
MySQL 23道经典面试吊打面试官
Pat 1030 travel plan (30 points) (unfinished)
Realization of mask recognition based on OpenCV
Installing redis under Linux
Explain in detail the process of realizing Chinese text classification by CNN
Chapter 4 of getting started with MySQL: data types stored in data tables
Understanding and application of least square method
JSON conversion tool class
[target detection] r-cnn, fast r-cnn, fast r-cnn learning
Create an interactive experience of popular games, and learn about the real-time voice of paileyun unity
NC17059 队列Q
Architecture: load balancing
多进程编程(三):消息队列
v8
Mutual exclusion and synchronization of threads