当前位置:网站首页>paddle入门-使用LeNet在MNIST实现图像分类方法一
paddle入门-使用LeNet在MNIST实现图像分类方法一
2022-07-07 22:06:00 【Vertira】
MNIST数据集实现图像分类
一、配置环境
import paddle
print(paddle.__version__)
如何配置paddle 可以网上搜 ,我的博客也有 ,这里略
加载数据:方式有两种:自定义数据加载(我之前的博客有),加载paddled官网做好的数据
我们寻找第二种方式,因为方便
手写数字的MNIST数据集,包含60,000个用于训练的示例和10,000个用于测试的示例。这些数字已经过尺寸标准化并位于图像中心,图像是固定大小(28x28像素),其值为0到1。该数据集的官方地址为:http://yann.lecun.com/exdb/mnist 。
我们使用飞桨框架自带的 paddle.vision.datasets.MNIST 完成mnist数据集的加载。
from paddle.vision.transforms import Compose, Normalize
transform = Compose([Normalize(mean=[127.5],
std=[127.5],
data_format='CHW')])
# 使用transform对数据集做归一化
print('download training data and load training data')
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
print('load finished')
取训练集中的一条数据看一下。
import numpy as np
import matplotlib.pyplot as plt
train_data0, train_label_0 = train_dataset[0][0],train_dataset[0][1]
train_data0 = train_data0.reshape([28,28])
plt.figure(figsize=(2,2))
plt.imshow(train_data0, cmap=plt.cm.binary)
print('train_data0 label is: ' + str(train_label_0))
三 、组网
用paddle.nn下的API,如Conv2D、MaxPool2D、Linear完成LeNet的构建。
import paddle
import paddle.nn.functional as F
class LeNet(paddle.nn.Layer):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)
self.max_pool1 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
self.conv2 = paddle.nn.Conv2D(in_channels=6, out_channels=16, kernel_size=5, stride=1)
self.max_pool2 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
self.linear1 = paddle.nn.Linear(in_features=16*5*5, out_features=120)
self.linear2 = paddle.nn.Linear(in_features=120, out_features=84)
self.linear3 = paddle.nn.Linear(in_features=84, out_features=10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.max_pool1(x)
x = self.conv2(x)
x = F.relu(x)
x = self.max_pool2(x)
x = paddle.flatten(x, start_axis=1,stop_axis=-1)
x = self.linear1(x)
x = F.relu(x)
x = self.linear2(x)
x = F.relu(x)
x = self.linear3(x)
return x
四、方式1:基于高层API,完成模型的训练与预测
通过paddle提供的Model 构建实例,使用封装好的训练与测试接口,快速完成模型训练与测试。
4.1 使用 Model.fit来训练模型
from paddle.metric import Accuracy
model = paddle.Model(LeNet()) # 用Model封装模型
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
# 配置模型
model.prepare(
optim,
paddle.nn.CrossEntropyLoss(),
Accuracy()
)
# 训练模型
model.fit(train_dataset,
epochs=2,
batch_size=64,
verbose=1
)
训练结果
The loss value printed in the log is the current step, and the metric is the average value of previous steps.
Epoch 1/2
step 938/938 [==============================] - loss: 0.0329 - acc: 0.9399 - 10ms/step
Epoch 2/2
step 938/938 [==============================] - loss: 0.0092 - acc: 0.9798 - 10ms/step
4.2 使用 Model.evaluate 来预测模型
model.evaluate(test_dataset, batch_size=64, verbose=1)
Eval begin...
step 157/157 [==============================] - loss: 4.4728e-04 - acc: 0.9857 - 8ms/step
Eval samples: 10000
{'loss': [0.0004472804], 'acc': 0.9857}
方式一结束
以上就是方式一,可以快速、高效的完成网络模型训练与预测。
参考:
边栏推荐
- Relevant methods of sorting arrays in JS (if you want to understand arrays, it's enough to read this article)
- Enterprise application demand-oriented development of human resources department, employee attendance records and paid wages business process cases
- Binary sort tree [BST] - create, find, delete, output
- STM32F1與STM32CubeIDE編程實例-旋轉編碼器驅動
- 2022-07-07:原本数组中都是大于0、小于等于k的数字,是一个单调不减的数组, 其中可能有相等的数字,总体趋势是递增的。 但是其中有些位置的数被替换成了0,我们需要求出所有的把0替换的方案数量:
- Two small problems in creating user registration interface
- 面试题详解:用Redis实现分布式锁的血泪史
- 【leetcode】day1
- 关于组织2021-2022全国青少年电子信息智能创新大赛西南赛区(四川)复赛的通知
- Seven years' experience of a test engineer -- to you who walk alone all the way (don't give up)
猜你喜欢

80% of the people answered incorrectly. Does the leaf on the apple logo face left or right?

Two small problems in creating user registration interface

SQL knowledge summary 004: Postgres terminal command summary

About the difference between ch32 library function and STM32 library function

Restricted linear table

Seven years' experience of a test engineer -- to you who walk alone all the way (don't give up)

机器人(自动化)等专业课程创新的结果

BSS 7230 flame retardant performance test of aviation interior materials

One click installation with fishros in blue bridge ROS

QT and OpenGL: load 3D models using the open asset import library (assimp)
随机推荐
Tools for debugging makefiles - tool for debugging makefiles
SQL knowledge summary 004: Postgres terminal command summary
Restricted linear table
Common selectors are
One click installation with fishros in blue bridge ROS
Data analysis series 3 σ Rule / eliminate outliers according to laida criterion
Pigsty:开箱即用的数据库发行版
80% of the people answered incorrectly. Does the leaf on the apple logo face left or right?
Relevant methods of sorting arrays in JS (if you want to understand arrays, it's enough to read this article)
第四期SFO销毁,Starfish OS如何对SFO价值赋能?
数据湖(十五):Spark与Iceberg整合写操作
快速上手使用本地测试工具postman
The result of innovation in professional courses such as robotics (Automation)
Stm32f1 and stm32cubeide programming example - rotary encoder drive
【编程题】【Scratch二级】2019.09 绘制雪花图案
The function is really powerful!
Binary sort tree [BST] - create, find, delete, output
An example analysis of MP4 file format parsing
C - linear table
Les mots ont été écrits, la fonction est vraiment puissante!