当前位置:网站首页>paddle入门-使用LeNet在MNIST实现图像分类方法一
paddle入门-使用LeNet在MNIST实现图像分类方法一
2022-07-07 22:06:00 【Vertira】
MNIST数据集实现图像分类
一、配置环境
import paddle
print(paddle.__version__)
如何配置paddle 可以网上搜 ,我的博客也有 ,这里略
加载数据:方式有两种:自定义数据加载(我之前的博客有),加载paddled官网做好的数据
我们寻找第二种方式,因为方便
手写数字的MNIST数据集,包含60,000个用于训练的示例和10,000个用于测试的示例。这些数字已经过尺寸标准化并位于图像中心,图像是固定大小(28x28像素),其值为0到1。该数据集的官方地址为:http://yann.lecun.com/exdb/mnist 。
我们使用飞桨框架自带的 paddle.vision.datasets.MNIST 完成mnist数据集的加载。
from paddle.vision.transforms import Compose, Normalize
transform = Compose([Normalize(mean=[127.5],
std=[127.5],
data_format='CHW')])
# 使用transform对数据集做归一化
print('download training data and load training data')
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
print('load finished')
取训练集中的一条数据看一下。
import numpy as np
import matplotlib.pyplot as plt
train_data0, train_label_0 = train_dataset[0][0],train_dataset[0][1]
train_data0 = train_data0.reshape([28,28])
plt.figure(figsize=(2,2))
plt.imshow(train_data0, cmap=plt.cm.binary)
print('train_data0 label is: ' + str(train_label_0))
三 、组网
用paddle.nn下的API,如Conv2D、MaxPool2D、Linear完成LeNet的构建。
import paddle
import paddle.nn.functional as F
class LeNet(paddle.nn.Layer):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)
self.max_pool1 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
self.conv2 = paddle.nn.Conv2D(in_channels=6, out_channels=16, kernel_size=5, stride=1)
self.max_pool2 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
self.linear1 = paddle.nn.Linear(in_features=16*5*5, out_features=120)
self.linear2 = paddle.nn.Linear(in_features=120, out_features=84)
self.linear3 = paddle.nn.Linear(in_features=84, out_features=10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.max_pool1(x)
x = self.conv2(x)
x = F.relu(x)
x = self.max_pool2(x)
x = paddle.flatten(x, start_axis=1,stop_axis=-1)
x = self.linear1(x)
x = F.relu(x)
x = self.linear2(x)
x = F.relu(x)
x = self.linear3(x)
return x
四、方式1:基于高层API,完成模型的训练与预测
通过paddle提供的Model 构建实例,使用封装好的训练与测试接口,快速完成模型训练与测试。
4.1 使用 Model.fit来训练模型
from paddle.metric import Accuracy
model = paddle.Model(LeNet()) # 用Model封装模型
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
# 配置模型
model.prepare(
optim,
paddle.nn.CrossEntropyLoss(),
Accuracy()
)
# 训练模型
model.fit(train_dataset,
epochs=2,
batch_size=64,
verbose=1
)
训练结果
The loss value printed in the log is the current step, and the metric is the average value of previous steps.
Epoch 1/2
step 938/938 [==============================] - loss: 0.0329 - acc: 0.9399 - 10ms/step
Epoch 2/2
step 938/938 [==============================] - loss: 0.0092 - acc: 0.9798 - 10ms/step
4.2 使用 Model.evaluate 来预测模型
model.evaluate(test_dataset, batch_size=64, verbose=1)
Eval begin...
step 157/157 [==============================] - loss: 4.4728e-04 - acc: 0.9857 - 8ms/step
Eval samples: 10000
{'loss': [0.0004472804], 'acc': 0.9857}
方式一结束
以上就是方式一,可以快速、高效的完成网络模型训练与预测。
参考:
边栏推荐
- FFA and ICGA angiography
- Data analysis series 3 σ Rule / eliminate outliers according to laida criterion
- 【编程题】【Scratch二级】2019.09 绘制雪花图案
- Codeworks 5 questions per day (average 1500) - day 8
- Preliminary test of optical flow sensor: gl9306
- 自动化测试:Robot FrameWork框架90%的人都想知道的实用技巧
- Seven years' experience of a test engineer -- to you who walk alone all the way (don't give up)
- STM32F1與STM32CubeIDE編程實例-旋轉編碼器驅動
- 【leetcode】day1
- AWS AWS help error
猜你喜欢

某马旅游网站开发(登录注册退出功能的实现)

How to measure whether the product is "just needed, high frequency, pain points"

Kubectl's handy command line tool: Oh my Zsh tips and tricks

ROS从入门到精通(九) 可视化仿真初体验之TurtleBot3

Preliminary test of optical flow sensor: gl9306

Chisel tutorial - 05 Sequential logic in chisel (including explicit multi clock, explicit synchronous reset and explicit asynchronous reset)

Stm32f1 and stm32cubeide programming example - rotary encoder drive

Automated testing: robot framework is a practical skill that 90% of people want to know

Traduction gratuite en un clic de plus de 300 pages de documents PDF

Laser slam learning (2d/3d, partial practice)
随机推荐
When creating body middleware, express Is there any difference between setting extended to true and false in urlencoded?
Resolve the URL of token
Pigsty: out of the box database distribution
webflux - webclient Connect reset by peer Error
Rectification characteristics of fast recovery diode
52岁的周鸿祎,还年轻吗?
Open display PDF file in web page
SQL uses the in keyword to query multiple fields
Cmake learning notes (1) compile single source programs with cmake
The difference between get and post
Traduction gratuite en un clic de plus de 300 pages de documents PDF
单机高并发模型设计
Robomaster visual tutorial (11) summary
35岁真就成了职业危机?不,我的技术在积累,我还越吃越香了
How did a fake offer steal $540million from "axie infinity"?
Usage of limit and offset (Reprint)
Basic learning of SQL Server -- creating databases and tables with the mouse
Use filters to count URL request time
ROS从入门到精通(九) 可视化仿真初体验之TurtleBot3
Go time package common functions