当前位置:网站首页>paddle入门-使用LeNet在MNIST实现图像分类方法二
paddle入门-使用LeNet在MNIST实现图像分类方法二
2022-07-07 22:06:00 【Vertira】
使用LeNet在MNIST实现图像分类方法二
与方法一相比,方法二 有点麻烦,适合进阶
五、方式2:基于基础API,完成模型的训练与预测
5.1 模型训练¶
组网后,开始对模型进行训练,先构建train_loader
,加载训练数据,然后定义train
函数,设置好损失函数后,按batch加载数据,完成模型的训练。
import paddle.nn.functional as F
train_loader = paddle.io.DataLoader(train_dataset, batch_size=64, shuffle=True)
# 加载训练集 batch_size 设为 64
def train(model):
model.train()
epochs = 2
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
# 用Adam作为优化函数
for epoch in range(epochs):
for batch_id, data in enumerate(train_loader()):
x_data = data[0]
y_data = data[1]
predicts = model(x_data)
loss = F.cross_entropy(predicts, y_data)
# 计算损失
acc = paddle.metric.accuracy(predicts, y_data)
loss.backward()
if batch_id % 300 == 0:
print("epoch: {}, batch_id: {}, loss is: {}, acc is: {}".format(epoch, batch_id, loss.numpy(), acc.numpy()))
optim.step()
optim.clear_grad()
model = LeNet()
train(model)
运行结果
epoch: 0, batch_id: 0, loss is: [3.2611141], acc is: [0.078125]
epoch: 0, batch_id: 300, loss is: [0.24404016], acc is: [0.921875]
epoch: 0, batch_id: 600, loss is: [0.03953885], acc is: [1.]
epoch: 0, batch_id: 900, loss is: [0.03700985], acc is: [0.984375]
epoch: 1, batch_id: 0, loss is: [0.05806625], acc is: [0.96875]
epoch: 1, batch_id: 300, loss is: [0.06538856], acc is: [0.953125]
epoch: 1, batch_id: 600, loss is: [0.03884572], acc is: [0.984375]
epoch: 1, batch_id: 900, loss is: [0.01922364], acc is: [0.984375]
5.2 模型验证¶
训练完成后,需要验证模型的效果,此时,加载测试数据集,然后用训练好的模对测试集进行预测,计算损失与精度。
test_loader = paddle.io.DataLoader(test_dataset, places=paddle.CPUPlace(), batch_size=64)
# 加载测试数据集
def test(model):
model.eval()
batch_size = 64
for batch_id, data in enumerate(test_loader()):
x_data = data[0]
y_data = data[1]
predicts = model(x_data)
# 获取预测结果
loss = F.cross_entropy(predicts, y_data)
acc = paddle.metric.accuracy(predicts, y_data)
if batch_id % 20 == 0:
print("batch_id: {}, loss is: {}, acc is: {}".format(batch_id, loss.numpy(), acc.numpy()))
test(model)
运行结果
batch_id: 0, loss is: [0.01972857], acc is: [0.984375]
batch_id: 20, loss is: [0.19958115], acc is: [0.9375]
batch_id: 40, loss is: [0.23575728], acc is: [0.953125]
batch_id: 60, loss is: [0.07018849], acc is: [0.984375]
batch_id: 80, loss is: [0.02309197], acc is: [0.984375]
batch_id: 100, loss is: [0.00239462], acc is: [1.]
batch_id: 120, loss is: [0.01583934], acc is: [1.]
batch_id: 140, loss is: [0.00399609], acc is: [1.]
方式二结束¶
以上就是方式二,通过底层API,可以清楚的看到训练和测试中的每一步过程。但是,这种方式比较复杂。因此,我们提供了训练方式一,使用高层API来完成模型的训练与预测。对比底层API,高层API能够更加快速、高效的完成模型的训练与测试。
六、总结¶
以上就是用LeNet对手写数字数据及MNIST进行分类。本示例提供了两种训练模型的方式,一种可以快速完成模型的组建与预测,非常适合新手用户上手。另一种则需要多个步骤来完成模型的训练,适合进阶用户使用。
方法二的完整代码如下:
import os
import cv2
import numpy as np
from paddle.io import Dataset
import paddle.vision.transforms as T
import matplotlib.pyplot as plt
import paddle.nn.functional as F
import paddle
from paddle.metric import Accuracy
from paddle.vision.transforms import Compose, Normalize
transform = Compose([Normalize(mean=[127.5],
std=[127.5],
data_format='CHW')])
# 使用transform对数据集做归一化
print('download training data and load training data')
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
print('load finished')
class LeNet(paddle.nn.Layer):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)
self.max_pool1 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
self.conv2 = paddle.nn.Conv2D(in_channels=6, out_channels=16, kernel_size=5, stride=1)
self.max_pool2 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
self.linear1 = paddle.nn.Linear(in_features=16*5*5, out_features=120)
self.linear2 = paddle.nn.Linear(in_features=120, out_features=84)
self.linear3 = paddle.nn.Linear(in_features=84, out_features=10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.max_pool1(x)
x = self.conv2(x)
x = F.relu(x)
x = self.max_pool2(x)
x = paddle.flatten(x, start_axis=1,stop_axis=-1)
x = self.linear1(x)
x = F.relu(x)
x = self.linear2(x)
x = F.relu(x)
x = self.linear3(x)
return x
#组网后,开始对模型进行训练,先构建train_loader,加载训练数据,然后定义train函数,设置好损失函数后,按batch加载数据,完成模型的训练。
train_loader = paddle.io.DataLoader(train_dataset, batch_size=64, shuffle=True)
# 加载训练集 batch_size 设为 64
def train(model):
model.train()
epochs = 2
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
# 用Adam作为优化函数
for epoch in range(epochs):
for batch_id, data in enumerate(train_loader()):
x_data = data[0]
y_data = data[1]
predicts = model(x_data)
loss = F.cross_entropy(predicts, y_data)
# 计算损失
acc = paddle.metric.accuracy(predicts, y_data)
loss.backward()
if batch_id % 300 == 0:
print("epoch: {}, batch_id: {}, loss is: {}, acc is: {}".format(epoch, batch_id, loss.numpy(), acc.numpy()))
optim.step()
optim.clear_grad()
model = LeNet()
train(model)
边栏推荐
- When creating body middleware, express Is there any difference between setting extended to true and false in urlencoded?
- Uic564-2 Appendix 4 - flame retardant fire test: flame diffusion
- Linkedblockingqueue source code analysis - add and delete
- 正畸注意事项(持续更新中)
- Daily question brushing record (16)
- [path planning] use the vertical distance limit method and Bessel to optimize the path of a star
- Problems faced when connecting to sqlserver after downloading (I)
- new和delete的底层原理以及模板
- Using Google test in QT
- [the most detailed in history] statistical description of overdue days in credit
猜你喜欢
机器人(自动化)等专业课程创新的结果
[leetcode] 20. Valid brackets
Solutions to problems in sqlserver deleting data in tables
【史上最详细】信贷中逾期天数统计说明
全自动化处理每月缺卡数据,输出缺卡人员信息
52岁的周鸿祎,还年轻吗?
How did a fake offer steal $540million from "axie infinity"?
At the age of 35, I made a decision to face unemployment
[programming problem] [scratch Level 2] December 2019 flying birds
One click installation with fishros in blue bridge ROS
随机推荐
面试题详解:用Redis实现分布式锁的血泪史
手写一个模拟的ReentrantLock
Use filters to count URL request time
10 schemes to ensure interface data security
单机高并发模型设计
Go learning notes (2) basic types and statements (1)
Magic fast power
STM32F1与STM32CubeIDE编程实例-旋转编码器驱动
【編程題】【Scratch二級】2019.12 飛翔的小鳥
Orthodontic precautions (continuously updated)
全自动化处理每月缺卡数据,输出缺卡人员信息
Opengl3.3 mouse picking up objects
mysql8.0 ubuntu20.4
Is it safe to buy funds online?
Resolve the URL of token
[path planning] use the vertical distance limit method and Bessel to optimize the path of a star
Introduction to programming hardware
Cmake learning notes (1) compile single source programs with cmake
Chisel tutorial - 00 Ex.scala metals plug-in (vs Code), SBT and coursier exchange endogenous
Coindesk comments on the decentralization process of the wave field: let people see the future of the Internet