当前位置:网站首页>Introduction to paddle - using lenet to realize image classification method II in MNIST
Introduction to paddle - using lenet to realize image classification method II in MNIST
2022-07-08 00:52:00 【Vertira】
Use LeNet stay MNIST Implement image classification method II
Compared with method 1 , Method 2 A little bit of a problem , Suitable for advanced
5、 ... and 、 The way 2: Based on the foundation API, Complete the training and prediction of the model
5.1 model training ¶
After networking , Start training the model , Build first train_loader, Load training data , Then define train function , After setting the loss function , Press batch Load data , Complete the training of the model .
import paddle.nn.functional as F
train_loader = paddle.io.DataLoader(train_dataset, batch_size=64, shuffle=True)
# Load training set batch_size Set to 64
def train(model):
model.train()
epochs = 2
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
# use Adam As an optimization function
for epoch in range(epochs):
for batch_id, data in enumerate(train_loader()):
x_data = data[0]
y_data = data[1]
predicts = model(x_data)
loss = F.cross_entropy(predicts, y_data)
# Calculate the loss
acc = paddle.metric.accuracy(predicts, y_data)
loss.backward()
if batch_id % 300 == 0:
print("epoch: {}, batch_id: {}, loss is: {}, acc is: {}".format(epoch, batch_id, loss.numpy(), acc.numpy()))
optim.step()
optim.clear_grad()
model = LeNet()
train(model)
Running results
epoch: 0, batch_id: 0, loss is: [3.2611141], acc is: [0.078125]
epoch: 0, batch_id: 300, loss is: [0.24404016], acc is: [0.921875]
epoch: 0, batch_id: 600, loss is: [0.03953885], acc is: [1.]
epoch: 0, batch_id: 900, loss is: [0.03700985], acc is: [0.984375]
epoch: 1, batch_id: 0, loss is: [0.05806625], acc is: [0.96875]
epoch: 1, batch_id: 300, loss is: [0.06538856], acc is: [0.953125]
epoch: 1, batch_id: 600, loss is: [0.03884572], acc is: [0.984375]
epoch: 1, batch_id: 900, loss is: [0.01922364], acc is: [0.984375]
5.2 Model validation ¶
After training , The effect of the model needs to be verified , here , Load test data set , Then use the trained model to predict the test set , Calculation loss and accuracy .
test_loader = paddle.io.DataLoader(test_dataset, places=paddle.CPUPlace(), batch_size=64)
# Load test data set
def test(model):
model.eval()
batch_size = 64
for batch_id, data in enumerate(test_loader()):
x_data = data[0]
y_data = data[1]
predicts = model(x_data)
# Get forecast results
loss = F.cross_entropy(predicts, y_data)
acc = paddle.metric.accuracy(predicts, y_data)
if batch_id % 20 == 0:
print("batch_id: {}, loss is: {}, acc is: {}".format(batch_id, loss.numpy(), acc.numpy()))
test(model)
Running results
batch_id: 0, loss is: [0.01972857], acc is: [0.984375]
batch_id: 20, loss is: [0.19958115], acc is: [0.9375]
batch_id: 40, loss is: [0.23575728], acc is: [0.953125]
batch_id: 60, loss is: [0.07018849], acc is: [0.984375]
batch_id: 80, loss is: [0.02309197], acc is: [0.984375]
batch_id: 100, loss is: [0.00239462], acc is: [1.]
batch_id: 120, loss is: [0.01583934], acc is: [1.]
batch_id: 140, loss is: [0.00399609], acc is: [1.]
End of mode 2 ¶
The above is mode 2 , Through the ground floor API, You can clearly see every step of the training and testing process . however , This way is more complicated . therefore , We offer a training method , Use high-level API To complete the training and prediction of the model . Compare the bottom API, high-level API Can be faster 、 Efficiently complete the training and testing of the model .
6、 ... and 、 summary ¶
The above is the use of LeNet For handwritten digital data and MNIST To classify . This example provides two ways to train the model , One can quickly complete the establishment and prediction of the model , It is very suitable for novice users . The other requires multiple steps to complete the training of the model , Suitable for advanced users .
The complete code of method 2 is as follows :
import os
import cv2
import numpy as np
from paddle.io import Dataset
import paddle.vision.transforms as T
import matplotlib.pyplot as plt
import paddle.nn.functional as F
import paddle
from paddle.metric import Accuracy
from paddle.vision.transforms import Compose, Normalize
transform = Compose([Normalize(mean=[127.5],
std=[127.5],
data_format='CHW')])
# Use transform Normalize the data set
print('download training data and load training data')
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
print('load finished')
class LeNet(paddle.nn.Layer):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)
self.max_pool1 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
self.conv2 = paddle.nn.Conv2D(in_channels=6, out_channels=16, kernel_size=5, stride=1)
self.max_pool2 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
self.linear1 = paddle.nn.Linear(in_features=16*5*5, out_features=120)
self.linear2 = paddle.nn.Linear(in_features=120, out_features=84)
self.linear3 = paddle.nn.Linear(in_features=84, out_features=10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.max_pool1(x)
x = self.conv2(x)
x = F.relu(x)
x = self.max_pool2(x)
x = paddle.flatten(x, start_axis=1,stop_axis=-1)
x = self.linear1(x)
x = F.relu(x)
x = self.linear2(x)
x = F.relu(x)
x = self.linear3(x)
return x
# After networking , Start training the model , Build first train_loader, Load training data , Then define train function , After setting the loss function , Press batch Load data , Complete the training of the model .
train_loader = paddle.io.DataLoader(train_dataset, batch_size=64, shuffle=True)
# Load training set batch_size Set to 64
def train(model):
model.train()
epochs = 2
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
# use Adam As an optimization function
for epoch in range(epochs):
for batch_id, data in enumerate(train_loader()):
x_data = data[0]
y_data = data[1]
predicts = model(x_data)
loss = F.cross_entropy(predicts, y_data)
# Calculate the loss
acc = paddle.metric.accuracy(predicts, y_data)
loss.backward()
if batch_id % 300 == 0:
print("epoch: {}, batch_id: {}, loss is: {}, acc is: {}".format(epoch, batch_id, loss.numpy(), acc.numpy()))
optim.step()
optim.clear_grad()
model = LeNet()
train(model)
边栏推荐
- 韦东山第三期课程内容概要
- 赞!idea 如何单窗口打开多个项目?
- 基于微信小程序开发的我最在行的小游戏
- Stock account opening is free of charge. Is it safe to open an account on your mobile phone
- 玩转Sonar
- 华为交换机S5735S-L24T4S-QA2无法telnet远程访问
- Class head up rate detection based on face recognition
- 5G NR 系统消息
- Codeforces Round #804 (Div. 2)(A~D)
- 【obs】Impossible to find entrance point CreateDirect3D11DeviceFromDXGIDevice
猜你喜欢

5g NR system messages
![Cause analysis and solution of too laggy page of [test interview questions]](/img/8d/3ca92ce5f9cdc85d52dbcd826e477d.jpg)
Cause analysis and solution of too laggy page of [test interview questions]

Single machine high concurrency model design

【测试面试题】页面很卡的原因分析及解决方案

SDNU_ACM_ICPC_2022_Summer_Practice(1~2)

“一个优秀程序员可抵五个普通程序员”,差距就在这7个关键点

Development of a horse tourism website (optimization of servlet)

How to insert highlighted code blocks in WPS and word

"An excellent programmer is worth five ordinary programmers", and the gap lies in these seven key points

Qt添加资源文件,为QAction添加图标,建立信号槽函数并实现
随机推荐
3 years of experience, can't you get 20K for the interview and test post? Such a hole?
Huawei switch s5735s-l24t4s-qa2 cannot be remotely accessed by telnet
NVIDIA Jetson test installation yolox process record
STL--String类的常用功能复写
华泰证券官方网站开户安全吗?
Kubernetes Static Pod (静态Pod)
【obs】官方是配置USE_GPU_PRIORITY 效果为TRUE的
Tapdata 的 2.0 版 ,开源的 Live Data Platform 现已发布
What is load balancing? How does DNS achieve load balancing?
51与蓝牙模块通讯,51驱动蓝牙APP点灯
3年经验,面试测试岗20K都拿不到了吗?这么坑?
v-for遍历元素样式失效
赞!idea 如何单窗口打开多个项目?
新库上线 | CnOpenData中国星级酒店数据
Introduction to paddle - using lenet to realize image classification method I in MNIST
Hotel
【obs】Impossible to find entrance point CreateDirect3D11DeviceFromDXGIDevice
Solution to prompt configure: error: curses library not found when configuring and installing crosstool ng tool
German prime minister says Ukraine will not receive "NATO style" security guarantee
The underlying principles and templates of new and delete