当前位置:网站首页>paddle入门-使用LeNet在MNIST实现图像分类方法二
paddle入门-使用LeNet在MNIST实现图像分类方法二
2022-07-07 22:06:00 【Vertira】
使用LeNet在MNIST实现图像分类方法二
与方法一相比,方法二 有点麻烦,适合进阶
五、方式2:基于基础API,完成模型的训练与预测
5.1 模型训练¶
组网后,开始对模型进行训练,先构建train_loader
,加载训练数据,然后定义train
函数,设置好损失函数后,按batch加载数据,完成模型的训练。
import paddle.nn.functional as F
train_loader = paddle.io.DataLoader(train_dataset, batch_size=64, shuffle=True)
# 加载训练集 batch_size 设为 64
def train(model):
model.train()
epochs = 2
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
# 用Adam作为优化函数
for epoch in range(epochs):
for batch_id, data in enumerate(train_loader()):
x_data = data[0]
y_data = data[1]
predicts = model(x_data)
loss = F.cross_entropy(predicts, y_data)
# 计算损失
acc = paddle.metric.accuracy(predicts, y_data)
loss.backward()
if batch_id % 300 == 0:
print("epoch: {}, batch_id: {}, loss is: {}, acc is: {}".format(epoch, batch_id, loss.numpy(), acc.numpy()))
optim.step()
optim.clear_grad()
model = LeNet()
train(model)
运行结果
epoch: 0, batch_id: 0, loss is: [3.2611141], acc is: [0.078125]
epoch: 0, batch_id: 300, loss is: [0.24404016], acc is: [0.921875]
epoch: 0, batch_id: 600, loss is: [0.03953885], acc is: [1.]
epoch: 0, batch_id: 900, loss is: [0.03700985], acc is: [0.984375]
epoch: 1, batch_id: 0, loss is: [0.05806625], acc is: [0.96875]
epoch: 1, batch_id: 300, loss is: [0.06538856], acc is: [0.953125]
epoch: 1, batch_id: 600, loss is: [0.03884572], acc is: [0.984375]
epoch: 1, batch_id: 900, loss is: [0.01922364], acc is: [0.984375]
5.2 模型验证¶
训练完成后,需要验证模型的效果,此时,加载测试数据集,然后用训练好的模对测试集进行预测,计算损失与精度。
test_loader = paddle.io.DataLoader(test_dataset, places=paddle.CPUPlace(), batch_size=64)
# 加载测试数据集
def test(model):
model.eval()
batch_size = 64
for batch_id, data in enumerate(test_loader()):
x_data = data[0]
y_data = data[1]
predicts = model(x_data)
# 获取预测结果
loss = F.cross_entropy(predicts, y_data)
acc = paddle.metric.accuracy(predicts, y_data)
if batch_id % 20 == 0:
print("batch_id: {}, loss is: {}, acc is: {}".format(batch_id, loss.numpy(), acc.numpy()))
test(model)
运行结果
batch_id: 0, loss is: [0.01972857], acc is: [0.984375]
batch_id: 20, loss is: [0.19958115], acc is: [0.9375]
batch_id: 40, loss is: [0.23575728], acc is: [0.953125]
batch_id: 60, loss is: [0.07018849], acc is: [0.984375]
batch_id: 80, loss is: [0.02309197], acc is: [0.984375]
batch_id: 100, loss is: [0.00239462], acc is: [1.]
batch_id: 120, loss is: [0.01583934], acc is: [1.]
batch_id: 140, loss is: [0.00399609], acc is: [1.]
方式二结束¶
以上就是方式二,通过底层API,可以清楚的看到训练和测试中的每一步过程。但是,这种方式比较复杂。因此,我们提供了训练方式一,使用高层API来完成模型的训练与预测。对比底层API,高层API能够更加快速、高效的完成模型的训练与测试。
六、总结¶
以上就是用LeNet对手写数字数据及MNIST进行分类。本示例提供了两种训练模型的方式,一种可以快速完成模型的组建与预测,非常适合新手用户上手。另一种则需要多个步骤来完成模型的训练,适合进阶用户使用。
方法二的完整代码如下:
import os
import cv2
import numpy as np
from paddle.io import Dataset
import paddle.vision.transforms as T
import matplotlib.pyplot as plt
import paddle.nn.functional as F
import paddle
from paddle.metric import Accuracy
from paddle.vision.transforms import Compose, Normalize
transform = Compose([Normalize(mean=[127.5],
std=[127.5],
data_format='CHW')])
# 使用transform对数据集做归一化
print('download training data and load training data')
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
print('load finished')
class LeNet(paddle.nn.Layer):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)
self.max_pool1 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
self.conv2 = paddle.nn.Conv2D(in_channels=6, out_channels=16, kernel_size=5, stride=1)
self.max_pool2 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
self.linear1 = paddle.nn.Linear(in_features=16*5*5, out_features=120)
self.linear2 = paddle.nn.Linear(in_features=120, out_features=84)
self.linear3 = paddle.nn.Linear(in_features=84, out_features=10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.max_pool1(x)
x = self.conv2(x)
x = F.relu(x)
x = self.max_pool2(x)
x = paddle.flatten(x, start_axis=1,stop_axis=-1)
x = self.linear1(x)
x = F.relu(x)
x = self.linear2(x)
x = F.relu(x)
x = self.linear3(x)
return x
#组网后,开始对模型进行训练,先构建train_loader,加载训练数据,然后定义train函数,设置好损失函数后,按batch加载数据,完成模型的训练。
train_loader = paddle.io.DataLoader(train_dataset, batch_size=64, shuffle=True)
# 加载训练集 batch_size 设为 64
def train(model):
model.train()
epochs = 2
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
# 用Adam作为优化函数
for epoch in range(epochs):
for batch_id, data in enumerate(train_loader()):
x_data = data[0]
y_data = data[1]
predicts = model(x_data)
loss = F.cross_entropy(predicts, y_data)
# 计算损失
acc = paddle.metric.accuracy(predicts, y_data)
loss.backward()
if batch_id % 300 == 0:
print("epoch: {}, batch_id: {}, loss is: {}, acc is: {}".format(epoch, batch_id, loss.numpy(), acc.numpy()))
optim.step()
optim.clear_grad()
model = LeNet()
train(model)
边栏推荐
- Scrapy framework
- 面试题详解:用Redis实现分布式锁的血泪史
- 80% of the people answered incorrectly. Does the leaf on the apple logo face left or right?
- 自动化测试:Robot FrameWork框架90%的人都想知道的实用技巧
- Fully automated processing of monthly card shortage data and output of card shortage personnel information
- Go learning notes (2) basic types and statements (1)
- 【编程题】【Scratch二级】2019.03 绘制方形螺旋
- mysql8.0 ubuntu20.4
- PostGIS learning
- Basic learning of SQL Server -- creating databases and tables with the mouse
猜你喜欢
BSS 7230 flame retardant performance test of aviation interior materials
Chisel tutorial - 02 Chisel environment configuration and implementation and testing of the first chisel module
One click free translation of more than 300 pages of PDF documents
[path planning] use the vertical distance limit method and Bessel to optimize the path of a star
Archery installation test
Pycharm basic settings latest version 2022
95. (cesium chapter) cesium dynamic monomer-3d building (building)
Go learning notes (1) environment installation and hello world
Magic fast power
[programming questions] [scratch Level 2] March 2019 garbage classification
随机推荐
【编程题】【Scratch二级】2019.03 绘制方形螺旋
Robomaster visual tutorial (11) summary
Codeworks 5 questions per day (average 1500) - day 8
One click free translation of more than 300 pages of PDF documents
PostGIS learning
52岁的周鸿祎,还年轻吗?
Alibaba cloud MySQL cannot connect
Traduction gratuite en un clic de plus de 300 pages de documents PDF
Trust orbtk development issues 2022
数据库查询——第几高的数据?
How did a fake offer steal $540million from "axie infinity"?
Scrapy framework
Anaconda+pycharm+pyqt5 configuration problem: pyuic5 cannot be found exe
一鍵免費翻譯300多頁的pdf文檔
Automated testing: robot framework is a practical skill that 90% of people want to know
Go learning notes (1) environment installation and hello world
Orthodontic precautions (continuously updated)
Is it safe to buy funds online?
Uic564-2 Appendix 4 - flame retardant fire test: flame diffusion
Aitm3.0005 smoke toxicity test