当前位置:网站首页>PyTorch 学习笔记 1 —— Quick Start
PyTorch 学习笔记 1 —— Quick Start
2022-07-28 05:24:00 【我有两颗糖】
文章目录
这是 PyTorch 学习笔记 的第一篇博客,学了一点点皮毛,先记录下来!
1. 环境检查
首先确认电脑是否有 GPU,有 GPU 记得安装对应版本的 CUDA 和支持 GPU 版本的 Pytorch,参考 PyTorch 环境搭建:Win11 + mx450,使用下面的 code 检查 GPU 是否可用:
import time
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
print(torch.__version__)
print(torch.cuda.is_available())
打印结果为 True,说明可以使用 GPU,接着指定使用 GPU(如果没有 GPU 或者安装的 PyTorch 版本不支持 GPU,则会自动设置 device 为 cpu):
# Get cup or gpu devices for trianing
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {
device} device")
2. 数据集下载与预处理
2.1 Download dataset
使用 torchvision 库中的 dataset 模块在官网下载数据集:
# Download training data frm open datasets
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
# download test data from open datasets
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
其中,参数含义如下:
rootis the path where the train/test data is stored,trainspecifies training or test dataset,download=Truedownloads the data from the internet if it’s not available at root.transformand target_transform specify the feature and label transformations
执行代码后会把数据集下载到当前路径的 ./data/ 目录下,FashionMNIST 继承了 MNIST,MNIST 的构造函数中,指定 download 为 True 则会执行 self.download(),首先检查数据集是否已经在当前目录,如果在则不下载(可自行查看源代码)
2.2 读取数据集
使用 torch.util.data 中的 DataLoader 类读取数据集,指定 batchsize=64:
# create data loaders
batch_size = 64
train_dataloader = DataLoader(dataset=training_data, batch_size=batch_size)
test_dataloader = DataLoader(dataset=test_data, batch_size=batch_size)
print(len(train_dataloader)) # 938
print(len(train_dataloader.dataset)) # 60000
for X, y in test_dataloader:
print(f'Shape of X [N, C, H, W]: {
X.shape}') # torch.Size([64, 1, 28, 28])
print(f'Shape of y: {
y.shape} {
y.dtype}') # torch.Size([64])
break
DataLoader 对象可迭代,训练集总样本数目为 60000,batchsize=64,则可划分为 60000/64 ≈938 个 batch,每个 batch 有 64 个样本,每个样本的 shape 为 (1, 28, 28)
3. 模型构建
创建一个简单的分类 NLP 模型,
# Define model
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10)
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork().to(device)
print(model)
注意,这里创建的 NeuralNetwork 对象调用了 .to(device),这样做的目的是,让模型能够在 GPU 上跑,实际上所有数据也需要调用 .to(device)
模型的构造函数中有模型的结构,forward() 为前向传播函数,通过举例说明每个层的作用:
nn.Flatten() 可以创建Flatten的对象 flatten,flatten 可以将输入的 matrix 拉伸为 vector(默认将维度从 1 到 -1 向量化):
input_image = torch.rand(3, 28, 28)
print(input_image.size()) # torch.Size([3, 28, 28])
flatten = nn.Flatten()
flat_image = flatten(input_image)
print(flat_image.size()) # torch.Size([3, 784])
nn.Linear() 创建 lenear 层,可以进行矩阵运算,in_features 和 out_features 分别为进行矩阵相乘使用的矩阵的行数和列数,可以理解为输入特征维数和输出特征维数:
layer1 = nn.Linear(in_features=28*28, out_features=20)
hidden1 = layer1(flat_image)
print(hidden1.size()) # torch.Size([3, 20])
nn.ReLU() 为 ReLU 激活函数,将输入中所有负元素置为零,不改变输入输出的特征维数:
print(f'Before ReLU: {
hidden1}\n\n')
relu = nn.ReLU()
hidden1 = relu(hidden1)
# hidden1 = nn.ReLU()(hidden1)
print(f'After ReLU: {
hidden1}')
nn.Sequential() 可以将不同的小的模块拼接成更大的网络:
squ_models = nn.Sequential(
flatten,
layer1,
nn.ReLU(),
nn.Linear(20, 10)
)
4. 模型训练与测试
4.1 train model
# Train model
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(params=model.parameters(), lr=0.3)
其中,nn.CrossEntropyLoss() 为交叉熵损失函数,optimizer 用来优化参数,使用方法为:
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
创建的 optimizer 对象将 [‘params’, ‘lr’, ‘momentum’, ‘dampening’, ‘weight_decay’, ‘nesterov’] 这 6 个参数存放到由 6 组键值对构成的字典 param_group 中,由于是传递列表,因此 optimizer 可以修改模型的参数
训练函数:
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
# Back propagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
# print message for each 100 batches
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f'loss: {
loss:>7f} [{
current:>5d} / {
size:>5f}]')
首先调用了前面创建的 NeuralNetwork 模型 model 的 train(),第二行的 model.train() 的具体内容并不是真正的训练模型,而是将模型包括其中的 sub models 设置为 training mode(可以查看源代码,查看 NeuralNetwork 的父类 nn.Module 的 train 函数)
接着对 dataloader 中的每个 batch 中的数据,首先将 X 和 y 转为 GPU 上可训练的格式,先计算预测值 pred 和交叉熵损失 loss, loss 是 torch.Tensor 类型,使用方法如下:
>>> loss = nn.CrossEntropyLoss()
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.empty(3, dtype=torch.long).random_(5)
>>> output = loss(input, target)
>>> output.backward()
反向传播
反向传播中,下面的结构很常见:
# Back propagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
在反向传播时,首先调用 optimizer.zero_grad() ,它会遍历模型的所有参数,通过p.grad.detach_() 方法截断反向传播的梯度流,再通过 p.grad.zero_() 函数将每个参数的梯度值设为0,即上一次的梯度记录被清空。因为训练的过程通常使用mini-batch方法,所以如果不将梯度清零的话,梯度会与上一个batch的数据相关,因此该函数要写在反向传播和梯度下降之前。
接着调用 loss.backward() 计算模型的所有参数的梯度,此时模型参数不再为 0
最后,optimizer.step() 函数的作用是执行一次优化步骤,通过梯度下降法来更新模型的参数。因为梯度下降是基于梯度的,所以在执行 optimizer.step() 函数前,应先执行loss.backward() 函数来计算梯度。
4.2 test model
读取每一个 test_dataset 中的 batch,根据预测值计算 loss,打印信息:
# Test model
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f'Test Error: \n Accuracy: {
(100 * correct):>0.1f}%, Average loss: {
test_loss:>8f}\n')
model.eval(): Sets the module in evaluation mode. 即进入评估模式,在评估模式下,batchNorm层,dropout层等用于优化训练而添加的网络层会被关闭,从而使得评估时不会发生偏移。
在对模型进行评估时,应该配合使用 with torch.no_grad() 与 model.eval(), torch.no_grad() 设置模型的所有 tensor 参数的 requires_grad 属性为 False,此时模型的所有参数不会自动求导,模型的性能不会变化。
4.3 训练模型主函数
设置学习率 lr 为 0.3,训练了 5 个 epochs,模型的准确率达到 85% 左右:
if __name__ == '__main__':
t1 = time.time()
epochs = 5
for t in range(epochs):
print(f'Epoch {
t+1}\n----------------------------')
train(train_dataloader, model, loss_fn, optimizer)
test(test_dataloader, model, loss_fn)
print('Done!')
t2 = time.time()
print(f'Duration: {
t2-t1:>0.2f} sec')
5. save and load models
保存模型
# Save models
torch.save(model.state_dict(), 'model.pth')
print('Saved PyTorch Model State to model.pth')
加载模型并测试
# Loading models
model = NeuralNetwork()
model.load_state_dict(torch.load('model.pth'))
# prediction
classes = [
"T-shirt/top",
"Trouser",
"Pullover",
"Dress",
"Coat",
"Sandal",
"Shirt",
"Sneaker",
"Bag",
"Ankle boot",
]
model.eval()
X, y = test_data[0][0], test_data[0][1]
with torch.no_grad():
pred = model(X)
predicted, actual = classes[pred[0].argmax(0)], classes[y]
print(f'Predicted: "{
predicted}", Actual: "{
actual}"')
注意,直接在像 sublime 之类的编辑器上跑模型,输出有点多会导致卡顿,使用终端命令才是正确方式:
>>> python xxx.py
以上就是全部内容了!
studying…
REFERENCE:
1 . https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html
3 . 理解optimizer.zero_grad(), loss.backward(), optimizer.step()的作用及原理
边栏推荐
猜你喜欢

