当前位置:网站首页>我的NVIDIA开发者之旅-Jetson Nano 2gb教你怎么训练模型(完整的模型训练套路)
我的NVIDIA开发者之旅-Jetson Nano 2gb教你怎么训练模型(完整的模型训练套路)
2022-06-28 12:07:00 【无证驾驶梁嗖嗖】
我的NVIDIA开发者之旅” | 征文活动进行中.......
模型的保存和加载
pytorch的安装方法这里就不写了,之前的文章有记录,nvidia官网的资料已经很详细了附上连接(注意你的Jetpack版本就好了,一般玄学的问题都出现在这里)

安装pychrom的方法之前的文章也有完整的记录,基本环境也就是这些。今天用到的也就是pytorch。
神经网络的训练一般要进行的步骤:
加载数据集,并做预处理。
预处理后的数据分为 feature 和 label 两部分,feature 送到模型里面,label 被当做 ground-truth。
model 接收 feature 作为 input,并通过一系列运算,向外输出 predict。
通过以 predict 和 predict 为变量,建立一个损失函数 Loss,Loss 的函数值是为了表示 predict 与 ground-truth 之间的差距。
建立 Optimizer 优化器,优化的目标就是 Loss 函数,让它的取值尽可能最小,loss 越小代表 Model 预测的准确率越高。
Optimizer 优化过程中,Model 根据规则改变自身参数的权重,这是个反复循环和持续的过程,直到 loss 值趋于稳定,不能在取得更小值。

CIFAR-10 和 CIFAR-100 是 8000 万张微小图像数据集的标记子集。它们由Alex Krizhevsky,Vinod Nair和Geoffrey Hinton收集。
CIFAR-10 数据集
CIFAR-10 数据集由 10 个类中的 60000 张 32x32 彩色图像组成,每个类包含 6000 张图像。有 50000 张训练图像和 10000 张测试图像。
数据集分为五个训练批次和一个测试批次,每个批次包含 10000 张图像。测试批次正好包含从每个类中随机选择的 1000 张图像。训练批次包含随机顺序的剩余图像,但某些训练批次可能包含来自一个类的图像多于另一个类的图像。在它们之间,训练批次正好包含来自每个类的 5000 张图像。
运行代码后,会自动下载数据集,并存放在当前目录下的 data 文件中。
# 这是一个示例 Python 脚本。
# 按 Shift+F10 执行或将其替换为您的代码。
# 按 Double Shift 在所有地方搜索类、文件、工具窗口、操作和设置。
def print_hi(name):
# 在下面的代码行中使用断点来调试脚本。
print(f'Hi, {name}') # 按 Ctrl+F8 切换断点。
# 按间距中的绿色按钮以运行脚本。
if __name__ == '__main__':
print_hi('PyCharm')
# 访问 https://www.jetbrains.com/help/pycharm/ 获取 PyCharm 帮助
import torch
from torch.utils.data import DataLoader
#import torch.nn as nn
from torch import nn
from model import *
torch.cuda.is_available()
print('CUDA available: ' + str(torch.cuda.is_available()))
a = torch.cuda.FloatTensor(2).zero_()
print('Tensor a = ' + str(a))
b = torch.randn(2).cuda()
print('Tensor b = ' + str(b))
c = a + b
print('Tensor c = ' + str(c))
import torchvision
train_data = torchvision.datasets.CIFAR10("root=../data",train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_date = torchvision.datasets.CIFAR10("root=../data",train=False,transform=torchvision.transforms.ToTensor(),download=True)
trian_data_size = len(train_data)
test_data_size = len(test_date)
#train_data_size=10,xunlianshujujide changduwei10
print("xunlianshujujjidechangduwei:{}".format(trian_data_size))
print("ceshishujujjidechangduwei:{}".format(test_data_size))
#liyongdataloader laijiazaishujuji
train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_date,batch_size=64)
#Build neural network
tudui = Tudui()
# sunshihanshu
loss_fn = nn.CrossEntropyLoss()
#youhuaqi
# learning_rate = 0.01
#1e-2=1x (10)^(-2) = 1/100 =0.01
learning_rate = 1e-2
optimizer = torch.optim.SGD(tudui.parameters(),lr=learning_rate)
#shezhixunliandechishu
#jiluxunliandechichu
total_train_setp = 0
##jiluceshide cichu
total_test_step = 0
# xunlianndelunsh
epoch = 10
for i in range(epoch):
print("---------di{}lunxunlianstart".format(i+1))
#xunliankaishi
for data in train_dataloader:
imgs, targets = data
outputs = tudui(imgs)
loss = loss_fn(outputs,targets)
#youhuaqijianmoxing
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_train_setp = total_train_setp + 1
print("xunliancishu:{},loss:{}".format(total_test_step, loss.item()))
epoch 和 iteration 的区别,iteration 指的是单次 mini-batch 训练,而 epoch 和数据集的大小还有 batch size 有关。
CIFAR-10 训练集图片数量是 50000,batch size 的大小是 100,所以要经过 500 次 iteration 才算走完一个 epoch。
epoch 可以大致当成神经网络把训练集所有的照片从头看到尾都过一遍。
运行结果如下图。

