当前位置:网站首页>pytorch基本操作:使用神经网络进行分类任务
pytorch基本操作:使用神经网络进行分类任务
2022-08-02 05:11:00 【樱花的浪漫】
1.读取Mnist数据
首先,读取Mnist数据,在深度学习框架中,数据的基本结构是tensor,据需转换成tensor才能参与后续建模训练,可用map函数将数据转换为tensor格式
import torch
x_train, y_train, x_valid, y_valid = map(
torch.tensor, (x_train, y_train, x_valid, y_valid)
)
n, c = x_train.shape
x_train, x_train.shape, y_train.min(), y_train.max()
print(x_train, y_train)
print(x_train.shape)
print(y_train.min(), y_train.max())
2.torch.nn.functional
torch.nn.functional中有很多功能, 比如,常见的激活函数、损失函数,一般情况下,如果模型有可学习的参数,最好用nn.Module,其他情况nn.functional相对更简单一些
3.创建一个model
- 必须继承nn.Module且在其构造函数中需调用nn.Module的构造函数
- 无需写反向传播函数,nn.Module能够利用autograd自动实现反向传播
- Module中的可学习参数可以通过named_parameters()或者parameters()返回迭代器
from torch import nn
class Mnist_NN(nn.Module):
def __init__(self):
super().__init__()
self.hidden1 = nn.Linear(784, 128)
self.hidden2 = nn.Linear(128, 256)
self.out = nn.Linear(256, 10)
def forward(self, x):
x = F.relu(self.hidden1(x))
x = F.relu(self.hidden2(x))
x = self.out(x)
return x
打印出来:
通过named_parameters()或者parameters()返回迭代器
4.使用TensorDataset和DataLoader加载数据
TensorDataset:将训练数据的特征和标签组合
DataLoader:随机读取小批量
5.训练模块
梯度下降方法和损失函数
torch默认会叠加梯度,所以结束后需要将梯度置零
- 一般在训练模型时加上model.train(),这样会正常使用Batch Normalization和 Dropout
- 测试的时候一般选择model.eval(),这样就不会使用Batch Normalization和 Dropout
import numpy as np
def fit(steps, model, loss_func, opt, train_dl, valid_dl):
for step in range(steps):
model.train()
for xb, yb in train_dl:
loss_batch(model, loss_func, xb, yb, opt)
model.eval()
with torch.no_grad(): # 验证时不进行梯度下降
losses, nums = zip(
*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
)
val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums) # 平均损失
print('当前step:'+str(step), '验证集损失:'+str(val_loss))
边栏推荐
- [PSQL] window function, GROUPING operator
- Matlab paper illustration drawing template No. 41 - bubble chart (bubblechart)
- golang的time包:时间间隔格式化和秒、毫秒、纳秒等时间戳格式输出的方法
- 服务器的单机防御与集群防御
- C语言中i++和++i在循环中的差异性
- 6W+字记录实验全过程 | 探索Alluxio经济化数据存储策略
- Contents of encoding-indexes.js file printed with Bluetooth:
- Detailed explanation of interface in Go language
- 【漫画】2021满分程序员行为对照表(最新版)
- 提高软件测试能力的方法有哪些?看完这篇文章让你提升一个档次
猜你喜欢
Automated operation and maintenance tools - ansible, overview, installation, module introduction
51 microcontroller peripherals article: dot-matrix LCD
ATM系统
leetcode一步解决链表反转问题
51单片机外设篇:点阵式LCD
Redis-cluster mode (master-slave replication mode, sentinel mode, clustering mode)
eggjs controller层调用controller层解决方案
51单片机外设篇:ADC
MySql将一张表的数据copy到另一张表中
51单片机外设篇:DS18B20
随机推荐
JUC(一)- JUC学习概览 - 对JUC有一个整体的认识
如何优化OpenSumi终端性能?
TikTok平台的两种账户有什么区别?
What do interview test engineers usually ask?The test supervisor tells you
golang's time package: methods for time interval formatting and output of timestamp formats such as seconds, milliseconds, and nanoseconds
网安学习-内网渗透4
测试环境要多少?从成本与效率说起
Automated operation and maintenance tools - ansible, overview, installation, module introduction
整合ssm(一)
C language entry combat (13): decimal number to binary
Mysql common commands
Google notes cut hidden plug-in installation impression
Alluxio为Presto赋能跨云的自助服务能力
Features and installation of non-relational database MongoDB
51单片机外设篇:点阵式LCD
Mysql implements optimistic locking
51 microcontroller peripherals article: dot-matrix LCD
MySQL导入sql文件的三种方法
eggjs controller层调用controller层解决方案
mysql实现按照自定义(指定顺序)排序