当前位置:网站首页>paddle入门-使用LeNet在MNIST实现图像分类方法一
paddle入门-使用LeNet在MNIST实现图像分类方法一
2022-07-07 22:06:00 【Vertira】
MNIST数据集实现图像分类
一、配置环境
import paddle
print(paddle.__version__)
如何配置paddle 可以网上搜 ,我的博客也有 ,这里略
加载数据:方式有两种:自定义数据加载(我之前的博客有),加载paddled官网做好的数据
我们寻找第二种方式,因为方便
手写数字的MNIST数据集,包含60,000个用于训练的示例和10,000个用于测试的示例。这些数字已经过尺寸标准化并位于图像中心,图像是固定大小(28x28像素),其值为0到1。该数据集的官方地址为:http://yann.lecun.com/exdb/mnist 。
我们使用飞桨框架自带的 paddle.vision.datasets.MNIST
完成mnist数据集的加载。
from paddle.vision.transforms import Compose, Normalize
transform = Compose([Normalize(mean=[127.5],
std=[127.5],
data_format='CHW')])
# 使用transform对数据集做归一化
print('download training data and load training data')
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
print('load finished')
取训练集中的一条数据看一下。
import numpy as np
import matplotlib.pyplot as plt
train_data0, train_label_0 = train_dataset[0][0],train_dataset[0][1]
train_data0 = train_data0.reshape([28,28])
plt.figure(figsize=(2,2))
plt.imshow(train_data0, cmap=plt.cm.binary)
print('train_data0 label is: ' + str(train_label_0))
三 、组网
用paddle.nn下的API,如Conv2D
、MaxPool2D
、Linear
完成LeNet的构建。
import paddle
import paddle.nn.functional as F
class LeNet(paddle.nn.Layer):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)
self.max_pool1 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
self.conv2 = paddle.nn.Conv2D(in_channels=6, out_channels=16, kernel_size=5, stride=1)
self.max_pool2 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
self.linear1 = paddle.nn.Linear(in_features=16*5*5, out_features=120)
self.linear2 = paddle.nn.Linear(in_features=120, out_features=84)
self.linear3 = paddle.nn.Linear(in_features=84, out_features=10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.max_pool1(x)
x = self.conv2(x)
x = F.relu(x)
x = self.max_pool2(x)
x = paddle.flatten(x, start_axis=1,stop_axis=-1)
x = self.linear1(x)
x = F.relu(x)
x = self.linear2(x)
x = F.relu(x)
x = self.linear3(x)
return x
四、方式1:基于高层API,完成模型的训练与预测
通过paddle提供的Model
构建实例,使用封装好的训练与测试接口,快速完成模型训练与测试。
4.1 使用 Model.fit
来训练模型
from paddle.metric import Accuracy
model = paddle.Model(LeNet()) # 用Model封装模型
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
# 配置模型
model.prepare(
optim,
paddle.nn.CrossEntropyLoss(),
Accuracy()
)
# 训练模型
model.fit(train_dataset,
epochs=2,
batch_size=64,
verbose=1
)
训练结果
The loss value printed in the log is the current step, and the metric is the average value of previous steps.
Epoch 1/2
step 938/938 [==============================] - loss: 0.0329 - acc: 0.9399 - 10ms/step
Epoch 2/2
step 938/938 [==============================] - loss: 0.0092 - acc: 0.9798 - 10ms/step
4.2 使用 Model.evaluate
来预测模型
model.evaluate(test_dataset, batch_size=64, verbose=1)
Eval begin...
step 157/157 [==============================] - loss: 4.4728e-04 - acc: 0.9857 - 8ms/step
Eval samples: 10000
{'loss': [0.0004472804], 'acc': 0.9857}
方式一结束
以上就是方式一,可以快速、高效的完成网络模型训练与预测。
参考:
边栏推荐
- LinkedBlockingQueue源码分析-新增和删除
- SQL connection problem after downloading (2)
- Redis caching tool class, worth owning~
- 手写一个模拟的ReentrantLock
- Set up personal network disk with nextcloud
- Problems faced when connecting to sqlserver after downloading (I)
- Enterprise application demand-oriented development of human resources department, employee attendance records and paid wages business process cases
- Pigsty: out of the box database distribution
- 【编程题】【Scratch二级】2019.09 制作蝙蝠冲关游戏
- Data Lake (XV): spark and iceberg integrate write operations
猜你喜欢
Solutions to problems in sqlserver deleting data in tables
Basic learning of SQL Server -- creating databases and tables with code
How does starfish OS enable the value of SFO in the fourth phase of SFO destruction?
Traduction gratuite en un clic de plus de 300 pages de documents PDF
[programming problem] [scratch Level 2] March 2019 draw a square spiral
Stm32f1 and stm32cubeide programming example - rotary encoder drive
Binary sort tree [BST] - create, find, delete, output
一个测试工程师的7年感悟 ---- 致在一路独行的你(别放弃)
QT creator add JSON based Wizard
QT and OpenGL: load 3D models using the open asset import library (assimp)
随机推荐
【史上最详细】信贷中逾期天数统计说明
商品的设计等整个生命周期,都可以将其纳入到产业互联网的范畴内
Robomaster visual tutorial (0) Introduction
redis你到底懂不懂之list
某马旅游网站开发(登录注册退出功能的实现)
Stm32f1 and stm32cubeide programming example - rotary encoder drive
Daily question brushing record (16)
Usage of limit and offset (Reprint)
Anaconda+pycharm+pyqt5 configuration problem: pyuic5 cannot be found exe
Emotional post station 010: things that contemporary college students should understand
Uic564-2 Appendix 4 - flame retardant fire test: flame diffusion
Kubectl 好用的命令行工具:oh-my-zsh 技巧和窍门
SQL uses the in keyword to query multiple fields
Two small problems in creating user registration interface
Postgres timestamp to human eye time string or millisecond value
【编程题】【Scratch二级】2019.12 飞翔的小鸟
串联二极管,提高耐压
Chisel tutorial - 04 Control flow in chisel
Is it safe for tongdaxin to buy funds?
【编程题】【Scratch二级】2019.03 绘制方形螺旋