当前位置:网站首页>Pytorch|GAN在手写数字集上的复现
Pytorch|GAN在手写数字集上的复现
2022-08-01 16:56:00 【Rolandxxx】
在复现开始之前需要知道几个tricky但很重要的知识点:
1.在pytorch中,神经网络层中的权值weight和偏差bias的tensor均为叶子节点,自己定义的tensor例如a=torch.tensor([1.0])定义的节点是叶子节点,中间计算产生的变量都叫非叶子节点。默认情况下,只有叶子节点的梯度值能够被保留下来,非叶子节点的梯度值在反向传播过程中使用完后就会被清除,不会被保留,除非使用 retain_grad() 方法。backward函数是计算当前tensor对计算图的叶子节点的梯度。backward函数的计算方式中,梯度是累积计算而不是被替换,所以不清0的话梯度就会累加上去。
2.fake.detach()返回的是一个新的tensor,从当前计算图中分离下来的,但是仍指向原变量fake的存放位置,不同之处只是requires_grad为false,得到的这个tensor永远不需要计算其梯度,不具有grad。进一步理解就是fake.detach()这个tensor变成了当前计算图的叶子节点,计算图其实就是代表程序中变量之间的关系,计算图会在backward()函数执行后被清理掉,由于叶子节点变成了fake.detach(),那么fake前的变量的计算关系是没有被清理掉的。
3.元组中只有一个数据要加逗号,
tup1 = (23) # 不是元组
print(type(tup1)) #<class 'int'>
tup2 = (23,) # 是元组
print(type(tup2))#<class 'tuple'>
现在进入正题,开始复现
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter # to print to tensorboard
class Discriminator(nn.Module):
def __init__(self,img_dim):
super().__init__()
self.disc = nn.Sequential(
nn.Linear(img_dim,128),
nn.LeakyReLU(0.1),
nn.Linear(128,1),
nn.Sigmoid(),
)
def forward(self,x):
return self.disc(x)
class Generator(nn.Module):
def __init__(self,z_dim,img_dim):
super().__init__()
self.gen = nn.Sequential(nn.Linear(z_dim,256),
nn.LeakyReLU(0.1),
nn.Linear(256,img_dim),#28*28*1 -> 784
nn.Tanh(),
)
def forward(self,x):
return self.gen(x)
device ="cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64
img_dim = 28*28*1
batch_size = 32
num_epochs =50
disc = Discriminator(img_dim).to(device)
gen = Generator(z_dim,img_dim).to(device)
fixed_noise = torch.randn((batch_size,z_dim)).to(device)
transforms = transforms.Compose(
#transforms.Normalize是channel的对图像进行标准化,当数据维数为1时,数据后面要有逗号
#因为传进去的数据类型是元组,所以要加逗号,元组中只有一个数据要加逗号,这传list也行
[transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))]#image=(image-mean)/std,img.shape(28,28)
)
dataset = datasets.MNIST(root="dataset/",transform=transforms,download=True)
loader = DataLoader(dataset,batch_size=batch_size,shuffle=True)
opt_disc = optim.Adam(disc.parameters(),lr=lr)
opt_gen = optim.Adam(gen.parameters(),lr=lr)
criterion = nn.BCELoss()
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")#for tensorboard
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")#for tensorboard
step = 0#for tensorboard
for epoch in range(num_epochs):
for batch_idx, (real,_) in enumerate(loader):
real = real.view(-1,784).to(device)
batch_size =real.shape[0]
####Traning Discriminator: 最大化 log(D(real))+log(1-D(G(z)))
noise = torch.randn(batch_size,z_dim).to(device)
fake = gen(noise)
disc_real = disc(real).view(-1) #shape: torch.Size([32])
lossD_real = criterion(disc_real,torch.ones_like(disc_real))#最大化log(D(real))这一项就等于最小化这个的bce损失
#用.detach()或者77行的retain_graph=True都可以,二选一
# disc_fake = disc(fake.detach()).view(-1) #截断fake节点前的梯度传播,所以fake.detach这个tensor此时就成了叶子节点
disc_fake = disc(fake).view(-1)
lossD_fake = criterion(disc_fake,torch.zeros_like(disc_fake))
lossD = (lossD_fake+lossD_real)/2
disc.zero_grad() #只对判别器模型的梯度清0,不清0的话梯度会叠加,这里生成器参数的梯度就没有清掉
#lossD.backward()
#backward函数是计算当前tensor对图叶子结点的梯度
#计算图在backward一次之后各个节点的值会清除,但因为我们下面还要backward一次,所以需要retain_graph=True保存这个图。
#因为进行了backward后,叶子节点的梯度值是保存了,但计算图被释放了
lossD.backward(retain_graph=True)
opt_disc.step()#只更新判别器的参数
####Traning Generator: 最小化 log(1-D(G(z))) ->最大化 log(D(G(Z))),因为这样梯度比较大
output = disc(fake).view(-1)
lossG = criterion(output,torch.ones_like(output))
gen.zero_grad() #对生成器模型参数的梯度清0,如果不清0,下次backward计算就会形成累加
lossG.backward()#由于保留了计算图,这样就可以求叶子结点生成器的参数:gen.grad
opt_gen.step()#只更新生成器的参数
if batch_idx == 0:
print(
f"Epoch [{
epoch}/{
num_epochs}] Batch {
batch_idx}/{
len(loader)} \
Loss D: {
lossD:.4f}, loss G: {
lossG:.4f}"
)
with torch.no_grad():
fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
data = real.reshape(-1, 1, 28, 28)
img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
img_grid_real = torchvision.utils.make_grid(data, normalize=True)
writer_fake.add_image(
"Mnist Fake Images", img_grid_fake, global_step=step
)
writer_real.add_image(
"Mnist Real Images", img_grid_real, global_step=step
)
step += 1
边栏推荐
- DateTime Helper Class for C#
- 参观首钢园
- DOM series of touch screen events
- 金仓数据库 MySQL 至 KingbaseES 迁移最佳实践(2. 概述)
- 90后的焦虑,被菜市场治好了
- MySQL最大建议行数2000w, 靠谱吗?
- 工业制造行业的低代码开发平台思维架构图
- Complete knapsack problem to find the number of combinations and permutations
- 金仓数据库KingbaseES安全指南--6.3. Kerberos身份验证
- 06 redis cluster structures
猜你喜欢

