当前位置:网站首页>Introduction to paddle - using lenet to realize image classification method I in MNIST
Introduction to paddle - using lenet to realize image classification method I in MNIST
2022-07-08 00:17:00 【Vertira】
MNIST The data set realizes image classification
One 、 Configuration environment
import paddle
print(paddle.__version__)
How to configure paddle You can search online , My blog also has , Here is a little
Load data : There are two ways : Custom data loading ( My previous blog has ), load paddled Data prepared on the official website
We are looking for the second way , Because it's convenient
Handwritten numbers MNIST Data sets , contain 60,000 Examples and for training 10,000 An example for testing . These numbers have been dimensioned and located in the center of the image , The image is a fixed size (28x28 Pixels ), Its value is 0 To 1. The official address of the data set is :http://yann.lecun.com/exdb/mnist .
We use the built-in... Of the propeller frame paddle.vision.datasets.MNIST
complete mnist Data set loading .
from paddle.vision.transforms import Compose, Normalize
transform = Compose([Normalize(mean=[127.5],
std=[127.5],
data_format='CHW')])
# Use transform Normalize the data set
print('download training data and load training data')
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
print('load finished')
Take a piece of data from the training set and have a look .
import numpy as np
import matplotlib.pyplot as plt
train_data0, train_label_0 = train_dataset[0][0],train_dataset[0][1]
train_data0 = train_data0.reshape([28,28])
plt.figure(figsize=(2,2))
plt.imshow(train_data0, cmap=plt.cm.binary)
print('train_data0 label is: ' + str(train_label_0))
3、 ... and 、 networking
use paddle.nn Under the API, Such as Conv2D
、MaxPool2D
、Linear
complete LeNet The construction of .
import paddle
import paddle.nn.functional as F
class LeNet(paddle.nn.Layer):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)
self.max_pool1 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
self.conv2 = paddle.nn.Conv2D(in_channels=6, out_channels=16, kernel_size=5, stride=1)
self.max_pool2 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
self.linear1 = paddle.nn.Linear(in_features=16*5*5, out_features=120)
self.linear2 = paddle.nn.Linear(in_features=120, out_features=84)
self.linear3 = paddle.nn.Linear(in_features=84, out_features=10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.max_pool1(x)
x = self.conv2(x)
x = F.relu(x)
x = self.max_pool2(x)
x = paddle.flatten(x, start_axis=1,stop_axis=-1)
x = self.linear1(x)
x = F.relu(x)
x = self.linear2(x)
x = F.relu(x)
x = self.linear3(x)
return x
Four 、 The way 1: Based on high-level API, Complete the training and prediction of the model
adopt paddle Provided Model
Build instance , Use the encapsulated training and test interface , Quickly complete model training and testing .
4.1 Use Model.fit
To train the model
from paddle.metric import Accuracy
model = paddle.Model(LeNet()) # use Model Packaging model
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
# Configuration model
model.prepare(
optim,
paddle.nn.CrossEntropyLoss(),
Accuracy()
)
# Training models
model.fit(train_dataset,
epochs=2,
batch_size=64,
verbose=1
)
Training results
The loss value printed in the log is the current step, and the metric is the average value of previous steps.
Epoch 1/2
step 938/938 [==============================] - loss: 0.0329 - acc: 0.9399 - 10ms/step
Epoch 2/2
step 938/938 [==============================] - loss: 0.0092 - acc: 0.9798 - 10ms/step
4.2 Use Model.evaluate
To predict the model
model.evaluate(test_dataset, batch_size=64, verbose=1)
Eval begin...
step 157/157 [==============================] - loss: 4.4728e-04 - acc: 0.9857 - 8ms/step
Eval samples: 10000
{'loss': [0.0004472804], 'acc': 0.9857}
Way one ends
That's way one , Can quickly 、 Efficiently complete network model training and prediction .
Reference resources :
边栏推荐
- QT creator add custom new file / Project Template Wizard
- The difference between -s and -d when downloading packages using NPM
- ROS从入门到精通(九) 可视化仿真初体验之TurtleBot3
- If an exception is thrown in the constructor, the best way is to prevent memory leakage?
- 单机高并发模型设计
- Opengl3.3 mouse picking up objects
- Solution to prompt configure: error: curses library not found when configuring and installing crosstool ng tool
- PostGIS learning
- Kubectl's handy command line tool: Oh my Zsh tips and tricks
- 攻防世界Web进阶区unserialize3题解
猜你喜欢
Jouer sonar
Daily question brushing record (16)
Pypharm uses, and the third-party library has errors due to version problems
【编程题】【Scratch二级】2019.03 绘制方形螺旋
Two small problems in creating user registration interface
玩轉Sonar
How to insert highlighted code blocks in WPS and word
Opengl3.3 mouse picking up objects
Install sqlserver2019
玩转Sonar
随机推荐
How does starfish OS enable the value of SFO in the fourth phase of SFO destruction?
每日刷题记录 (十六)
自动化测试:Robot FrameWork框架90%的人都想知道的实用技巧
Vscode software
Is it safe for tongdaxin to buy funds?
Coindesk comments on the decentralization process of the wave field: let people see the future of the Internet
Go learning notes (1) environment installation and hello world
ROS from entry to mastery (IX) initial experience of visual simulation: turtlebot3
【编程题】【Scratch二级】2019.12 飞翔的小鸟
51与蓝牙模块通讯,51驱动蓝牙APP点灯
玩轉Sonar
【转载】解决conda安装pytorch过慢的问题
应用实践 | 数仓体系效率全面提升!同程数科基于 Apache Doris 的数据仓库建设
Fully automated processing of monthly card shortage data and output of card shortage personnel information
DNS 系列(一):为什么更新了 DNS 记录不生效?
How to add automatic sorting titles in typora software?
[programming problem] [scratch Level 2] December 2019 flying birds
Install sqlserver2019
The underlying principles and templates of new and delete
大数据开源项目,一站式全自动化全生命周期运维管家ChengYing(承影)走向何方?