[email protected]:~/PycharmProjects/pythonProject$ cat model.py
import torch
from torch import nn
class Tudui(nn.Module):
def __init__(self):
super(Tudui, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(3, 32, 5, 1, 2),
nn.MaxPool2d(2),
nn.Conv2d(32, 32, 5, 1, 2),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 5, 1, 2),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(64 * 4 * 4, 64),
nn.Linear(64, 10)
)
def forward(self, x):
x = self.model(x)
return x
if __name__ == '__main__':
tudui = Tudui()
input = torch.ones((64, 3, 32, 32))
output = tudui(input)
print(output.shape)
Jetson nano用来测试这个是很吃力的,希望各位同学还是换个性能强悍的nx或者orin试试。
边栏推荐
- [vi/vim] basic usage and command summary
- Convert black mask picture to color annotation file
- Setting overridesorting for canvas does not take effect
- Django -- MySQL database reflects the mapping data model to models
- Function and principle of remoteviews
- Privilege management of vivo mobile phone
- Build your own website (18)
- 【JS】斐波那契数列实现(递归与循环)
- IDEA全局搜索快捷设置
- 内部振荡器、无源晶振、有源晶振有什么区别?
猜你喜欢

【Unity编辑器扩展基础】、EditorGUILayout(二)

Share the easy-to-use fastadmin open source system - practical part

【vi/vim】基本使用及命令汇总

【C语言】二叉树的实现及三种遍历

Tips for using ugui (V) using scroll rect component

ArrayList源码解析

Redis 原理 - List
![[C language] use of nested secondary pointer of structure](/img/59/8b61805431e152995c250f6dd08e29.png)
[C language] use of nested secondary pointer of structure

【北京航空航天大学】考研初试复试资料分享

【C语言】关于scanf()与scanf_s()的一些问题
随机推荐
RemoteViews的作用及原理
Unity加载设置:Application.backgroundLoadingPriority
Prefix and (one dimension)
零基础C语言(一)
【经验分享】Django开发中常用到的数据库操作总结
AcWing 605. Simple product (implemented in C language)
Leetcode 48. 旋转图像(可以,已解决)
C语言 sprintf函数使用详解
In less than an hour, apple destroyed 15 startups
【Unity编辑器扩展基础】、GUI
EMC RS485接口EMC电路设计方案
Map sorting tool class
js 期约与异步函数 Promise
[unity Editor Extension Foundation], editorguilayout (I)
Is it feasible to be a programmer at the age of 26?
Function and principle of remoteviews
Setting overridesorting for canvas does not take effect
. Net hybrid development solution 24 webview2's superior advantages over cefsharp
【Unity编辑器扩展实践】、利用txt模板动态生成UI代码
Unity Editor Extension Foundation, editorguilayout (II)