当前位置:网站首页>PyTorch四种常用优化器测试
PyTorch四种常用优化器测试
2022-07-06 09:16:00 【想成为风筝】
PyTorch四种常用优化器测试SGD、SGD(Momentum)、RMSprop、Adam
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
import torch
import torch.utils.data as Data
import torch.nn.functional as F
import matplotlib.pyplot as plt
#超参数
LR =0.001
Batch_Size = 32
Epochs = 12
#生成训练数据
x = torch.unsqueeze(torch.linspace(-1,1,1000),dim=1)
y = x.pow(2) + 0.1 * torch.normal(torch.zeros(*x.size()))
torch_dataset = Data.TensorDataset(x,y)
loader = Data.DataLoader(dataset=torch_dataset,batch_size=Batch_Size,shuffle=True)
class Net2(torch.nn.Module):
def __init__(self):
super(Net2,self).__init__()
self.hidden = torch.nn.Linear(1,20)
self.predict = torch.nn.Linear(20,1)
#前向传递
def forward(self,x):
x = F.relu(self.hidden(x))
x = self.predict(x)
return x
net_SGD = Net2()
net_Momentum =Net2()
net_RMSprop = Net2()
net_Adam = Net2()
nets = [net_SGD,net_Momentum,net_RMSprop,net_Adam]
opt_SGD = torch.optim.SGD(net_SGD.parameters(),lr=LR)
opt_Momentum = torch.optim.SGD(net_Momentum.parameters(),lr=LR,momentum=0.9)
opt_RMSProp = torch.optim.RMSprop(net_RMSprop.parameters(),lr=LR,alpha=0.9)
opt_Adam = torch.optim.Adam(net_Adam.parameters(),lr=LR,betas=(0.9,0.99))
optimizers = [opt_SGD,opt_Momentum,opt_RMSProp,opt_Adam]
loss_func = torch.nn.MSELoss()
loss_his = [[],[],[],[]]
for epoch in range(Epochs):
for step,(batch_x,batch_y) in enumerate(loader):
for net,opt,l_his in zip(nets,optimizers,loss_his):
output = net(batch_x)
loss = loss_func(output,batch_y)
opt.zero_grad()
loss.backward()
opt.step()
l_his.append(loss.data.numpy()) #loss recoder
labels = ['SGD','Momentum','RMsprop','Adam']
for i ,l_his in enumerate(loss_his):
plt.plot(l_his, label=labels[i])
plt.legend(loc='best')
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.ylim((0, 0.2))
plt.show()
边栏推荐
- Aborted connection 1055898 to db:
- Niuke novice monthly race 40
- Password free login of distributed nodes
- Implementation scheme of distributed transaction
- 机器学习--线性回归(sklearn)
- About string immutability
- Codeforces Round #753 (Div. 3)
- [MRCTF2020]套娃
- Are you monitored by the company for sending resumes and logging in to job search websites? Deeply convinced that the product of "behavior awareness system ba" has not been retrieved on the official w
- Small L's test paper
猜你喜欢
Kaggle竞赛-Two Sigma Connect: Rental Listing Inquiries(XGBoost)
MySQL主从复制的原理以及实现
Come and walk into the JVM
Principle and implementation of MySQL master-slave replication
Correspondence between STM32 model and contex M
[CDH] cdh5.16 configuring the setting of yarn task centralized allocation does not take effect
MySQL与c语言连接(vs2019版)
Vs2019 first MFC Application
R & D thinking 01 ----- classic of embedded intelligent product development process
Stage 4 MySQL database
随机推荐
数据库面试常问的一些概念
Vs2019 use wizard to generate an MFC Application
PHP - whether the setting error displays -php xxx When PHP executes, there is no code exception prompt
Heating data in data lake?
Funny cartoon: Programmer's logic
Word排版(小计)
Using LinkedHashMap to realize the caching of an LRU algorithm
Machine learning notes week02 convolutional neural network
[NPUCTF2020]ReadlezPHP
保姆级出题教程
【yarn】Yarn container 日志清理
[Bluebridge cup 2020 preliminary] horizontal segmentation
ES6 let 和 const 命令
Contiki source code + principle + function + programming + transplantation + drive + network (turn)
[Presto] Presto parameter configuration optimization
vs2019 桌面程序快速入门
Redis面试题
Those commonly used tool classes and methods in hutool
Implementation scheme of distributed transaction
Détails du Protocole Internet