当前位置:网站首页>CNN循环训练的解释 | PyTorch系列(二十二)
CNN循环训练的解释 | PyTorch系列(二十二)
2022-07-28 01:45:00 【51CTO】
文 |AI_study

原标题:CNN Training Loop Explained - Neural Network Code Project
- 准备数据
- 建立模型
- 训练模型
- 建立训练 loop
- 分析模型的结果
单个 batch 进行训练
我们可以将单个 batch 训练的代码总结如下:
输出
您会注意到的一件事是,每次运行这段代码都会得到不同的结果。这是因为模型每次都是在顶部创建的,我们从以前的文章中知道模型的权重是随机初始化的。
现在让我们看看如何修改这段代码来使用所有的batch,从而使用整个训练集进行训练。
所有 batch的训练 (epoch)
现在,为了训练我们的数据加载器中可用的所有批次,我们需要做一些更改并添加额外的一行代码:
我们将创建一个for循环来迭代所有batch 处理,而不是从数据加载器获取单个batch 处理。
因为我们的训练集中有60,000个样本,所以我们将有60,000 / 100 = 600次迭代。由于这个原因,我们将从循环中删除print语句,并跟踪总损失和最后打印它们的正确预测总数。
关于这600次迭代需要注意的一点是,到循环结束时,我们的权重将更新600次。如果我们提高batch_size这个数字会下降如果我们降低batch_size这个数字会上升。
最后,在我们对loss张量调用backward() 方法之后,我们知道梯度将被计算出来并添加到网络参数的grad属性中。因为这个原因,我们需要把这些梯度归零。我们可以使用优化器附带的zero_grad()方法来实现这一点。
我们已经准备好运行这段代码。这一次代码将花费更长的时间,因为循环将处理600个批。
我们得到了结果,我们可以看到60000中正确的总数是42104。
在只有一个epoch(一次完整的数据传递)之后,这已经很好了。即使我们做了一个epoch,我们仍然需要记住,权重被更新了600次,这取决于我们的批大小。如果让batch_batch的大小更大一些,比如10,000,那么权重只会更新 6 次,结果也不会很好。
多个 epoch的 训练
要执行多个epoch,我们所要做的就是将此代码放入for循环中。我们还将把epoch数添加到print语句中。
运行这段代码后,我们得到每个epoch的结果:
我们可以看到正确值的数量增加了,而loss减少了。
完整的训练 loop
将所有这些放在一起,我们可以将网络、优化器和train_loader从训练循环单元中提取出来。
optimizer = optim.Adam(network.parameters(), lr=0.01)
接下来是可视化结果
我们现在应该很好地理解了训练循环以及如何使用PyTorch来构建它们。PyTorch很酷的一点是,我们可以像调试forward()函数那样调试训练循环代码。
在下一篇文章中,我们将看到如何获得训练集中每个样本的预测,并使用这些预测创建一个混淆矩阵。下节课见!
文章中内容都是经过仔细研究的,本人水平有限,翻译无法做到完美,但是真的是费了很大功夫,希望小伙伴能动动你性感的小手,分享朋友圈,支持一下我 ^_^
英文原文链接是:
https://deeplizard.com/learn/video/XfYmia3q2Ow




边栏推荐
- 修改MySQL密码的四种方法(适合初学者)
- How do you use the jar package sent by others (how to use the jar package sent by others)
- 【英雄哥七月集训】第 27天:图
- unordered_ The hash function of map and the storage mode of hash bucket
- 基于stm32的恒功率无线充电
- Interviewer: what is the factory method mode?
- Use try-with-resources or close this
- "The faster the code is written, the slower the program runs"
- Representation of children and brothers of trees
- [in depth study of 4g/5g/6g topic -42]: urllc-14 - in depth interpretation of 3GPP urllc related protocols, specifications and technical principles -8-low delay technology-2-slot based scheduling and
猜你喜欢

【ELM分类】基于核极限学习机和极限学习机实现UCI数据集分类附matlab代码

Design of edit memory path of edit box in Gui

MySQL is shown in the figure. The existing tables a and B need to be associated with a and B tables through projectcode to find idcardnum with different addresses.

Find - block search

Use try-with-resources or close this
![[leetcode] 13. linked list cycle · circular linked list](/img/58/c8796bb5ed96d09325b8f2fa6a709e.png)
[leetcode] 13. linked list cycle · circular linked list

【微信小程序开发(六)】绘制音乐播放器环形进度条

"The faster the code is written, the slower the program runs"

Learn this trick and never be afraid to let the code collapse by mistake

Pycharm 快速给整页全部相同名称修改的快捷键
随机推荐
基于FPGA的64位8级流水线加法器
New infrastructure helps the transformation and development of intelligent road transportation
2022.7.8 eth price analysis
超参数调整和实验-训练深度神经网络 | PyTorch系列(二十六)
Pycharm 快速给整页全部相同名称修改的快捷键
[elm classification] classification of UCI data sets based on nuclear limit learning machine and limit learning machine, with matlab code
0动态规划中等 LeetCode873. 最长的斐波那契子序列的长度
[software testing] - unittest framework for automated testing
Some shortest path problems solved by hierarchical graph
Common SQL statement query
【信号去噪】基于卡尔曼滤波实现信号去噪附matlab代码
[data processing] boxplot drawing
Three core issues of concurrent programming (glory Collection Edition)
【LeetCode】13. Linked List Cycle·环形链表
【微信小程序开发(五)】接口按照根据开发版体验版正式版智能配置
Usage of delegate
JS event loop synchronous task, asynchronous task (micro task, macro task) problem analysis
[brother hero's July training] day 27: picture
Email security report in the second quarter: email attacks have soared fourfold, and well-known brands have been used to gain trust
Eigenvalues and eigenvectors