当前位置:网站首页>Pytorch|GAN在手写数字集上的复现
Pytorch|GAN在手写数字集上的复现
2022-08-01 16:56:00 【Rolandxxx】
在复现开始之前需要知道几个tricky但很重要的知识点:
1.在pytorch中,神经网络层中的权值weight和偏差bias的tensor均为叶子节点,自己定义的tensor例如a=torch.tensor([1.0])定义的节点是叶子节点,中间计算产生的变量都叫非叶子节点。默认情况下,只有叶子节点的梯度值能够被保留下来,非叶子节点的梯度值在反向传播过程中使用完后就会被清除,不会被保留,除非使用 retain_grad() 方法。backward函数是计算当前tensor对计算图的叶子节点的梯度。backward函数的计算方式中,梯度是累积计算而不是被替换,所以不清0的话梯度就会累加上去。
2.fake.detach()返回的是一个新的tensor,从当前计算图中分离下来的,但是仍指向原变量fake的存放位置,不同之处只是requires_grad为false,得到的这个tensor永远不需要计算其梯度,不具有grad。进一步理解就是fake.detach()这个tensor变成了当前计算图的叶子节点,计算图其实就是代表程序中变量之间的关系,计算图会在backward()函数执行后被清理掉,由于叶子节点变成了fake.detach(),那么fake前的变量的计算关系是没有被清理掉的。
3.元组中只有一个数据要加逗号,
tup1 = (23) # 不是元组
print(type(tup1)) #<class 'int'>
tup2 = (23,) # 是元组
print(type(tup2))#<class 'tuple'>
现在进入正题,开始复现
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter # to print to tensorboard
class Discriminator(nn.Module):
def __init__(self,img_dim):
super().__init__()
self.disc = nn.Sequential(
nn.Linear(img_dim,128),
nn.LeakyReLU(0.1),
nn.Linear(128,1),
nn.Sigmoid(),
)
def forward(self,x):
return self.disc(x)
class Generator(nn.Module):
def __init__(self,z_dim,img_dim):
super().__init__()
self.gen = nn.Sequential(nn.Linear(z_dim,256),
nn.LeakyReLU(0.1),
nn.Linear(256,img_dim),#28*28*1 -> 784
nn.Tanh(),
)
def forward(self,x):
return self.gen(x)
device ="cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64
img_dim = 28*28*1
batch_size = 32
num_epochs =50
disc = Discriminator(img_dim).to(device)
gen = Generator(z_dim,img_dim).to(device)
fixed_noise = torch.randn((batch_size,z_dim)).to(device)
transforms = transforms.Compose(
#transforms.Normalize是channel的对图像进行标准化,当数据维数为1时,数据后面要有逗号
#因为传进去的数据类型是元组,所以要加逗号,元组中只有一个数据要加逗号,这传list也行
[transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))]#image=(image-mean)/std,img.shape(28,28)
)
dataset = datasets.MNIST(root="dataset/",transform=transforms,download=True)
loader = DataLoader(dataset,batch_size=batch_size,shuffle=True)
opt_disc = optim.Adam(disc.parameters(),lr=lr)
opt_gen = optim.Adam(gen.parameters(),lr=lr)
criterion = nn.BCELoss()
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")#for tensorboard
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")#for tensorboard
step = 0#for tensorboard
for epoch in range(num_epochs):
for batch_idx, (real,_) in enumerate(loader):
real = real.view(-1,784).to(device)
batch_size =real.shape[0]
####Traning Discriminator: 最大化 log(D(real))+log(1-D(G(z)))
noise = torch.randn(batch_size,z_dim).to(device)
fake = gen(noise)
disc_real = disc(real).view(-1) #shape: torch.Size([32])
lossD_real = criterion(disc_real,torch.ones_like(disc_real))#最大化log(D(real))这一项就等于最小化这个的bce损失
#用.detach()或者77行的retain_graph=True都可以,二选一
# disc_fake = disc(fake.detach()).view(-1) #截断fake节点前的梯度传播,所以fake.detach这个tensor此时就成了叶子节点
disc_fake = disc(fake).view(-1)
lossD_fake = criterion(disc_fake,torch.zeros_like(disc_fake))
lossD = (lossD_fake+lossD_real)/2
disc.zero_grad() #只对判别器模型的梯度清0,不清0的话梯度会叠加,这里生成器参数的梯度就没有清掉
#lossD.backward()
#backward函数是计算当前tensor对图叶子结点的梯度
#计算图在backward一次之后各个节点的值会清除,但因为我们下面还要backward一次,所以需要retain_graph=True保存这个图。
#因为进行了backward后,叶子节点的梯度值是保存了,但计算图被释放了
lossD.backward(retain_graph=True)
opt_disc.step()#只更新判别器的参数
####Traning Generator: 最小化 log(1-D(G(z))) ->最大化 log(D(G(Z))),因为这样梯度比较大
output = disc(fake).view(-1)
lossG = criterion(output,torch.ones_like(output))
gen.zero_grad() #对生成器模型参数的梯度清0,如果不清0,下次backward计算就会形成累加
lossG.backward()#由于保留了计算图,这样就可以求叶子结点生成器的参数:gen.grad
opt_gen.step()#只更新生成器的参数
if batch_idx == 0:
print(
f"Epoch [{
epoch}/{
num_epochs}] Batch {
batch_idx}/{
len(loader)} \
Loss D: {
lossD:.4f}, loss G: {
lossG:.4f}"
)
with torch.no_grad():
fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
data = real.reshape(-1, 1, 28, 28)
img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
img_grid_real = torchvision.utils.make_grid(data, normalize=True)
writer_fake.add_image(
"Mnist Fake Images", img_grid_fake, global_step=step
)
writer_real.add_image(
"Mnist Real Images", img_grid_real, global_step=step
)
step += 1
边栏推荐
猜你喜欢
DOM series of touch screen events
C# LibUsbDotNet 在USB-CDC设备的上位机应用
第一次改开源中间件keycloak总个结
[Dark Horse Morning Post] Hu Jun's endorsement of Wukong's financial management is suspected of fraud, which is suspected to involve 39 billion yuan; Fuling mustard responded that mustard ate toenails
zabbix部署和简单使用
棕榈油罐区数字化转型
今晚直播!
LeetCode第 303 场周赛
22年镜头“卷”史,智能手机之战卷进死胡同
好家伙,公司服务器直接热崩掉了!
随机推荐
变量交换;复合赋值;增递减运算符
经验|如何做好业务测试?
【R语言】对图片进行裁剪 图片批量裁剪
我的新书销量1万册了!
百度网盘下载速度提升100倍
Winform的消息提示框帮助类
第一次改开源中间件keycloak总个结
【Unity,C#】哨兵点位循迹模板代码
ROS2系列知识(5):【参数】如何管理?
70后夫妻给苹果华为做“雨衣”,三年进账7.91亿
SQL函数 TIMESTAMPDIFF
直播app开发,是优化直播体验不得不关注的两大指标
显示为弹出窗口是什么意思(电脑总是弹出广告)
Go unit tests
今晚直播!
11 Publish a series as soon as it is released
云商店携手快报税,解锁财务服务新体验!
【建议收藏】技术面必考题:多线程、多进程
搭建云计算平台(云计算管理平台搭建)
Ali's official Redis development specification