当前位置:网站首页>paddle一个由三个卷积层组成的网络完成cifar10数据集的图像分类任务
paddle一个由三个卷积层组成的网络完成cifar10数据集的图像分类任务
2022-07-07 22:06:00 【Vertira】
paddle一个由三个卷积层组成的网络完成cifar10数据集的图像分类任务
文章内容 来源 paddle 官网,代码并不十分完整,部分有修改,保证完整的运行代码和效果图
摘要: 本示例教程将会演示如何使用飞桨的卷积神经网络来完成图像分类任务。这是一个较为简单的示例,将会使用一个由三个卷积层组成的网络完成cifar10数据集的图像分类任务。
一、环境配置
import paddle
import paddle.nn.functional as F
from paddle.vision.transforms import ToTensor
import numpy as np
import matplotlib.pyplot as plt
print(paddle.__version__)
paddle的安装和配置方法 这里略了
二、加载数据集
本案例将会使用飞桨提供的API完成数据集的下载并为后续的训练任务准备好数据迭代器。cifar10数据集由60000张大小为32 * 32的彩色图片组成,其中有50000张图片组成了训练集,另外10000张图片组成了测试集。这些图片分为10个类别,将训练一个模型能够把图片进行正确的分类。
transform = ToTensor()
cifar10_train = paddle.vision.datasets.Cifar10(mode='train',
transform=transform)
cifar10_test = paddle.vision.datasets.Cifar10(mode='test',
transform=transform)
三、组建网络
接下来使用飞桨定义一个使用了三个二维卷积( Conv2D ) 且每次卷积之后使用 relu 激活函数,两个二维池化层( MaxPool2D ),和两个线性变换层组成的分类网络,来把一个(32, 32, 3)形状的图片通过卷积神经网络映射为10个输出,这对应着10个分类的类别。
class MyNet(paddle.nn.Layer):
def __init__(self, num_classes=1):
super(MyNet, self).__init__()
self.conv1 = paddle.nn.Conv2D(in_channels=3, out_channels=32, kernel_size=(3, 3))
self.pool1 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
self.conv2 = paddle.nn.Conv2D(in_channels=32, out_channels=64, kernel_size=(3,3))
self.pool2 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
self.conv3 = paddle.nn.Conv2D(in_channels=64, out_channels=64, kernel_size=(3,3))
self.flatten = paddle.nn.Flatten()
self.linear1 = paddle.nn.Linear(in_features=1024, out_features=64)
self.linear2 = paddle.nn.Linear(in_features=64, out_features=num_classes)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.pool1(x)
x = self.conv2(x)
x = F.relu(x)
x = self.pool2(x)
x = self.conv3(x)
x = F.relu(x)
x = self.flatten(x)
x = self.linear1(x)
x = F.relu(x)
x = self.linear2(x)
return x
四、模型训练&预测¶
接下来,用一个循环来进行模型的训练,将会:
使用
paddle.optimizer.Adam优化器来进行优化。使用
F.cross_entropy来计算损失值。使用
paddle.io.DataLoader来加载数据并组建batch。
epoch_num = 10
batch_size = 32
learning_rate = 0.001
val_acc_history = []
val_loss_history = []
def train(model):
print('start training ... ')
# turn into training mode
model.train()
opt = paddle.optimizer.Adam(learning_rate=learning_rate,
parameters=model.parameters())
train_loader = paddle.io.DataLoader(cifar10_train,
shuffle=True,
batch_size=batch_size)
valid_loader = paddle.io.DataLoader(cifar10_test, batch_size=batch_size)
for epoch in range(epoch_num):
for batch_id, data in enumerate(train_loader()):
x_data = data[0]
y_data = paddle.to_tensor(data[1])
y_data = paddle.unsqueeze(y_data, 1)
logits = model(x_data)
loss = F.cross_entropy(logits, y_data)
if batch_id % 1000 == 0:
print("epoch: {}, batch_id: {}, loss is: {}".format(epoch, batch_id, loss.numpy()))
loss.backward()
opt.step()
opt.clear_grad()
# evaluate model after one epoch
model.eval()
accuracies = []
losses = []
for batch_id, data in enumerate(valid_loader()):
x_data = data[0]
y_data = paddle.to_tensor(data[1])
y_data = paddle.unsqueeze(y_data, 1)
logits = model(x_data)
loss = F.cross_entropy(logits, y_data)
acc = paddle.metric.accuracy(logits, y_data)
accuracies.append(acc.numpy())
losses.append(loss.numpy())
avg_acc, avg_loss = np.mean(accuracies), np.mean(losses)
print("[validation] accuracy/loss: {}/{}".format(avg_acc, avg_loss))
val_acc_history.append(avg_acc)
val_loss_history.append(avg_loss)
model.train()
model = MyNet(num_classes=10)
train(model)
运行结果
start training ...
epoch: 0, batch_id: 0, loss is: [2.7433677]
epoch: 0, batch_id: 1000, loss is: [1.5053985]
[validation] accuracy/loss: 0.5752795338630676/1.1952502727508545
epoch: 1, batch_id: 0, loss is: [1.2686675]
epoch: 1, batch_id: 1000, loss is: [0.6766195]
[validation] accuracy/loss: 0.6521565318107605/0.9908956289291382
epoch: 2, batch_id: 0, loss is: [0.97449476]
epoch: 2, batch_id: 1000, loss is: [0.7748282]
[validation] accuracy/loss: 0.680111825466156/0.9200474619865417
epoch: 3, batch_id: 0, loss is: [0.7913307]
epoch: 3, batch_id: 1000, loss is: [1.0034081]
[validation] accuracy/loss: 0.6979832053184509/0.8721970915794373
epoch: 4, batch_id: 0, loss is: [0.6251695]
epoch: 4, batch_id: 1000, loss is: [0.6004331]
[validation] accuracy/loss: 0.6930910348892212/0.8982931971549988
epoch: 5, batch_id: 0, loss is: [0.6123275]
epoch: 5, batch_id: 1000, loss is: [0.8438066]
[validation] accuracy/loss: 0.710463285446167/0.8458449840545654
epoch: 6, batch_id: 0, loss is: [0.47533002]
epoch: 6, batch_id: 1000, loss is: [0.41863057]
[validation] accuracy/loss: 0.7125598788261414/0.8965839147567749
epoch: 7, batch_id: 0, loss is: [0.64983004]
epoch: 7, batch_id: 1000, loss is: [0.61536294]
[validation] accuracy/loss: 0.7009784579277039/0.9212258458137512
epoch: 8, batch_id: 0, loss is: [0.79953825]
epoch: 8, batch_id: 1000, loss is: [0.6168741]
[validation] accuracy/loss: 0.7134584784507751/0.8829751014709473
epoch: 9, batch_id: 0, loss is: [0.33510458]
epoch: 9, batch_id: 1000, loss is: [0.3573485]
[validation] accuracy/loss: 0.6938897967338562/0.9611227512359619
显示曲线图的的代码
plt.plot(val_acc_history, label = 'validation accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0.5, 0.8])
plt.legend(loc='lower right')
显示如下

边栏推荐
- About the difference between ch32 library function and STM32 library function
- How does starfish OS enable the value of SFO in the fourth phase of SFO destruction?
- 在网页中打开展示pdf文件
- Go learning notes (1) environment installation and hello world
- 如果在构造函数中抛出异常,最好的做法是防止内存泄漏?
- How did a fake offer steal $540million from "axie infinity"?
- Usage of limit and offset (Reprint)
- Relevant methods of sorting arrays in JS (if you want to understand arrays, it's enough to read this article)
- 快速回复二极管整流特性
- Trust orbtk development issues 2022
猜你喜欢

某马旅游网站开发(对servlet的优化)

Traduction gratuite en un clic de plus de 300 pages de documents PDF

Problems faced when connecting to sqlserver after downloading (I)

Anaconda+pycharm+pyqt5 configuration problem: pyuic5 cannot be found exe
![[programming questions] [scratch Level 2] March 2019 garbage classification](/img/08/9f7ebf4302c9239784751b579c9efc.png)
[programming questions] [scratch Level 2] March 2019 garbage classification

机器人(自动化)等专业课程创新的结果

Set up personal network disk with nextcloud

STM32F1與STM32CubeIDE編程實例-旋轉編碼器驅動

Pypharm uses, and the third-party library has errors due to version problems
![[leetcode] 20. Valid brackets](/img/42/5a2c5ec6c1a7dbcdfb2226cdea6a42.png)
[leetcode] 20. Valid brackets
随机推荐
Gorm Association summary
【编程题】【Scratch二级】2019.12 绘制十个正方形
ROS from entry to mastery (IX) initial experience of visual simulation: turtlebot3
Chisel tutorial - 04 Control flow in chisel
【編程題】【Scratch二級】2019.12 飛翔的小鳥
第四期SFO销毁,Starfish OS如何对SFO价值赋能?
How did a fake offer steal $540million from "axie infinity"?
Pigsty:开箱即用的数据库发行版
Scrapy framework
面试题详解:用Redis实现分布式锁的血泪史
蓝桥ROS中使用fishros一键安装
An example analysis of MP4 file format parsing
Codeworks 5 questions per day (average 1500) - day 8
Reading notes 004: Wang Yangming's quotations
STM32F1与STM32CubeIDE编程实例-旋转编码器驱动
[basis of recommendation system] sampling and construction of positive and negative samples
Benchmarking Detection Transfer Learning with Vision Transformers(2021-11)
机器人(自动化)等专业课程创新的结果
串联二极管,提高耐压
Kubectl's handy command line tool: Oh my Zsh tips and tricks