当前位置:网站首页>paddle入门-使用LeNet在MNIST实现图像分类方法一
paddle入门-使用LeNet在MNIST实现图像分类方法一
2022-07-07 22:06:00 【Vertira】
MNIST数据集实现图像分类
一、配置环境
import paddle
print(paddle.__version__)
如何配置paddle 可以网上搜 ,我的博客也有 ,这里略
加载数据:方式有两种:自定义数据加载(我之前的博客有),加载paddled官网做好的数据
我们寻找第二种方式,因为方便
手写数字的MNIST数据集,包含60,000个用于训练的示例和10,000个用于测试的示例。这些数字已经过尺寸标准化并位于图像中心,图像是固定大小(28x28像素),其值为0到1。该数据集的官方地址为:http://yann.lecun.com/exdb/mnist 。
我们使用飞桨框架自带的 paddle.vision.datasets.MNIST 完成mnist数据集的加载。
from paddle.vision.transforms import Compose, Normalize
transform = Compose([Normalize(mean=[127.5],
std=[127.5],
data_format='CHW')])
# 使用transform对数据集做归一化
print('download training data and load training data')
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
print('load finished')
取训练集中的一条数据看一下。
import numpy as np
import matplotlib.pyplot as plt
train_data0, train_label_0 = train_dataset[0][0],train_dataset[0][1]
train_data0 = train_data0.reshape([28,28])
plt.figure(figsize=(2,2))
plt.imshow(train_data0, cmap=plt.cm.binary)
print('train_data0 label is: ' + str(train_label_0))
三 、组网
用paddle.nn下的API,如Conv2D、MaxPool2D、Linear完成LeNet的构建。
import paddle
import paddle.nn.functional as F
class LeNet(paddle.nn.Layer):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)
self.max_pool1 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
self.conv2 = paddle.nn.Conv2D(in_channels=6, out_channels=16, kernel_size=5, stride=1)
self.max_pool2 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
self.linear1 = paddle.nn.Linear(in_features=16*5*5, out_features=120)
self.linear2 = paddle.nn.Linear(in_features=120, out_features=84)
self.linear3 = paddle.nn.Linear(in_features=84, out_features=10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.max_pool1(x)
x = self.conv2(x)
x = F.relu(x)
x = self.max_pool2(x)
x = paddle.flatten(x, start_axis=1,stop_axis=-1)
x = self.linear1(x)
x = F.relu(x)
x = self.linear2(x)
x = F.relu(x)
x = self.linear3(x)
return x
四、方式1:基于高层API,完成模型的训练与预测
通过paddle提供的Model 构建实例,使用封装好的训练与测试接口,快速完成模型训练与测试。
4.1 使用 Model.fit来训练模型
from paddle.metric import Accuracy
model = paddle.Model(LeNet()) # 用Model封装模型
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
# 配置模型
model.prepare(
optim,
paddle.nn.CrossEntropyLoss(),
Accuracy()
)
# 训练模型
model.fit(train_dataset,
epochs=2,
batch_size=64,
verbose=1
)
训练结果
The loss value printed in the log is the current step, and the metric is the average value of previous steps.
Epoch 1/2
step 938/938 [==============================] - loss: 0.0329 - acc: 0.9399 - 10ms/step
Epoch 2/2
step 938/938 [==============================] - loss: 0.0092 - acc: 0.9798 - 10ms/step
4.2 使用 Model.evaluate 来预测模型
model.evaluate(test_dataset, batch_size=64, verbose=1)
Eval begin...
step 157/157 [==============================] - loss: 4.4728e-04 - acc: 0.9857 - 8ms/step
Eval samples: 10000
{'loss': [0.0004472804], 'acc': 0.9857}
方式一结束
以上就是方式一,可以快速、高效的完成网络模型训练与预测。
参考:
边栏推荐
- Go learning notes (1) environment installation and hello world
- [leetcode] 20. Valid brackets
- Seven years' experience of a test engineer -- to you who walk alone all the way (don't give up)
- Set up personal network disk with nextcloud
- Les mots ont été écrits, la fonction est vraiment puissante!
- Archery installation test
- 单机高并发模型设计
- How to put recyclerview in nestedscrollview- How to put RecyclerView inside NestedScrollView?
- 【编程题】【Scratch二级】2019.09 制作蝙蝠冲关游戏
- Restricted linear table
猜你喜欢

串联二极管,提高耐压

When creating body middleware, express Is there any difference between setting extended to true and false in urlencoded?

Is 35 really a career crisis? No, my skills are accumulating, and the more I eat, the better

Chisel tutorial - 03 Combinatorial logic in chisel (chisel3 cheat sheet is attached at the end)

快速回复二极管整流特性
![[programming problem] [scratch Level 2] March 2019 draw a square spiral](/img/fa/ae9dabdd36ba77b1f4644dd23bee93.png)
[programming problem] [scratch Level 2] March 2019 draw a square spiral
![[basis of recommendation system] sampling and construction of positive and negative samples](/img/4b/753a61b583cf38826b597fd31e5d20.png)
[basis of recommendation system] sampling and construction of positive and negative samples

Trust orbtk development issues 2022

Install sqlserver2019

Kubectl's handy command line tool: Oh my Zsh tips and tricks
随机推荐
腾讯安全发布《BOT管理白皮书》|解读BOT攻击,探索防护之道
At the age of 35, I made a decision to face unemployment
Basic learning of SQL Server -- creating databases and tables with code
QT and OpenGL: load 3D models using the open asset import library (assimp)
Pypharm uses, and the third-party library has errors due to version problems
CoinDesk评波场去中心化进程:让人们看到互联网的未来
Data analysis series 3 σ Rule / eliminate outliers according to laida criterion
Anaconda+pycharm+pyqt5 configuration problem: pyuic5 cannot be found exe
Detailed explanation of interview questions: the history of blood and tears in implementing distributed locks with redis
全自动化处理每月缺卡数据,输出缺卡人员信息
自动化测试:Robot FrameWork框架90%的人都想知道的实用技巧
Chisel tutorial - 05 Sequential logic in chisel (including explicit multi clock, explicit synchronous reset and explicit asynchronous reset)
某马旅游网站开发(登录注册退出功能的实现)
Is it safe to buy funds online?
Two small problems in creating user registration interface
The difference between get and post
串联二极管,提高耐压
QT creator add JSON based Wizard
About the difference between ch32 library function and STM32 library function
Use filters to count URL request time