当前位置:网站首页>Use pytoch to quickly train the network model
Use pytoch to quickly train the network model
2022-07-28 23:52:00 【The next day is expected 1314】
Let me write it first :
This blog does not involve the explanation of model principles , It can be regarded as a pure engineering experiment . I read a lot of code in the paper model before , I just don't understand , Understand the general process and put it down . This experiment is to carefully experience the details .
Everybody knows ,pytorch The underlying code has been well encapsulated , We only need to write a little code to run a model . So this experiment has another purpose , Try to reuse the written code .
1. SVHN Data sets
The first step before the experiment , Is to select data sets . I saw before that this data set is used in many papers of the summit , Here we also follow the trend , If you want to download, you can click here . This data set is a data set about digital color image design , It can be understood as more complex Mnist
Data sets . Show you its complexity . I can't see some samples clearly , I really don't know how the boss did it 90+ Of , terrible !
2. Dataset And DataLoader
These two classes encapsulate the data set loading process and preprocessing process , Let the upper layer ignore the implementation details of the lower layer .Dataset:
import scipy.io as sio
from torch.utils.data import Dataset
from torch.utils.data.dataset import T_co
class SVHN(Dataset):
def __init__(self, file_path) -> None:
super().__init__()
self.file_path = file_path
data_mat = sio.loadmat(self.file_path)
self.X = data_mat["X"]
self.y = data_mat["y"]
def __getitem__(self, index) -> T_co:
return self.X[:, :, :, index], self.y[index]
def __len__(self):
return self.y.shape[0]
It is worth noting that , We need to override the parent class Dataset
Two approaches ,__getitem__
, __len__
.__getitem__
The method is to return a training sample and label , __len__
The method is to return the length of the data set .
DataLoader:
dataLoader = DataLoader(dataset, batch_size=batchSize, shuffle=True)
Some students will ask when they see here ,Dataset There is already an interface to return data ? Why is there another layer DataLoader
Well ? The reason is in the process of network training , Samples are not input one by one , It is a Batch One Batch The input of . there Batch It can be understood as a set of training samples ( Multiple samples are packed together ).DataLoader There are many optional parameters , I won't introduce it in detail here , Interested students can go to check pytoch Of API file .
3. ResNet Model
Don't write the model structure by yourself here ,pytorch There is an official implementation , Let's steal laziness here .
from torchvision import models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resnet18 = models.resnet18()
# Modify the output of the full connection layer
num_ftrs = resnet18.fc.in_features
# Ten categories , Change the output layer to 10
resnet18.fc = nn.Linear(num_ftrs, 10)
# Model parameter amplification GPU On , Speed up your training
resnet18 = resnet18.to(device)
4. Training
This part is actually the main workload of this time . This is full of a lot of template code , Almost every model will be used . This part is mainly about Calculating Losses , Back propagation , Optimizer . The optimizer optimizes the back propagation . What is more helpless is , This part has also been implemented , Just use it , Very convenient .
def train(model, dataLoader, optimizer, lossFunc, n_epoch):
start_time = time.time()
test_best_loss = float('inf')
last_improve = 0 # Record the last validation set loss Falling down batch Count
flag = False # Record whether there is no effect improvement for a long time
total_batch = 0 # Record how much batch
writer = SummaryWriter(log_dir=log_path + '/' + time.strftime('%m-%d_%H.%M', time.localtime()))
for epoch in range(n_epoch):
print('Epoch [{}/{}]'.format(epoch + 1, n_epoch))
model.train()
sum_loss = 0.0
correct = 0.0
total = 0.0
for batch_idx, dataset in enumerate(dataLoader):
length = len(dataLoader)
optimizer.zero_grad()
data, labelOrg = dataset
data = data.to(device)
label = F.one_hot(labelOrg.to(torch.long), 10).to(torch.float).to(device)
predict = model(data)
loss = lossFunc(predict, label)
loss.backward()
optimizer.step()
# Tensor.item() Type conversion , Return a number
sum_loss += loss.item()
# maxIdx, maxVal = torch.max
_, predicted = torch.max(predict.data, dim=1)
total += label.size(0)
correct += predicted.cpu().eq(labelOrg.data).sum()
# Note that here is a batch For a unit
print("[epoch:%d, iter:%d] Loss: %.03f | Acc: %.3f%% "
% (epoch + 1, (batch_idx + 1 + epoch * length), sum_loss / (batch_idx + 1), 100. * correct / total))
# Every 100 batch Calculate the accuracy of the model retest set or verification set
if total_batch % 100 == 0:
testDataLoss, testDataAcc = evalTestAcc(model)
time_dif = get_time_dif(start_time)
if testDataLoss < test_best_loss:
test_best_loss = testDataLoss
torch.save(model.state_dict(), save_path)
improve = '*'
last_improve = total_batch
else:
improve = ''
msg = 'Iter: {0:>6}, Train Loss: {1:>5.2}, Train Acc: {2:>6.2%}, Test Loss: {3:>5.2}, Test Acc: {4:>6.2%}, Time: {5} {6}'
print(msg.format(total_batch, sum_loss / (batch_idx + 1), correct / total, testDataLoss, testDataAcc, time_dif, improve))
writer.add_scalar("loss/train", loss.item(), total_batch)
writer.add_scalar("loss/dev", testDataLoss, total_batch)
writer.add_scalar("acc/train", correct / total, total_batch)
writer.add_scalar("acc/dev", testDataAcc, total_batch)
# Provide two exits for the training program : n_epoch, require_improvement individual batch No promotion
total_batch += 1
model.train()
if total_batch - last_improve > require_improvement:
# Verification set loss exceed 1000batch No drop , Finish training
print("No optimization for a long time, auto-stopping...")
flag = True
break
if flag:
break
writer.close()
def evalTestAcc(net):
net.eval()
totalAcc = 0.0
sumLoss = 0.0
total = 0.0
with torch.no_grad():
for idx, dataset in enumerate(testDataLoader):
data, labelOrg = dataset
predict = net(data.to(device))
_, predicted = torch.max(predict.data, dim=1)
totalAcc += predicted.cpu().eq(labelOrg).sum()
label = F.one_hot(labelOrg.to(torch.long), 10).to(torch.float).to(device)
sumLoss += lossFunc(predict, label).item()
total += label.size(0)
return sumLoss / len(testDataLoader), totalAcc / total
Look at the , It feels like there's nothing to say , Almost all of them are template code , It can be used in any model . It is worth noting that , In this experiment, there is no distinction between test set and verification set , It can be understood that there is no test set , In the experiment testDataset Used as a validation set , Adjust the training parameters .
5. call
if __name__ == '__main__':
# filePath = r"E:\dataset\SVHN\train_32x32.mat"
save_path = r"model_save/net.pt"
log_path = r"logs"
require_improvement = 1000
batchSize = 256
n_epoch = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resnet18 = models.resnet18()
# Modify the output of the full connection layer
num_ftrs = resnet18.fc.in_features
resnet18.fc = nn.Linear(num_ftrs, 10)
resnet18 = resnet18.to(device)
# SVHNTrainData = SVHN(filePath)
train_dataset = torchvision.datasets.SVHN(
root=r'E:\dataset\SVHN',
split='train',
download=False,
transform=torchvision.transforms.ToTensor()
)
test_dataset = torchvision.datasets.SVHN(
root=r'E:\dataset\SVHN',
split='test',
download=False,
transform=torchvision.transforms.ToTensor()
)
dataLoader = DataLoader(train_dataset, batch_size=batchSize, shuffle=True)
testDataLoader = DataLoader(test_dataset, batch_size=batchSize, shuffle=True)
optimizer = optim.SGD(resnet18.parameters(), lr=0.01, momentum=0.9)
lossFunc = nn.CrossEntropyLoss()
train(resnet18, dataLoader, optimizer, lossFunc, n_epoch)
Here we put all the contents together . After the operation is completed , In the current directory, it will be generated in a logs Folder , You can run tensorboard --logdir Folder address
, You can see the picture below , Record the training process , The curve of loss and accuracy on the test set and verification set .
6. Serialization and deserialization
Serialization and deserialization , We can understand it as saving in loading . After our model is trained , You can directly carry out the prediction task , At this time, the model parameters will not be updated in the back propagation .
Reference resources , This blog is too clear , It covers almost everything , I don't want to talk anymore . Let me record my deserialization process here .
import random
import numpy as np
import torch
import torchvision
from matplotlib import pyplot as plt
from torch import nn
from torch.utils.data import DataLoader
from torchvision import models
if __name__ == '__main__':
path = r"model_save/net.pt"
batchSize = 256
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resnet18 = models.resnet18()
# Modify the output of the full connection layer
num_ftrs = resnet18.fc.in_features
resnet18.fc = nn.Linear(num_ftrs, 10)
# resnet18 = resnet18.to(device)
resnet18.load_state_dict(torch.load(path, map_location=torch.device("cpu")))
resnet18.eval()
test_dataset = torchvision.datasets.SVHN(
root=r'E:\dataset\SVHN',
split='test',
download=False,
transform=torchvision.transforms.ToTensor()
)
testDataLoader = DataLoader(test_dataset, batch_size=batchSize, shuffle=True)
trains, labels = iter(testDataLoader).__next__()
predicts = resnet18(trains)
# In fact, only one sample can be predicted , Instead of a batch
# resnet18(trains[0].unsqueeze(0))
_, predictLabels = torch.max(predicts, dim=1)
fig, axs = plt.subplots(1, 5, figsize=(10, 10)) # Create subgraphs
print("predictLabels: {}".format(predictLabels))
print("labels: {}".format(labels))
print("Acc: {:.2f}".format(predictLabels.data.eq(labels).sum() / labels.shape[0]))
for i in range(5):
num = random.randint(0, batchSize) # First, select random numbers , Five times at random
npimg, nplabel = trains[num], labels[num]
axs[i].imshow(np.transpose(npimg, (1, 2, 0)))
axs[i].set_title("GroundTruth: {}, Predict: {}".format(nplabel, predictLabels[num])) # Label each subgraph
axs[i].axis("off") # Eliminate the coordinate axis of each subgraph
plt.show()
边栏推荐
- 剑指 Offer 64. 求1+2+…+n,逻辑运算符短路效应
- EN 1935建筑五金.单轴铰链—CE认证
- 请简述list,set,map类型的集合的各自特点(简述三种不同的继承方式下)
- 小程序editor富文本编辑使用及rich-text解析富文本
- 基因组 DNA 分离丨Worthington核糖核酸酶A
- 齐博建站指南(艾戈勒)
- Zero view h5s video platform getUserInfo information disclosure vulnerability cnvd-2020-67113
- Class, leetcode919 -- complete binary tree inserter
- Inspur clusterenginev4.0 remote command execution vulnerability cve-2020-21224
- 超参数优化(网格搜索和贝叶斯优化)
猜你喜欢
【自】-刷题-逻辑
My second uncle is angry and swipes the screen all over the network. How can he cure my spiritual internal friction?
xss.haozi.me靶场详解
2022年R2移动式压力容器充装考题模拟考试平台操作
MySQL introduction
SAP oracle 复制新实例后数据库远程连接报错 ora-01031
类中多函数填写,LeetCode919——完全二叉树插入器
Pycharm new project
Kingbasees client programming interface guide ODBC (4. Create data source)
Arduino uno driver universe 1.8 'TFT SPI screen example demonstration (including data package)
随机推荐
OpenCV宏定义
苹果官网正在更新维护 Apple Store,国行 iPhone 13 / Pro 等产品将最高优惠 600 元
多传感器融合定位(三)——惯性技术
The computer doesn't know what to uninstall, can't open the calculator, can't edit screenshots, can't open txt files, and so on
[data mining engineer - written examination] Dahua shares in 2022
2022 welder (Junior) work license questions and answers
pycharm配置运行环境
[self] - brush questions logic
How powerful can top "hackers" be? Internet access without signal, expert: high-end operation!
【自】-刷题-逻辑
刨根问底学 二叉树
类中多函数填写,LeetCode919——完全二叉树插入器
Fundamental inquiry binary tree
Crud of MySQL
[self] - question brushing - dynamic programming
Worthington丨Worthington胰蛋白酶抑制剂说明书
My second uncle is angry and swipes the screen all over the network. How can he cure my spiritual internal friction?
2022年R2移动式压力容器充装考题模拟考试平台操作
Mongodb index add, view, export, delete
剑指 Offer 55 - I. 二叉树的深度