硬件电路设计学习笔记2--降压电源电路

dsp和fpga的通讯

杭州某公司福禄克FLUKE DTX-SFM2单模模块-修复案例

Best practices to ensure successful deployment of Poe devices

TVS管参数与选型

Electric fast burst (EFT) design - EMC series hardware design notes 4

Reversible watermarking method based on difference expansion

IMS-FACNN(Improved Multi-Scale Convolution Neural Network integrated with a Feature Attention Mecha

clock tree分析实例

USB Network Native Driver for ESXi更新到支持ESXi7.0.1
随机推荐
Chinese display problem of calendarextender control
Photovoltaic power generation system MPPT maximum power point tracking
Deep learning (I): enter the theoretical part of machine learning and deep learning
How does fluke dtx-1800 test cat7 network cable?
Shuffle Net_ v1-shuffle_ v2
ConNeXt
Reversible digital watermarking method based on histogram modification
2、 Openvino brief introduction and construction process
In asp Usage of cookies in. Net
(PHP graduation project) based on PHP Gansu tourism website management system to obtain
8类网线测试仪AEM testpro CV100 和FLUKE DSX-8000哪些事?
4、 Model optimizer and inference engine
雷达成像 Matlab 仿真 3 —— 多目标检测
USB Network Native Driver for ESXi更新到支持ESXi7.0.1
VB-ocx应用于Web
怎么看SIMULINK直接搭的模块的传递函数
ICC2(三)Clock Tree Synthesis
Uniapp problem: "navigationbartextstyle" error: invalid prop: custom validator check failed for prop "Navigator
压敏电阻设计参数及经典电路记录 硬件学习笔记5
AEM-TESTpro K50和南粤勘察结下的缘分