当前位置:网站首页>PyTorch四种常用优化器测试
PyTorch四种常用优化器测试
2022-07-06 09:16:00 【想成为风筝】
PyTorch四种常用优化器测试SGD、SGD(Momentum)、RMSprop、Adam
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
import torch
import torch.utils.data as Data
import torch.nn.functional as F
import matplotlib.pyplot as plt
#超参数
LR =0.001
Batch_Size = 32
Epochs = 12
#生成训练数据
x = torch.unsqueeze(torch.linspace(-1,1,1000),dim=1)
y = x.pow(2) + 0.1 * torch.normal(torch.zeros(*x.size()))
torch_dataset = Data.TensorDataset(x,y)
loader = Data.DataLoader(dataset=torch_dataset,batch_size=Batch_Size,shuffle=True)
class Net2(torch.nn.Module):
def __init__(self):
super(Net2,self).__init__()
self.hidden = torch.nn.Linear(1,20)
self.predict = torch.nn.Linear(20,1)
#前向传递
def forward(self,x):
x = F.relu(self.hidden(x))
x = self.predict(x)
return x
net_SGD = Net2()
net_Momentum =Net2()
net_RMSprop = Net2()
net_Adam = Net2()
nets = [net_SGD,net_Momentum,net_RMSprop,net_Adam]
opt_SGD = torch.optim.SGD(net_SGD.parameters(),lr=LR)
opt_Momentum = torch.optim.SGD(net_Momentum.parameters(),lr=LR,momentum=0.9)
opt_RMSProp = torch.optim.RMSprop(net_RMSprop.parameters(),lr=LR,alpha=0.9)
opt_Adam = torch.optim.Adam(net_Adam.parameters(),lr=LR,betas=(0.9,0.99))
optimizers = [opt_SGD,opt_Momentum,opt_RMSProp,opt_Adam]
loss_func = torch.nn.MSELoss()
loss_his = [[],[],[],[]]
for epoch in range(Epochs):
for step,(batch_x,batch_y) in enumerate(loader):
for net,opt,l_his in zip(nets,optimizers,loss_his):
output = net(batch_x)
loss = loss_func(output,batch_y)
opt.zero_grad()
loss.backward()
opt.step()
l_his.append(loss.data.numpy()) #loss recoder
labels = ['SGD','Momentum','RMsprop','Adam']
for i ,l_his in enumerate(loss_his):
plt.plot(l_his, label=labels[i])
plt.legend(loc='best')
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.ylim((0, 0.2))
plt.show()

边栏推荐
猜你喜欢

Stage 4 MySQL database

Wangeditor rich text reference and table usage

Linux yum安装MySQL

Basic use of pytest

Composition des mots (sous - total)

MySQL数据库面试题

MongoDB
![[Flink] Flink learning](/img/2e/ff53e0795456e301f61da908c013af.png)
[Flink] Flink learning
![[CDH] cdh5.16 configuring the setting of yarn task centralized allocation does not take effect](/img/e7/a0d4fc58429a0fd8c447891c848024.png)
[CDH] cdh5.16 configuring the setting of yarn task centralized allocation does not take effect

MySQL and C language connection (vs2019 version)
随机推荐
SQL时间注入
【kerberos】深入理解kerberos票据生命周期
vs2019 桌面程序快速入门
使用LinkedHashMap实现一个LRU算法的缓存
About string immutability
Nodejs connect mysql
B tree and b+ tree of MySQL index implementation
DICOM: Overview
Password free login of distributed nodes
Connexion sans mot de passe du noeud distribué
Vs2019 desktop app quick start
Redis面试题
Linux Yum install MySQL
Vs2019 first MFC Application
Détails du Protocole Internet
2020网鼎杯_朱雀组_Web_nmap
Those commonly used tool classes and methods in hutool
FTP file upload file implementation, regularly scan folders to upload files in the specified format to the server, C language to realize FTP file upload details and code case implementation
Yarn installation and use
【CDH】CDH5.16 配置 yarn 任务集中分配设置不生效问题