当前位置:网站首页>paddle入门-使用LeNet在MNIST实现图像分类方法二
paddle入门-使用LeNet在MNIST实现图像分类方法二
2022-07-07 22:06:00 【Vertira】
使用LeNet在MNIST实现图像分类方法二
与方法一相比,方法二 有点麻烦,适合进阶
五、方式2:基于基础API,完成模型的训练与预测
5.1 模型训练¶
组网后,开始对模型进行训练,先构建train_loader,加载训练数据,然后定义train函数,设置好损失函数后,按batch加载数据,完成模型的训练。
import paddle.nn.functional as F
train_loader = paddle.io.DataLoader(train_dataset, batch_size=64, shuffle=True)
# 加载训练集 batch_size 设为 64
def train(model):
model.train()
epochs = 2
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
# 用Adam作为优化函数
for epoch in range(epochs):
for batch_id, data in enumerate(train_loader()):
x_data = data[0]
y_data = data[1]
predicts = model(x_data)
loss = F.cross_entropy(predicts, y_data)
# 计算损失
acc = paddle.metric.accuracy(predicts, y_data)
loss.backward()
if batch_id % 300 == 0:
print("epoch: {}, batch_id: {}, loss is: {}, acc is: {}".format(epoch, batch_id, loss.numpy(), acc.numpy()))
optim.step()
optim.clear_grad()
model = LeNet()
train(model)
运行结果
epoch: 0, batch_id: 0, loss is: [3.2611141], acc is: [0.078125]
epoch: 0, batch_id: 300, loss is: [0.24404016], acc is: [0.921875]
epoch: 0, batch_id: 600, loss is: [0.03953885], acc is: [1.]
epoch: 0, batch_id: 900, loss is: [0.03700985], acc is: [0.984375]
epoch: 1, batch_id: 0, loss is: [0.05806625], acc is: [0.96875]
epoch: 1, batch_id: 300, loss is: [0.06538856], acc is: [0.953125]
epoch: 1, batch_id: 600, loss is: [0.03884572], acc is: [0.984375]
epoch: 1, batch_id: 900, loss is: [0.01922364], acc is: [0.984375]
5.2 模型验证¶
训练完成后,需要验证模型的效果,此时,加载测试数据集,然后用训练好的模对测试集进行预测,计算损失与精度。
test_loader = paddle.io.DataLoader(test_dataset, places=paddle.CPUPlace(), batch_size=64)
# 加载测试数据集
def test(model):
model.eval()
batch_size = 64
for batch_id, data in enumerate(test_loader()):
x_data = data[0]
y_data = data[1]
predicts = model(x_data)
# 获取预测结果
loss = F.cross_entropy(predicts, y_data)
acc = paddle.metric.accuracy(predicts, y_data)
if batch_id % 20 == 0:
print("batch_id: {}, loss is: {}, acc is: {}".format(batch_id, loss.numpy(), acc.numpy()))
test(model)
运行结果
batch_id: 0, loss is: [0.01972857], acc is: [0.984375]
batch_id: 20, loss is: [0.19958115], acc is: [0.9375]
batch_id: 40, loss is: [0.23575728], acc is: [0.953125]
batch_id: 60, loss is: [0.07018849], acc is: [0.984375]
batch_id: 80, loss is: [0.02309197], acc is: [0.984375]
batch_id: 100, loss is: [0.00239462], acc is: [1.]
batch_id: 120, loss is: [0.01583934], acc is: [1.]
batch_id: 140, loss is: [0.00399609], acc is: [1.]
方式二结束¶
以上就是方式二,通过底层API,可以清楚的看到训练和测试中的每一步过程。但是,这种方式比较复杂。因此,我们提供了训练方式一,使用高层API来完成模型的训练与预测。对比底层API,高层API能够更加快速、高效的完成模型的训练与测试。
六、总结¶
以上就是用LeNet对手写数字数据及MNIST进行分类。本示例提供了两种训练模型的方式,一种可以快速完成模型的组建与预测,非常适合新手用户上手。另一种则需要多个步骤来完成模型的训练,适合进阶用户使用。
方法二的完整代码如下:
import os
import cv2
import numpy as np
from paddle.io import Dataset
import paddle.vision.transforms as T
import matplotlib.pyplot as plt
import paddle.nn.functional as F
import paddle
from paddle.metric import Accuracy
from paddle.vision.transforms import Compose, Normalize
transform = Compose([Normalize(mean=[127.5],
std=[127.5],
data_format='CHW')])
# 使用transform对数据集做归一化
print('download training data and load training data')
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
print('load finished')
class LeNet(paddle.nn.Layer):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)
self.max_pool1 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
self.conv2 = paddle.nn.Conv2D(in_channels=6, out_channels=16, kernel_size=5, stride=1)
self.max_pool2 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
self.linear1 = paddle.nn.Linear(in_features=16*5*5, out_features=120)
self.linear2 = paddle.nn.Linear(in_features=120, out_features=84)
self.linear3 = paddle.nn.Linear(in_features=84, out_features=10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.max_pool1(x)
x = self.conv2(x)
x = F.relu(x)
x = self.max_pool2(x)
x = paddle.flatten(x, start_axis=1,stop_axis=-1)
x = self.linear1(x)
x = F.relu(x)
x = self.linear2(x)
x = F.relu(x)
x = self.linear3(x)
return x
#组网后,开始对模型进行训练,先构建train_loader,加载训练数据,然后定义train函数,设置好损失函数后,按batch加载数据,完成模型的训练。
train_loader = paddle.io.DataLoader(train_dataset, batch_size=64, shuffle=True)
# 加载训练集 batch_size 设为 64
def train(model):
model.train()
epochs = 2
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
# 用Adam作为优化函数
for epoch in range(epochs):
for batch_id, data in enumerate(train_loader()):
x_data = data[0]
y_data = data[1]
predicts = model(x_data)
loss = F.cross_entropy(predicts, y_data)
# 计算损失
acc = paddle.metric.accuracy(predicts, y_data)
loss.backward()
if batch_id % 300 == 0:
print("epoch: {}, batch_id: {}, loss is: {}, acc is: {}".format(epoch, batch_id, loss.numpy(), acc.numpy()))
optim.step()
optim.clear_grad()
model = LeNet()
train(model)
边栏推荐
- Is it safe for tongdaxin to buy funds?
- 光流传感器初步测试:GL9306
- limit 与offset的用法(转载)
- Kubectl 好用的命令行工具:oh-my-zsh 技巧和窍门
- LinkedBlockingQueue源码分析-新增和删除
- 每日刷题记录 (十六)
- Basic learning of SQL Server -- creating databases and tables with the mouse
- Robomaster visual tutorial (0) Introduction
- Chisel tutorial - 00 Ex.scala metals plug-in (vs Code), SBT and coursier exchange endogenous
- Solutions to problems in sqlserver deleting data in tables
猜你喜欢

Connect diodes in series to improve voltage withstand

Automated testing: robot framework is a practical skill that 90% of people want to know

某马旅游网站开发(对servlet的优化)

Data Lake (XV): spark and iceberg integrate write operations

CoinDesk评波场去中心化进程:让人们看到互联网的未来

35岁真就成了职业危机?不,我的技术在积累,我还越吃越香了

ROS从入门到精通(九) 可视化仿真初体验之TurtleBot3

How to measure whether the product is "just needed, high frequency, pain points"

用語雀寫文章了,功能真心强大!

Stm32f1 and stm32cubeide programming example - rotary encoder drive
随机推荐
【leetcode】day1
2022.7.7-----leetcode.648
Chisel tutorial - 05 Sequential logic in chisel (including explicit multi clock, explicit synchronous reset and explicit asynchronous reset)
Enterprise application demand-oriented development of human resources department, employee attendance records and paid wages business process cases
全自动化处理每月缺卡数据,输出缺卡人员信息
Chisel tutorial - 02 Chisel environment configuration and implementation and testing of the first chisel module
Is 35 really a career crisis? No, my skills are accumulating, and the more I eat, the better
【编程题】【Scratch二级】2019.09 绘制雪花图案
Daily question brushing record (16)
一鍵免費翻譯300多頁的pdf文檔
Is it safe to buy funds online?
[basis of recommendation system] sampling and construction of positive and negative samples
[question de programmation] [scratch niveau 2] oiseaux volants en décembre 2019
Les mots ont été écrits, la fonction est vraiment puissante!
Cmake learning notes (1) compile single source programs with cmake
Fully automated processing of monthly card shortage data and output of card shortage personnel information
QT and OpenGL: load 3D models using the open asset import library (assimp)
How does starfish OS enable the value of SFO in the fourth phase of SFO destruction?
ROS from entry to mastery (IX) initial experience of visual simulation: turtlebot3
Codeworks 5 questions per day (average 1500) - day 8