当前位置:网站首页>pytorch基本操作:使用神经网络进行分类任务
pytorch基本操作:使用神经网络进行分类任务
2022-08-02 05:11:00 【樱花的浪漫】
1.读取Mnist数据
首先,读取Mnist数据,在深度学习框架中,数据的基本结构是tensor,据需转换成tensor才能参与后续建模训练,可用map函数将数据转换为tensor格式
import torch
x_train, y_train, x_valid, y_valid = map(
torch.tensor, (x_train, y_train, x_valid, y_valid)
)
n, c = x_train.shape
x_train, x_train.shape, y_train.min(), y_train.max()
print(x_train, y_train)
print(x_train.shape)
print(y_train.min(), y_train.max())
2.torch.nn.functional
torch.nn.functional中有很多功能, 比如,常见的激活函数、损失函数,一般情况下,如果模型有可学习的参数,最好用nn.Module,其他情况nn.functional相对更简单一些

3.创建一个model
- 必须继承nn.Module且在其构造函数中需调用nn.Module的构造函数
- 无需写反向传播函数,nn.Module能够利用autograd自动实现反向传播
- Module中的可学习参数可以通过named_parameters()或者parameters()返回迭代器
from torch import nn
class Mnist_NN(nn.Module):
def __init__(self):
super().__init__()
self.hidden1 = nn.Linear(784, 128)
self.hidden2 = nn.Linear(128, 256)
self.out = nn.Linear(256, 10)
def forward(self, x):
x = F.relu(self.hidden1(x))
x = F.relu(self.hidden2(x))
x = self.out(x)
return x
打印出来:
通过named_parameters()或者parameters()返回迭代器

4.使用TensorDataset和DataLoader加载数据
TensorDataset:将训练数据的特征和标签组合
DataLoader:随机读取小批量
5.训练模块
梯度下降方法和损失函数

torch默认会叠加梯度,所以结束后需要将梯度置零
- 一般在训练模型时加上model.train(),这样会正常使用Batch Normalization和 Dropout
- 测试的时候一般选择model.eval(),这样就不会使用Batch Normalization和 Dropout
import numpy as np
def fit(steps, model, loss_func, opt, train_dl, valid_dl):
for step in range(steps):
model.train()
for xb, yb in train_dl:
loss_batch(model, loss_func, xb, yb, opt)
model.eval()
with torch.no_grad(): # 验证时不进行梯度下降
losses, nums = zip(
*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
)
val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums) # 平均损失
print('当前step:'+str(step), '验证集损失:'+str(val_loss))
边栏推荐
- 软件测试的需求人才越来越多,为什么大家还是不太愿意走软件测试的道路?
- Mysql实现乐观锁
- Redis database
- 51单片机外设篇:DS18B20
- Navicat cannot connect to mysql super detailed processing method
- Brush LeetCode topic series - 10. Regular expression match
- 无代码生产新模式探索
- LeetCode brush topic series - 787 K station transfer within the cheapest flight
- Mysql implements optimistic locking
- 165.比较版本号
猜你喜欢

The original question on the two sides of the automatic test of the byte beating (arranged according to the recording) is real and effective 26

如何优化OpenSumi终端性能?

25K测试老鸟6年经验的面试心得,四种公司、四种问题…
![[PSQL] window function, GROUPING operator](/img/95/5c9dc06539330db907d22f84544370.png)
[PSQL] window function, GROUPING operator

区块元素、内联元素(<div>元素、span元素)

非关系型数据库MongoDB的特点及安装

ApiPost 真香真强大,是时候丢掉 Postman、Swagger 了

Say good woman programmers do testing have an advantage?More than a dozen interview, abuse of cry ~ ~ by the interviewer

Block elements, inline elements (
elements, span elements)
51单片机外设篇:点阵式LCD
随机推荐
关于 VS Code 优化启动性能的实践
Navicat报错:1045 -拒绝访问用户[email protected](使用passwordYES)
上海交大牵手淘宝成立媒体计算实验室:推动视频超分等关键技术发展
Redis数据库
C language: Check for omissions and fill in vacancies (3)
腾讯大咖分享 | 腾讯Alluxio(DOP)在金融场景的落地与优化实践
leetcode括号匹配问题——32.最长有效括号
el-input can only input integers (including positive numbers, negative numbers, 0) or only integers (including positive numbers, negative numbers, 0) and decimals
Constructors, member variables, local variables
golang泛型
Google 安装印象笔记剪藏插件
卸载redis
说好的女程序员做测试有优势?面试十几家,被面试官虐哭~~
leetcode每天5题-Day04
Features and installation of non-relational database MongoDB
跨桌面端Web容器演进
What do interview test engineers usually ask?The test supervisor tells you
Install and use Google Chrome
How H5 realizes evoking APP
classSR论文阅读笔记