当前位置:网站首页>A network composed of three convolution layers completes the image classification task of cifar10 data set
A network composed of three convolution layers completes the image classification task of cifar10 data set
2022-07-08 00:52:00 【Vertira】
paddle A network composed of three convolution layers cifar10 Image classification task of dataset
Article content source paddle Official website , The code is not very complete , Some have been modified , Ensure complete running code and effect diagram
Abstract : This example tutorial will demonstrate how to use the convolutional neural network of the propeller to complete the task of image classification . This is a relatively simple example , A network consisting of three convolution layers will be used cifar10 Image classification task of dataset .
One 、 Environment configuration
import paddle
import paddle.nn.functional as F
from paddle.vision.transforms import ToTensor
import numpy as np
import matplotlib.pyplot as plt
print(paddle.__version__)
paddle How to install and configure It's omitted here
Two 、 Load data set
This case will use the API Complete the data set download and prepare the data iterator for the subsequent training tasks .cifar10 Data set from 60000 The size of Zhang is 32 * 32 Color picture composition , Among them is 50000 Pictures form a training set , in addition 10000 Pictures form the test set . These pictures are divided into 10 Categories , A model will be trained to classify pictures correctly .
transform = ToTensor()
cifar10_train = paddle.vision.datasets.Cifar10(mode='train',
transform=transform)
cifar10_test = paddle.vision.datasets.Cifar10(mode='test',
transform=transform)
3、 ... and 、 Build a network
Next, use the propeller to define one, which uses three two-dimensional convolutions ( Conv2D
) And use after each convolution relu
Activation function , Two two-dimensional pooling layers ( MaxPool2D
), And two linear transformation layers , Let's take one (32, 32, 3) The shape image is mapped to 10 Outputs , This corresponds to 10 Categories of categories .
class MyNet(paddle.nn.Layer):
def __init__(self, num_classes=1):
super(MyNet, self).__init__()
self.conv1 = paddle.nn.Conv2D(in_channels=3, out_channels=32, kernel_size=(3, 3))
self.pool1 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
self.conv2 = paddle.nn.Conv2D(in_channels=32, out_channels=64, kernel_size=(3,3))
self.pool2 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
self.conv3 = paddle.nn.Conv2D(in_channels=64, out_channels=64, kernel_size=(3,3))
self.flatten = paddle.nn.Flatten()
self.linear1 = paddle.nn.Linear(in_features=1024, out_features=64)
self.linear2 = paddle.nn.Linear(in_features=64, out_features=num_classes)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.pool1(x)
x = self.conv2(x)
x = F.relu(x)
x = self.pool2(x)
x = self.conv3(x)
x = F.relu(x)
x = self.flatten(x)
x = self.linear1(x)
x = F.relu(x)
x = self.linear2(x)
return x
Four 、 model training & forecast ¶
Next , Use a cycle to train the model , will :
Use
paddle.optimizer.Adam
Optimizer to optimize .Use
F.cross_entropy
To calculate the loss value .Use
paddle.io.DataLoader
To load data and build batch.
epoch_num = 10
batch_size = 32
learning_rate = 0.001
val_acc_history = []
val_loss_history = []
def train(model):
print('start training ... ')
# turn into training mode
model.train()
opt = paddle.optimizer.Adam(learning_rate=learning_rate,
parameters=model.parameters())
train_loader = paddle.io.DataLoader(cifar10_train,
shuffle=True,
batch_size=batch_size)
valid_loader = paddle.io.DataLoader(cifar10_test, batch_size=batch_size)
for epoch in range(epoch_num):
for batch_id, data in enumerate(train_loader()):
x_data = data[0]
y_data = paddle.to_tensor(data[1])
y_data = paddle.unsqueeze(y_data, 1)
logits = model(x_data)
loss = F.cross_entropy(logits, y_data)
if batch_id % 1000 == 0:
print("epoch: {}, batch_id: {}, loss is: {}".format(epoch, batch_id, loss.numpy()))
loss.backward()
opt.step()
opt.clear_grad()
# evaluate model after one epoch
model.eval()
accuracies = []
losses = []
for batch_id, data in enumerate(valid_loader()):
x_data = data[0]
y_data = paddle.to_tensor(data[1])
y_data = paddle.unsqueeze(y_data, 1)
logits = model(x_data)
loss = F.cross_entropy(logits, y_data)
acc = paddle.metric.accuracy(logits, y_data)
accuracies.append(acc.numpy())
losses.append(loss.numpy())
avg_acc, avg_loss = np.mean(accuracies), np.mean(losses)
print("[validation] accuracy/loss: {}/{}".format(avg_acc, avg_loss))
val_acc_history.append(avg_acc)
val_loss_history.append(avg_loss)
model.train()
model = MyNet(num_classes=10)
train(model)
Running results
start training ...
epoch: 0, batch_id: 0, loss is: [2.7433677]
epoch: 0, batch_id: 1000, loss is: [1.5053985]
[validation] accuracy/loss: 0.5752795338630676/1.1952502727508545
epoch: 1, batch_id: 0, loss is: [1.2686675]
epoch: 1, batch_id: 1000, loss is: [0.6766195]
[validation] accuracy/loss: 0.6521565318107605/0.9908956289291382
epoch: 2, batch_id: 0, loss is: [0.97449476]
epoch: 2, batch_id: 1000, loss is: [0.7748282]
[validation] accuracy/loss: 0.680111825466156/0.9200474619865417
epoch: 3, batch_id: 0, loss is: [0.7913307]
epoch: 3, batch_id: 1000, loss is: [1.0034081]
[validation] accuracy/loss: 0.6979832053184509/0.8721970915794373
epoch: 4, batch_id: 0, loss is: [0.6251695]
epoch: 4, batch_id: 1000, loss is: [0.6004331]
[validation] accuracy/loss: 0.6930910348892212/0.8982931971549988
epoch: 5, batch_id: 0, loss is: [0.6123275]
epoch: 5, batch_id: 1000, loss is: [0.8438066]
[validation] accuracy/loss: 0.710463285446167/0.8458449840545654
epoch: 6, batch_id: 0, loss is: [0.47533002]
epoch: 6, batch_id: 1000, loss is: [0.41863057]
[validation] accuracy/loss: 0.7125598788261414/0.8965839147567749
epoch: 7, batch_id: 0, loss is: [0.64983004]
epoch: 7, batch_id: 1000, loss is: [0.61536294]
[validation] accuracy/loss: 0.7009784579277039/0.9212258458137512
epoch: 8, batch_id: 0, loss is: [0.79953825]
epoch: 8, batch_id: 1000, loss is: [0.6168741]
[validation] accuracy/loss: 0.7134584784507751/0.8829751014709473
epoch: 9, batch_id: 0, loss is: [0.33510458]
epoch: 9, batch_id: 1000, loss is: [0.3573485]
[validation] accuracy/loss: 0.6938897967338562/0.9611227512359619
Show the code of the graph
plt.plot(val_acc_history, label = 'validation accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0.5, 0.8])
plt.legend(loc='lower right')
It is shown as follows
边栏推荐
- 《因果性Causality》教程,哥本哈根大学Jonas Peters讲授
- C # generics and performance comparison
- 【obs】Impossible to find entrance point CreateDirect3D11DeviceFromDXGIDevice
- "An excellent programmer is worth five ordinary programmers", and the gap lies in these seven key points
- 1293_FreeRTOS中xTaskResumeAll()接口的实现分析
- ReentrantLock 公平锁源码 第0篇
- Fofa attack and defense challenge record
- Kubernetes Static Pod (静态Pod)
- Introduction to paddle - using lenet to realize image classification method I in MNIST
- Basic mode of service mesh
猜你喜欢
备库一直有延迟,查看mrp为wait_for_log,重启mrp后为apply_log但过一会又wait_for_log
Service mesh introduction, istio overview
新库上线 | CnOpenData中华老字号企业名录
Qt添加资源文件,为QAction添加图标,建立信号槽函数并实现
New library online | cnopendata China Star Hotel data
基于人脸识别实现课堂抬头率检测
DNS series (I): why does the updated DNS record not take effect?
搭建ADG过程中复制报错 RMAN-03009 ORA-03113
赞!idea 如何单窗口打开多个项目?
SDNU_ACM_ICPC_2022_Summer_Practice(1~2)
随机推荐
哪个券商公司开户佣金低又安全,又靠谱
赞!idea 如何单窗口打开多个项目?
Qt不同类之间建立信号槽,并传递参数
Play sonar
什么是负载均衡?DNS如何实现负载均衡?
《因果性Causality》教程,哥本哈根大学Jonas Peters讲授
How does the markdown editor of CSDN input mathematical formulas--- Latex syntax summary
备库一直有延迟,查看mrp为wait_for_log,重启mrp后为apply_log但过一会又wait_for_log
DNS 系列(一):为什么更新了 DNS 记录不生效?
服务器防御DDOS的方法,杭州高防IP段103.219.39.x
应用实践 | 数仓体系效率全面提升!同程数科基于 Apache Doris 的数据仓库建设
NVIDIA Jetson测试安装yolox过程记录
语义分割模型库segmentation_models_pytorch的详细使用介绍
What is load balancing? How does DNS achieve load balancing?
A brief history of information by James Gleick
取消select的默认样式的向下箭头和设置select默认字样
German prime minister says Ukraine will not receive "NATO style" security guarantee
Cause analysis and solution of too laggy page of [test interview questions]
基于微信小程序开发的我最在行的小游戏
How to learn a new technology (programming language)