当前位置:网站首页>paddle入门-使用LeNet在MNIST实现图像分类方法一
paddle入门-使用LeNet在MNIST实现图像分类方法一
2022-07-07 22:06:00 【Vertira】
MNIST数据集实现图像分类
一、配置环境
import paddle
print(paddle.__version__)
如何配置paddle 可以网上搜 ,我的博客也有 ,这里略
加载数据:方式有两种:自定义数据加载(我之前的博客有),加载paddled官网做好的数据
我们寻找第二种方式,因为方便
手写数字的MNIST数据集,包含60,000个用于训练的示例和10,000个用于测试的示例。这些数字已经过尺寸标准化并位于图像中心,图像是固定大小(28x28像素),其值为0到1。该数据集的官方地址为:http://yann.lecun.com/exdb/mnist 。
我们使用飞桨框架自带的 paddle.vision.datasets.MNIST 完成mnist数据集的加载。
from paddle.vision.transforms import Compose, Normalize
transform = Compose([Normalize(mean=[127.5],
std=[127.5],
data_format='CHW')])
# 使用transform对数据集做归一化
print('download training data and load training data')
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
print('load finished')
取训练集中的一条数据看一下。
import numpy as np
import matplotlib.pyplot as plt
train_data0, train_label_0 = train_dataset[0][0],train_dataset[0][1]
train_data0 = train_data0.reshape([28,28])
plt.figure(figsize=(2,2))
plt.imshow(train_data0, cmap=plt.cm.binary)
print('train_data0 label is: ' + str(train_label_0))
三 、组网
用paddle.nn下的API,如Conv2D、MaxPool2D、Linear完成LeNet的构建。
import paddle
import paddle.nn.functional as F
class LeNet(paddle.nn.Layer):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)
self.max_pool1 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
self.conv2 = paddle.nn.Conv2D(in_channels=6, out_channels=16, kernel_size=5, stride=1)
self.max_pool2 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
self.linear1 = paddle.nn.Linear(in_features=16*5*5, out_features=120)
self.linear2 = paddle.nn.Linear(in_features=120, out_features=84)
self.linear3 = paddle.nn.Linear(in_features=84, out_features=10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.max_pool1(x)
x = self.conv2(x)
x = F.relu(x)
x = self.max_pool2(x)
x = paddle.flatten(x, start_axis=1,stop_axis=-1)
x = self.linear1(x)
x = F.relu(x)
x = self.linear2(x)
x = F.relu(x)
x = self.linear3(x)
return x
四、方式1:基于高层API,完成模型的训练与预测
通过paddle提供的Model 构建实例,使用封装好的训练与测试接口,快速完成模型训练与测试。
4.1 使用 Model.fit来训练模型
from paddle.metric import Accuracy
model = paddle.Model(LeNet()) # 用Model封装模型
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
# 配置模型
model.prepare(
optim,
paddle.nn.CrossEntropyLoss(),
Accuracy()
)
# 训练模型
model.fit(train_dataset,
epochs=2,
batch_size=64,
verbose=1
)
训练结果
The loss value printed in the log is the current step, and the metric is the average value of previous steps.
Epoch 1/2
step 938/938 [==============================] - loss: 0.0329 - acc: 0.9399 - 10ms/step
Epoch 2/2
step 938/938 [==============================] - loss: 0.0092 - acc: 0.9798 - 10ms/step
4.2 使用 Model.evaluate 来预测模型
model.evaluate(test_dataset, batch_size=64, verbose=1)
Eval begin...
step 157/157 [==============================] - loss: 4.4728e-04 - acc: 0.9857 - 8ms/step
Eval samples: 10000
{'loss': [0.0004472804], 'acc': 0.9857}
方式一结束
以上就是方式一,可以快速、高效的完成网络模型训练与预测。
参考:
边栏推荐
- Robomaster visual tutorial (1) camera
- FFA and ICGA angiography
- Postgres timestamp to human eye time string or millisecond value
- [programming problem] [scratch Level 2] March 2019 draw a square spiral
- Emotional post station 010: things that contemporary college students should understand
- 【编程题】【Scratch二级】2019.09 绘制雪花图案
- 智慧监管入场,美团等互联网服务平台何去何从
- Traduction gratuite en un clic de plus de 300 pages de documents PDF
- Preliminary test of optical flow sensor: gl9306
- 【编程题】【Scratch二级】2019.03 绘制方形螺旋
猜你喜欢

Opengl3.3 mouse picking up objects

2022-07-07:原本数组中都是大于0、小于等于k的数字,是一个单调不减的数组, 其中可能有相等的数字,总体趋势是递增的。 但是其中有些位置的数被替换成了0,我们需要求出所有的把0替换的方案数量:
![[programming problem] [scratch Level 2] draw ten squares in December 2019](/img/4f/14ea8e786b7f8b0a263aa5c55abf15.png)
[programming problem] [scratch Level 2] draw ten squares in December 2019
![[programming problem] [scratch Level 2] December 2019 flying birds](/img/5e/a105f8615f3991635c9ffd3a8e5836.png)
[programming problem] [scratch Level 2] December 2019 flying birds

快速上手使用本地测试工具postman

80% of the people answered incorrectly. Does the leaf on the apple logo face left or right?
![[path planning] use the vertical distance limit method and Bessel to optimize the path of a star](/img/0b/e21f7ded7c854272e8cb631ff0154e.png)
[path planning] use the vertical distance limit method and Bessel to optimize the path of a star

Les mots ont été écrits, la fonction est vraiment puissante!

Go learning notes (1) environment installation and hello world

Kubectl 好用的命令行工具:oh-my-zsh 技巧和窍门
随机推荐
How did a fake offer steal $540million from "axie infinity"?
DataGuard active / standby cleanup archive settings
Sqlite数据库存储目录结构邻接表的实现2-目录树的构建
Database query - what is the highest data?
35岁真就成了职业危机?不,我的技术在积累,我还越吃越香了
Robomaster visual tutorial (0) Introduction
Using Google test in QT
【编程题】【Scratch二级】2019.12 飞翔的小鸟
[programming problem] [scratch Level 2] December 2019 flying birds
Resolve the URL of token
SQL connection problem after downloading (2)
Visual Studio Deployment Project - Create shortcut to deployed executable
在网页中打开展示pdf文件
The difference between -s and -d when downloading packages using NPM
Go learning notes (2) basic types and statements (1)
【编程题】【Scratch二级】2019.09 制作蝙蝠冲关游戏
机器人(自动化)等专业课程创新的结果
蓝桥ROS中使用fishros一键安装
Introduction knowledge system of Web front-end engineers
Install sqlserver2019