当前位置:网站首页>Introduction to paddle - using lenet to realize image classification method I in MNIST
Introduction to paddle - using lenet to realize image classification method I in MNIST
2022-07-08 00:17:00 【Vertira】
MNIST The data set realizes image classification
One 、 Configuration environment
import paddle
print(paddle.__version__)
How to configure paddle You can search online , My blog also has , Here is a little
Load data : There are two ways : Custom data loading ( My previous blog has ), load paddled Data prepared on the official website
We are looking for the second way , Because it's convenient
Handwritten numbers MNIST Data sets , contain 60,000 Examples and for training 10,000 An example for testing . These numbers have been dimensioned and located in the center of the image , The image is a fixed size (28x28 Pixels ), Its value is 0 To 1. The official address of the data set is :http://yann.lecun.com/exdb/mnist .
We use the built-in... Of the propeller frame paddle.vision.datasets.MNIST
complete mnist Data set loading .
from paddle.vision.transforms import Compose, Normalize
transform = Compose([Normalize(mean=[127.5],
std=[127.5],
data_format='CHW')])
# Use transform Normalize the data set
print('download training data and load training data')
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
print('load finished')
Take a piece of data from the training set and have a look .
import numpy as np
import matplotlib.pyplot as plt
train_data0, train_label_0 = train_dataset[0][0],train_dataset[0][1]
train_data0 = train_data0.reshape([28,28])
plt.figure(figsize=(2,2))
plt.imshow(train_data0, cmap=plt.cm.binary)
print('train_data0 label is: ' + str(train_label_0))
3、 ... and 、 networking
use paddle.nn Under the API, Such as Conv2D
、MaxPool2D
、Linear
complete LeNet The construction of .
import paddle
import paddle.nn.functional as F
class LeNet(paddle.nn.Layer):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)
self.max_pool1 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
self.conv2 = paddle.nn.Conv2D(in_channels=6, out_channels=16, kernel_size=5, stride=1)
self.max_pool2 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
self.linear1 = paddle.nn.Linear(in_features=16*5*5, out_features=120)
self.linear2 = paddle.nn.Linear(in_features=120, out_features=84)
self.linear3 = paddle.nn.Linear(in_features=84, out_features=10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.max_pool1(x)
x = self.conv2(x)
x = F.relu(x)
x = self.max_pool2(x)
x = paddle.flatten(x, start_axis=1,stop_axis=-1)
x = self.linear1(x)
x = F.relu(x)
x = self.linear2(x)
x = F.relu(x)
x = self.linear3(x)
return x
Four 、 The way 1: Based on high-level API, Complete the training and prediction of the model
adopt paddle Provided Model
Build instance , Use the encapsulated training and test interface , Quickly complete model training and testing .
4.1 Use Model.fit
To train the model
from paddle.metric import Accuracy
model = paddle.Model(LeNet()) # use Model Packaging model
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
# Configuration model
model.prepare(
optim,
paddle.nn.CrossEntropyLoss(),
Accuracy()
)
# Training models
model.fit(train_dataset,
epochs=2,
batch_size=64,
verbose=1
)
Training results
The loss value printed in the log is the current step, and the metric is the average value of previous steps.
Epoch 1/2
step 938/938 [==============================] - loss: 0.0329 - acc: 0.9399 - 10ms/step
Epoch 2/2
step 938/938 [==============================] - loss: 0.0092 - acc: 0.9798 - 10ms/step
4.2 Use Model.evaluate
To predict the model
model.evaluate(test_dataset, batch_size=64, verbose=1)
Eval begin...
step 157/157 [==============================] - loss: 4.4728e-04 - acc: 0.9857 - 8ms/step
Eval samples: 10000
{'loss': [0.0004472804], 'acc': 0.9857}
Way one ends
That's way one , Can quickly 、 Efficiently complete network model training and prediction .
Reference resources :
边栏推荐
- Robomaster visual tutorial (10) target prediction
- Coindesk comments on the decentralization process of the wave field: let people see the future of the Internet
- 2022-07-07:原本数组中都是大于0、小于等于k的数字,是一个单调不减的数组, 其中可能有相等的数字,总体趋势是递增的。 但是其中有些位置的数被替换成了0,我们需要求出所有的把0替换的方案数量:
- How to measure whether the product is "just needed, high frequency, pain points"
- RPA cloud computer, let RPA out of the box with unlimited computing power?
- 关于组织2021-2022全国青少年电子信息智能创新大赛西南赛区(四川)复赛的通知
- Opengl3.3 mouse picking up objects
- 快速上手使用本地测试工具postman
- QT creator add custom new file / Project Template Wizard
- Tools for debugging makefiles - tool for debugging makefiles
猜你喜欢
Go learning notes (2) basic types and statements (1)
Operating system principle --- summary of interview knowledge points
Pypharm uses, and the third-party library has errors due to version problems
52歲的周鴻禕,還年輕嗎?
Single machine high concurrency model design
大数据开源项目,一站式全自动化全生命周期运维管家ChengYing(承影)走向何方?
Preliminary test of optical flow sensor: gl9306
51与蓝牙模块通讯,51驱动蓝牙APP点灯
How to measure whether the product is "just needed, high frequency, pain points"
某马旅游网站开发(对servlet的优化)
随机推荐
【编程题】【Scratch二级】2019.09 绘制雪花图案
The difference between -s and -d when downloading packages using NPM
Robomaster visual tutorial (10) target prediction
Play sonar
“一个优秀程序员可抵五个普通程序员”,差距就在这7个关键点
DNS 系列(一):为什么更新了 DNS 记录不生效?
80%的人答错,苹果logo上的叶子到底朝左还是朝右?
攻防世界Web进阶区unserialize3题解
Single machine high concurrency model design
STM32F1與STM32CubeIDE編程實例-旋轉編碼器驅動
Opengl3.3 mouse picking up objects
Scrapy framework
If an exception is thrown in the constructor, the best way is to prevent memory leakage?
Cmake learning notes (1) compile single source programs with cmake
Installation and configuration of sublime Text3
浪潮云溪分布式数据库 Tracing(二)—— 源码解析
Trust orbtk development issues 2022
Sqlite数据库存储目录结构邻接表的实现2-目录树的构建
C language 001: download, install, create the first C project and execute the first C language program of CodeBlocks
Linkedblockingqueue source code analysis - add and delete