当前位置:网站首页>Introduction to paddle - using lenet to realize image classification method II in MNIST
Introduction to paddle - using lenet to realize image classification method II in MNIST
2022-07-08 00:52:00 【Vertira】
Use LeNet stay MNIST Implement image classification method II
Compared with method 1 , Method 2 A little bit of a problem , Suitable for advanced
5、 ... and 、 The way 2: Based on the foundation API, Complete the training and prediction of the model
5.1 model training ¶
After networking , Start training the model , Build first train_loader
, Load training data , Then define train
function , After setting the loss function , Press batch Load data , Complete the training of the model .
import paddle.nn.functional as F
train_loader = paddle.io.DataLoader(train_dataset, batch_size=64, shuffle=True)
# Load training set batch_size Set to 64
def train(model):
model.train()
epochs = 2
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
# use Adam As an optimization function
for epoch in range(epochs):
for batch_id, data in enumerate(train_loader()):
x_data = data[0]
y_data = data[1]
predicts = model(x_data)
loss = F.cross_entropy(predicts, y_data)
# Calculate the loss
acc = paddle.metric.accuracy(predicts, y_data)
loss.backward()
if batch_id % 300 == 0:
print("epoch: {}, batch_id: {}, loss is: {}, acc is: {}".format(epoch, batch_id, loss.numpy(), acc.numpy()))
optim.step()
optim.clear_grad()
model = LeNet()
train(model)
Running results
epoch: 0, batch_id: 0, loss is: [3.2611141], acc is: [0.078125]
epoch: 0, batch_id: 300, loss is: [0.24404016], acc is: [0.921875]
epoch: 0, batch_id: 600, loss is: [0.03953885], acc is: [1.]
epoch: 0, batch_id: 900, loss is: [0.03700985], acc is: [0.984375]
epoch: 1, batch_id: 0, loss is: [0.05806625], acc is: [0.96875]
epoch: 1, batch_id: 300, loss is: [0.06538856], acc is: [0.953125]
epoch: 1, batch_id: 600, loss is: [0.03884572], acc is: [0.984375]
epoch: 1, batch_id: 900, loss is: [0.01922364], acc is: [0.984375]
5.2 Model validation ¶
After training , The effect of the model needs to be verified , here , Load test data set , Then use the trained model to predict the test set , Calculation loss and accuracy .
test_loader = paddle.io.DataLoader(test_dataset, places=paddle.CPUPlace(), batch_size=64)
# Load test data set
def test(model):
model.eval()
batch_size = 64
for batch_id, data in enumerate(test_loader()):
x_data = data[0]
y_data = data[1]
predicts = model(x_data)
# Get forecast results
loss = F.cross_entropy(predicts, y_data)
acc = paddle.metric.accuracy(predicts, y_data)
if batch_id % 20 == 0:
print("batch_id: {}, loss is: {}, acc is: {}".format(batch_id, loss.numpy(), acc.numpy()))
test(model)
Running results
batch_id: 0, loss is: [0.01972857], acc is: [0.984375]
batch_id: 20, loss is: [0.19958115], acc is: [0.9375]
batch_id: 40, loss is: [0.23575728], acc is: [0.953125]
batch_id: 60, loss is: [0.07018849], acc is: [0.984375]
batch_id: 80, loss is: [0.02309197], acc is: [0.984375]
batch_id: 100, loss is: [0.00239462], acc is: [1.]
batch_id: 120, loss is: [0.01583934], acc is: [1.]
batch_id: 140, loss is: [0.00399609], acc is: [1.]
End of mode 2 ¶
The above is mode 2 , Through the ground floor API, You can clearly see every step of the training and testing process . however , This way is more complicated . therefore , We offer a training method , Use high-level API To complete the training and prediction of the model . Compare the bottom API, high-level API Can be faster 、 Efficiently complete the training and testing of the model .
6、 ... and 、 summary ¶
The above is the use of LeNet For handwritten digital data and MNIST To classify . This example provides two ways to train the model , One can quickly complete the establishment and prediction of the model , It is very suitable for novice users . The other requires multiple steps to complete the training of the model , Suitable for advanced users .
The complete code of method 2 is as follows :
import os
import cv2
import numpy as np
from paddle.io import Dataset
import paddle.vision.transforms as T
import matplotlib.pyplot as plt
import paddle.nn.functional as F
import paddle
from paddle.metric import Accuracy
from paddle.vision.transforms import Compose, Normalize
transform = Compose([Normalize(mean=[127.5],
std=[127.5],
data_format='CHW')])
# Use transform Normalize the data set
print('download training data and load training data')
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
print('load finished')
class LeNet(paddle.nn.Layer):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)
self.max_pool1 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
self.conv2 = paddle.nn.Conv2D(in_channels=6, out_channels=16, kernel_size=5, stride=1)
self.max_pool2 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
self.linear1 = paddle.nn.Linear(in_features=16*5*5, out_features=120)
self.linear2 = paddle.nn.Linear(in_features=120, out_features=84)
self.linear3 = paddle.nn.Linear(in_features=84, out_features=10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.max_pool1(x)
x = self.conv2(x)
x = F.relu(x)
x = self.max_pool2(x)
x = paddle.flatten(x, start_axis=1,stop_axis=-1)
x = self.linear1(x)
x = F.relu(x)
x = self.linear2(x)
x = F.relu(x)
x = self.linear3(x)
return x
# After networking , Start training the model , Build first train_loader, Load training data , Then define train function , After setting the loss function , Press batch Load data , Complete the training of the model .
train_loader = paddle.io.DataLoader(train_dataset, batch_size=64, shuffle=True)
# Load training set batch_size Set to 64
def train(model):
model.train()
epochs = 2
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
# use Adam As an optimization function
for epoch in range(epochs):
for batch_id, data in enumerate(train_loader()):
x_data = data[0]
y_data = data[1]
predicts = model(x_data)
loss = F.cross_entropy(predicts, y_data)
# Calculate the loss
acc = paddle.metric.accuracy(predicts, y_data)
loss.backward()
if batch_id % 300 == 0:
print("epoch: {}, batch_id: {}, loss is: {}, acc is: {}".format(epoch, batch_id, loss.numpy(), acc.numpy()))
optim.step()
optim.clear_grad()
model = LeNet()
train(model)
边栏推荐
- 51 communicates with the Bluetooth module, and 51 drives the Bluetooth app to light up
- Installation and configuration of sublime Text3
- 韦东山第二期课程内容概要
- DNS series (I): why does the updated DNS record not take effect?
- 炒股开户怎么最方便,手机上开户安全吗
- Hotel
- 手机上炒股安全么?
- [Yugong series] go teaching course 006 in July 2022 - automatic derivation of types and input and output
- Class head up rate detection based on face recognition
- ABAP ALV LVC模板
猜你喜欢
Codeforces Round #804 (Div. 2)(A~D)
[OBS] the official configuration is use_ GPU_ Priority effect is true
Analysis of 8 classic C language pointer written test questions
Lecture 1: the entry node of the link in the linked list
Langchao Yunxi distributed database tracing (II) -- source code analysis
jemter分布式
Operating system principle --- summary of interview knowledge points
They gathered at the 2022 ecug con just for "China's technological power"
The standby database has been delayed. Check that the MRP is wait_ for_ Log, apply after restarting MRP_ Log but wait again later_ for_ log
测试流程不完善,又遇到不积极的开发怎么办?
随机推荐
Is it safe to open an account on the official website of Huatai Securities?
Su embedded training - Day3
Service Mesh的基本模式
Operating system principle --- summary of interview knowledge points
韦东山第二期课程内容概要
The weight of the product page of the second level classification is low. What if it is not included?
Basic types of 100 questions for basic grammar of Niuke
Course of causality, taught by Jonas Peters, University of Copenhagen
什么是负载均衡?DNS如何实现负载均衡?
Is it safe to speculate in stocks on mobile phones?
Interface test advanced interface script use - apipost (pre / post execution script)
大二级分类产品页权重低,不收录怎么办?
8道经典C语言指针笔试题解析
SDNU_ ACM_ ICPC_ 2022_ Summer_ Practice(1~2)
FOFA-攻防挑战记录
jemter分布式
Malware detection method based on convolutional neural network
SDNU_ACM_ICPC_2022_Summer_Practice(1~2)
Cve-2022-28346: Django SQL injection vulnerability
3年经验,面试测试岗20K都拿不到了吗?这么坑?