当前位置:网站首页>PyTorch(13)---优化器_随机梯度下降法
PyTorch(13)---优化器_随机梯度下降法
2022-08-02 14:07:00 【伏月三十】
优化器:随机梯度下降法
反向传播—梯度下降
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
dataset=torchvision.datasets.CIFAR10("dataset_CIFAR10",train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader=DataLoader(dataset,batch_size=1)
class Demo(nn.Module):
def __init__(self) -> None:
super().__init__()
self.model1=Sequential(
Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2, dilation=1, ),
MaxPool2d(kernel_size=2, ),
Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2, ),
MaxPool2d(kernel_size=2),
Conv2d(32, 64, 5, 1, 2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10),
)
def forward(self,x):
x=self.model1(x)
return x
demo=Demo()
loss=nn.CrossEntropyLoss()
'''优化器:随机梯度下降'''
optim=torch.optim.SGD(params=demo.parameters(),lr=0.01,)
for epoch in range(20):#通常训练很多轮
running_loss=0.0
for data in dataloader:
imgs,targets=data
'''送入网络进行训练'''
outputs=demo(imgs)
'''损失函数'''
result_loss=loss(outputs,targets)
'''优化器进行调优 1、调用一个优化器:optim=torch.optim.SGD(params=demo.parameters(),lr=0.01,) 2、将梯度初始化为0,在循环里这一步都要清0 3、损失函数调用反向传播 4、optim.step() '''
optim.zero_grad()#首先将梯度设置为0
result_loss.backward()#反向传播,算出梯度(梯度下降法),目的:求出最小的loss 得到需要调节的梯度
optim.step()
running_loss=running_loss+result_loss#在每一轮开始之前,将loss都设置为0,整体误差的总和
#print(result_loss)
print(running_loss)
边栏推荐
猜你喜欢
MySQL知识总结 (三) 索引
绕过正则实现SQL注入
St. Regis Takeaway Notes - Lecture 05 Getting Started with Redis
我理解的学习金字塔
Policy Evaluation收敛性、炼丹与数学家
利用红外-可见光图像数据集OTCBVS打通图像融合、目标检测和目标跟踪
无人驾驶综述:国外国内发展历程
国内IT市场还有发展吗?有哪些创新好用的IT运维工具可以推荐?
[论文阅读] ACT: An Attentive Convolutional Transformer for Efficient Text Classification
拥抱Jetpack之印象篇
随机推荐
Win10不能启动WampServer图标呈橘黄色的解决方法
关于Flink
LLVM系列第三章:函数Function
App signature in flutter
LLVM系列第二十五章:简单统计一下LLVM源码行数
Kubernetes介绍
语言模型(NNLM)
记录Yolo-tiny-v4的权重提取和中间层结果提取
Spark_Core
spark优化
The Handler you really understand?
LLVM系列第十七章:控制流语句for
LLVM系列第二十四章:用Xcode编译调试LLVM源码
6.如何使用CardView制作卡片布局效果
Flink前期代码结构
Kubernetes资源编排系列之三: Kustomize篇
UIWindow的makeKeyAndVisible不调用rootviewController 的viewDidLoad的问题
宝塔搭建PHP自适应懒人网址导航源码实测
spark资源调度和任务调度
预训练模型 Bert