当前位置:网站首页>paddle入门-使用LeNet在MNIST实现图像分类方法二
paddle入门-使用LeNet在MNIST实现图像分类方法二
2022-07-07 22:06:00 【Vertira】
使用LeNet在MNIST实现图像分类方法二
与方法一相比,方法二 有点麻烦,适合进阶
五、方式2:基于基础API,完成模型的训练与预测
5.1 模型训练¶
组网后,开始对模型进行训练,先构建train_loader
,加载训练数据,然后定义train
函数,设置好损失函数后,按batch加载数据,完成模型的训练。
import paddle.nn.functional as F
train_loader = paddle.io.DataLoader(train_dataset, batch_size=64, shuffle=True)
# 加载训练集 batch_size 设为 64
def train(model):
model.train()
epochs = 2
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
# 用Adam作为优化函数
for epoch in range(epochs):
for batch_id, data in enumerate(train_loader()):
x_data = data[0]
y_data = data[1]
predicts = model(x_data)
loss = F.cross_entropy(predicts, y_data)
# 计算损失
acc = paddle.metric.accuracy(predicts, y_data)
loss.backward()
if batch_id % 300 == 0:
print("epoch: {}, batch_id: {}, loss is: {}, acc is: {}".format(epoch, batch_id, loss.numpy(), acc.numpy()))
optim.step()
optim.clear_grad()
model = LeNet()
train(model)
运行结果
epoch: 0, batch_id: 0, loss is: [3.2611141], acc is: [0.078125]
epoch: 0, batch_id: 300, loss is: [0.24404016], acc is: [0.921875]
epoch: 0, batch_id: 600, loss is: [0.03953885], acc is: [1.]
epoch: 0, batch_id: 900, loss is: [0.03700985], acc is: [0.984375]
epoch: 1, batch_id: 0, loss is: [0.05806625], acc is: [0.96875]
epoch: 1, batch_id: 300, loss is: [0.06538856], acc is: [0.953125]
epoch: 1, batch_id: 600, loss is: [0.03884572], acc is: [0.984375]
epoch: 1, batch_id: 900, loss is: [0.01922364], acc is: [0.984375]
5.2 模型验证¶
训练完成后,需要验证模型的效果,此时,加载测试数据集,然后用训练好的模对测试集进行预测,计算损失与精度。
test_loader = paddle.io.DataLoader(test_dataset, places=paddle.CPUPlace(), batch_size=64)
# 加载测试数据集
def test(model):
model.eval()
batch_size = 64
for batch_id, data in enumerate(test_loader()):
x_data = data[0]
y_data = data[1]
predicts = model(x_data)
# 获取预测结果
loss = F.cross_entropy(predicts, y_data)
acc = paddle.metric.accuracy(predicts, y_data)
if batch_id % 20 == 0:
print("batch_id: {}, loss is: {}, acc is: {}".format(batch_id, loss.numpy(), acc.numpy()))
test(model)
运行结果
batch_id: 0, loss is: [0.01972857], acc is: [0.984375]
batch_id: 20, loss is: [0.19958115], acc is: [0.9375]
batch_id: 40, loss is: [0.23575728], acc is: [0.953125]
batch_id: 60, loss is: [0.07018849], acc is: [0.984375]
batch_id: 80, loss is: [0.02309197], acc is: [0.984375]
batch_id: 100, loss is: [0.00239462], acc is: [1.]
batch_id: 120, loss is: [0.01583934], acc is: [1.]
batch_id: 140, loss is: [0.00399609], acc is: [1.]
方式二结束¶
以上就是方式二,通过底层API,可以清楚的看到训练和测试中的每一步过程。但是,这种方式比较复杂。因此,我们提供了训练方式一,使用高层API来完成模型的训练与预测。对比底层API,高层API能够更加快速、高效的完成模型的训练与测试。
六、总结¶
以上就是用LeNet对手写数字数据及MNIST进行分类。本示例提供了两种训练模型的方式,一种可以快速完成模型的组建与预测,非常适合新手用户上手。另一种则需要多个步骤来完成模型的训练,适合进阶用户使用。
方法二的完整代码如下:
import os
import cv2
import numpy as np
from paddle.io import Dataset
import paddle.vision.transforms as T
import matplotlib.pyplot as plt
import paddle.nn.functional as F
import paddle
from paddle.metric import Accuracy
from paddle.vision.transforms import Compose, Normalize
transform = Compose([Normalize(mean=[127.5],
std=[127.5],
data_format='CHW')])
# 使用transform对数据集做归一化
print('download training data and load training data')
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
print('load finished')
class LeNet(paddle.nn.Layer):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)
self.max_pool1 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
self.conv2 = paddle.nn.Conv2D(in_channels=6, out_channels=16, kernel_size=5, stride=1)
self.max_pool2 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
self.linear1 = paddle.nn.Linear(in_features=16*5*5, out_features=120)
self.linear2 = paddle.nn.Linear(in_features=120, out_features=84)
self.linear3 = paddle.nn.Linear(in_features=84, out_features=10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.max_pool1(x)
x = self.conv2(x)
x = F.relu(x)
x = self.max_pool2(x)
x = paddle.flatten(x, start_axis=1,stop_axis=-1)
x = self.linear1(x)
x = F.relu(x)
x = self.linear2(x)
x = F.relu(x)
x = self.linear3(x)
return x
#组网后,开始对模型进行训练,先构建train_loader,加载训练数据,然后定义train函数,设置好损失函数后,按batch加载数据,完成模型的训练。
train_loader = paddle.io.DataLoader(train_dataset, batch_size=64, shuffle=True)
# 加载训练集 batch_size 设为 64
def train(model):
model.train()
epochs = 2
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
# 用Adam作为优化函数
for epoch in range(epochs):
for batch_id, data in enumerate(train_loader()):
x_data = data[0]
y_data = data[1]
predicts = model(x_data)
loss = F.cross_entropy(predicts, y_data)
# 计算损失
acc = paddle.metric.accuracy(predicts, y_data)
loss.backward()
if batch_id % 300 == 0:
print("epoch: {}, batch_id: {}, loss is: {}, acc is: {}".format(epoch, batch_id, loss.numpy(), acc.numpy()))
optim.step()
optim.clear_grad()
model = LeNet()
train(model)
边栏推荐
- 10 schemes to ensure interface data security
- 腾讯安全发布《BOT管理白皮书》|解读BOT攻击,探索防护之道
- 全自动化处理每月缺卡数据,输出缺卡人员信息
- 【编程题】【Scratch二级】2019.03 垃圾分类
- Is 35 really a career crisis? No, my skills are accumulating, and the more I eat, the better
- [programming problem] [scratch Level 2] 2019.09 make bat Challenge Game
- 串联二极管,提高耐压
- [basis of recommendation system] sampling and construction of positive and negative samples
- One click free translation of more than 300 pages of PDF documents
- ROS from entry to mastery (IX) initial experience of visual simulation: turtlebot3
猜你喜欢
Using Google test in QT
Aitm3.0005 smoke toxicity test
【史上最详细】信贷中逾期天数统计说明
10 schemes to ensure interface data security
智慧监管入场,美团等互联网服务平台何去何从
单机高并发模型设计
Preliminary test of optical flow sensor: gl9306
Fully automated processing of monthly card shortage data and output of card shortage personnel information
一鍵免費翻譯300多頁的pdf文檔
Pypharm uses, and the third-party library has errors due to version problems
随机推荐
Restricted linear table
Database query - what is the highest data?
One click installation with fishros in blue bridge ROS
【leetcode】day1
【史上最详细】信贷中逾期天数统计说明
Opengl3.3 mouse picking up objects
STM32F1與STM32CubeIDE編程實例-旋轉編碼器驅動
Pycharm basic settings latest version 2022
一鍵免費翻譯300多頁的pdf文檔
80% of the people answered incorrectly. Does the leaf on the apple logo face left or right?
Flash download setup
正畸注意事项(持续更新中)
Using Google test in QT
Go time package common functions
串联二极管,提高耐压
CoinDesk评波场去中心化进程:让人们看到互联网的未来
Uic564-2 Appendix 4 - flame retardant fire test: flame diffusion
2022.7.7-----leetcode. six hundred and forty-eight
Open display PDF file in web page
如果在构造函数中抛出异常,最好的做法是防止内存泄漏?