MySQL locking case analysis

The anxiety of the post-90s was cured by the vegetable market

我的新书销量1万册了!

2022 Strong Net Cup CTF---Strong Net Pioneer ASR wp

How to Efficiently Develop Jmix Extension Components

谁还敢买影视股?

Rancher 部署 DataKit 最佳实践

好家伙,公司服务器直接热崩掉了!

Unity ui点击事件只响应最上层ui的方式

04 flink cluster construction
随机推荐
第一次改开源中间件keycloak总个结
短剧正在抢长剧的生意
每日优鲜大败局
显示为弹出窗口是什么意思(电脑总是弹出广告)
素域和扩域
网上开户佣金万一靠谱吗,网上开户安全吗
金仓数据库KingbaseES安全指南--6.5. LDAP身份验证
谁还敢买影视股?
Ali's official Redis development specification
沈腾拯救暑期档
【硬核拆解】50块2个的2022年夏季款智能节电器到底能不能省电?
08 spark 集群搭建
在码云拉取代码后,调整了seata版本1.5.2。出现如下异常。是因为数据库表缺少字段导致的吗?
银行案例|Zabbix跨版本升级指南,4.2-6.0不香吗?
缓存一致性MESI与内存屏障
Winform的UI帮助类——部分组件会使用到DevExpress组件
工业制造行业的低代码开发平台思维架构图
云商店携手快报税,解锁财务服务新体验!
08 Spark cluster construction
搭建云计算平台(云计算管理平台搭建)