当前位置:网站首页>paddle入门-使用LeNet在MNIST实现图像分类方法一
paddle入门-使用LeNet在MNIST实现图像分类方法一
2022-07-07 22:06:00 【Vertira】
MNIST数据集实现图像分类
一、配置环境
import paddle
print(paddle.__version__)
如何配置paddle 可以网上搜 ,我的博客也有 ,这里略
加载数据:方式有两种:自定义数据加载(我之前的博客有),加载paddled官网做好的数据
我们寻找第二种方式,因为方便
手写数字的MNIST数据集,包含60,000个用于训练的示例和10,000个用于测试的示例。这些数字已经过尺寸标准化并位于图像中心,图像是固定大小(28x28像素),其值为0到1。该数据集的官方地址为:http://yann.lecun.com/exdb/mnist 。
我们使用飞桨框架自带的 paddle.vision.datasets.MNIST 完成mnist数据集的加载。
from paddle.vision.transforms import Compose, Normalize
transform = Compose([Normalize(mean=[127.5],
std=[127.5],
data_format='CHW')])
# 使用transform对数据集做归一化
print('download training data and load training data')
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
print('load finished')
取训练集中的一条数据看一下。
import numpy as np
import matplotlib.pyplot as plt
train_data0, train_label_0 = train_dataset[0][0],train_dataset[0][1]
train_data0 = train_data0.reshape([28,28])
plt.figure(figsize=(2,2))
plt.imshow(train_data0, cmap=plt.cm.binary)
print('train_data0 label is: ' + str(train_label_0))
三 、组网
用paddle.nn下的API,如Conv2D、MaxPool2D、Linear完成LeNet的构建。
import paddle
import paddle.nn.functional as F
class LeNet(paddle.nn.Layer):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)
self.max_pool1 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
self.conv2 = paddle.nn.Conv2D(in_channels=6, out_channels=16, kernel_size=5, stride=1)
self.max_pool2 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
self.linear1 = paddle.nn.Linear(in_features=16*5*5, out_features=120)
self.linear2 = paddle.nn.Linear(in_features=120, out_features=84)
self.linear3 = paddle.nn.Linear(in_features=84, out_features=10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.max_pool1(x)
x = self.conv2(x)
x = F.relu(x)
x = self.max_pool2(x)
x = paddle.flatten(x, start_axis=1,stop_axis=-1)
x = self.linear1(x)
x = F.relu(x)
x = self.linear2(x)
x = F.relu(x)
x = self.linear3(x)
return x
四、方式1:基于高层API,完成模型的训练与预测
通过paddle提供的Model 构建实例,使用封装好的训练与测试接口,快速完成模型训练与测试。
4.1 使用 Model.fit来训练模型
from paddle.metric import Accuracy
model = paddle.Model(LeNet()) # 用Model封装模型
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
# 配置模型
model.prepare(
optim,
paddle.nn.CrossEntropyLoss(),
Accuracy()
)
# 训练模型
model.fit(train_dataset,
epochs=2,
batch_size=64,
verbose=1
)
训练结果
The loss value printed in the log is the current step, and the metric is the average value of previous steps.
Epoch 1/2
step 938/938 [==============================] - loss: 0.0329 - acc: 0.9399 - 10ms/step
Epoch 2/2
step 938/938 [==============================] - loss: 0.0092 - acc: 0.9798 - 10ms/step
4.2 使用 Model.evaluate 来预测模型
model.evaluate(test_dataset, batch_size=64, verbose=1)
Eval begin...
step 157/157 [==============================] - loss: 4.4728e-04 - acc: 0.9857 - 8ms/step
Eval samples: 10000
{'loss': [0.0004472804], 'acc': 0.9857}
方式一结束
以上就是方式一,可以快速、高效的完成网络模型训练与预测。
参考:
边栏推荐
- ROS from entry to mastery (IX) initial experience of visual simulation: turtlebot3
- SQL uses the in keyword to query multiple fields
- 某马旅游网站开发(对servlet的优化)
- [path planning] use the vertical distance limit method and Bessel to optimize the path of a star
- 蓝桥ROS中使用fishros一键安装
- 【编程题】【Scratch二级】2019.12 绘制十个正方形
- Postgres timestamp to human eye time string or millisecond value
- Solutions to problems in sqlserver deleting data in tables
- QT and OpenGL: load 3D models using the open asset import library (assimp)
- FFA and ICGA angiography
猜你喜欢

Stm32f1 and stm32cubeide programming example - rotary encoder drive

How does starfish OS enable the value of SFO in the fourth phase of SFO destruction?

The result of innovation in professional courses such as robotics (Automation)

Chisel tutorial - 03 Combinatorial logic in chisel (chisel3 cheat sheet is attached at the end)

用語雀寫文章了,功能真心强大!

One click free translation of more than 300 pages of PDF documents

Set up personal network disk with nextcloud

Detailed explanation of interview questions: the history of blood and tears in implementing distributed locks with redis

Go learning notes (1) environment installation and hello world

redis你到底懂不懂之list
随机推荐
Database interview questions + analysis
52歲的周鴻禕,還年輕嗎?
全自动化处理每月缺卡数据,输出缺卡人员信息
快速上手使用本地测试工具postman
Connect diodes in series to improve voltage withstand
When creating body middleware, express Is there any difference between setting extended to true and false in urlencoded?
Go learning notes (1) environment installation and hello world
The difference between -s and -d when downloading packages using NPM
[programming problem] [scratch Level 2] 2019.09 make bat Challenge Game
ROS from entry to mastery (IX) initial experience of visual simulation: turtlebot3
第四期SFO销毁,Starfish OS如何对SFO价值赋能?
STM32F1與STM32CubeIDE編程實例-旋轉編碼器驅動
Is it safe for tongdaxin to buy funds?
C language 005: common examples
关于组织2021-2022全国青少年电子信息智能创新大赛西南赛区(四川)复赛的通知
Two small problems in creating user registration interface
[programming problem] [scratch Level 2] draw ten squares in December 2019
Redis caching tool class, worth owning~
Is 35 really a career crisis? No, my skills are accumulating, and the more I eat, the better
C - linear table