当前位置:网站首页>Keras deep learning practice - traffic sign recognition
Keras deep learning practice - traffic sign recognition
2022-08-01 18:52:00 【Hope Xiaohui】
Keras深度学习实战——交通标志识别
0. 前言
在道路交通场景中,交通标志识别作为驾驶辅助系统与无人驾驶车辆中不可缺少的技术,为车辆行驶中提供了安全保障.vehicles on the road,The environment around the road includes many important traffic sign information,Make correct driving behavior on the road based on traffic sign information,Traffic accidents are usually avoided.Traffic sign recognition can detect and identify traffic signs on the current road,The necessary information about the road is then derived.
但交通标志会受到车辆的运动状态、光照以及遮挡等环境因素的影响,因此如何使车辆在道路交通中快速准确地帮助驾驶员识别交通标志已经成为智能交通领域的热点问题之一.鉴于交通标志识别在自动驾驶等应用中具有重要作用,在节中,我们将学习使用卷积神经网络实现交通标志识别.
1. 数据集与模型分析
1.1 数据集介绍
德国交通标志识别基准 (German Traffic Sign Recognition Benchmark
, GTSRB
) 是高级驾驶辅助系统和自动驾驶领域的交通标志图像分类基准.其中共包含 43
种不同类别的交通标志.可以在官方网页中下载相关数据集.
每张图片包含一个交通标志,图像包含实际交通标志周围的环境像素,大约为交通标志尺寸的 10%
(至少为 5
个像素),图像以 PPM
格式存储,图像尺寸在 15x15
到 250x250
像素之间.
1.2 模型分析
为了使车辆在道路交通中快速准确地帮助驾驶员识别交通标志,我们将采用以下策略:
- 加载交通标志数据集
- 在输入图像上执行直方图归一化:
- 这是由于,数据集中的某些图像是在白天拍摄的,有些则可能是在夜晚拍摄的,拍摄照片时不同的照明条件会导致像素值发生变化
- 直方图归一化可以对像素值进行归一化,以便它们可以具有相似的分布
- 缩放输入图像,使其具有相同尺寸,并对输出向量进行独热编码
- 构建、编译并拟合卷积神经网络模型以减少分类交叉熵损失值
2. 交通标志识别
本节中,我们将实现在上一节中分析的交通标志识别模型.
2.1 数据集加载与预处理
首先,加载所需库,并将图像路径读入列表:
from skimage import io
import os
from glob import glob
import matplotlib.pyplot as plt
import cv2
root_dir = 'GTSRB_Final_Training_Images/GTSRB/Final_Training/Images/'
all_img_paths = glob(os.path.join(root_dir, '*/*.ppm'))
为了进一步了解数据集中,我们可以查看数据集中图像样本如下所示:
for i, img_path in enumerate(all_img_paths[:800:100]):
img = cv2.imread(img_path)
plt.subplot(2,4,i+1)
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
plt.show()
如下图所示,可以看出数据集中的图像具有不同的形状尺寸,并且图像的明暗也并不相同.因此,我们必须对图像进行预处理,以使所有图像在不同光照条件都具有同样的分布,并对其进行缩放具有相同形状.
在输入数据集执行直方图归一化,如下所示:
import numpy as np
from skimage import color, exposure, transform
NUM_CLASSES = 43
IMG_SIZE = 48
def preprocess_img(img):
hsv = color.rgb2hsv(img)
hsv[:, :, 2] = exposure.equalize_hist(hsv[:, :, 2])
img = color.hsv2rgb(hsv)
img = transform.resize(img, (IMG_SIZE, IMG_SIZE))
return img
在以上代码中,我们首先将 RGB
格式的图像转换为 HSV
(Hue Saturation Value
) 格式.通过将图像从 RGB
转换为 HSV
格式,之后,我们对 HSV
格式的图像值进行归一化,通过使用 equalize_hist
方法令它们具有相同分布.将图像在 HSV
格式的最后一个通道(即明度通道,使图像具有相同的明度分布)中执行归一化后,将其转换回 RGB
格式.最后,我们将图像调整为同样的尺寸.
完成直方图归一化后检查图像,观察直方图归一化前后图像的变化:
for i, img_path in enumerate(all_img_paths[:800:100]):
img = cv2.imread(img_path)
plt.subplot(4,4,2*i+1)
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
plt.title('Original')
img = preprocess_img(img)
plt.subplot(4,4,2*i+2)
plt.imshow(img[:, :, ::-1])
plt.title('Transformed')
plt.show()
从以上结果图片中,可以看到,在直方图归一化后,图像的亮度值发生了很大的变化,更容易分辨出交通标志的类别.
接下来,为模型准备输入和输出数组,并将数据集划分为训练和测试数据集:
from keras.utils import to_categorical
count = 0
imgs = []
labels = []
for img_path in all_img_paths:
img = preprocess_img(io.imread(img_path))
label = img_path.split('/')[-2]
imgs.append(img)
labels.append(label)
x = np.array(imgs)
y = to_categorical(labels, num_classes=NUM_CLASSES)
# 将数据集划分为训练和测试数据集
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)
2.2 模型构建与训练
构建用于交通标志识别的分类模型:
from keras.models import Sequential
from keras.layers import Conv2D, MaxPool2D, Dropout, ReLU
from keras.layers import Flatten, Dense, BatchNormalization
model = Sequential()
model.add(Conv2D(32, (3,3), padding='same', input_shape=(IMG_SIZE, IMG_SIZE, 3)))
model.add(BatchNormalization())
model.add(ReLU())
model.add(Conv2D(32, (3,3)))
model.add(BatchNormalization())
model.add(ReLU())
model.add(MaxPool2D(pool_size=(2,2)))
model.add(Dropout(0.2))
model.add(Conv2D(64, (3,3), padding='same'))
model.add(BatchNormalization())
model.add(ReLU())
model.add(Conv2D(64, (3,3)))
model.add(BatchNormalization())
model.add(ReLU())
model.add(MaxPool2D(pool_size=(2,2)))
model.add(Dropout(0.2))
model.add(Conv2D(128, (3,3), padding='same'))
model.add(BatchNormalization())
model.add(ReLU())
model.add(Conv2D(128, (3,3)))
model.add(BatchNormalization())
model.add(ReLU())
model.add(Dropout(0.2))
model.add(Flatten())
model.add(Dense(512))
model.add(BatchNormalization())
model.add(ReLU())
model.add(Dropout(0.2))
model.add(Dense(NUM_CLASSES, activation='softmax'))
model.summary()
该模型的简要结构信息输入如下:
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 48, 48, 32) 896
_________________________________________________________________
batch_normalization (BatchNo (None, 48, 48, 32) 128
_________________________________________________________________
re_lu (ReLU) (None, 48, 48, 32) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 46, 46, 32) 9248
_________________________________________________________________
batch_normalization_1 (Batch (None, 46, 46, 32) 128
_________________________________________________________________
re_lu_1 (ReLU) (None, 46, 46, 32) 0
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 23, 23, 32) 0
_________________________________________________________________
dropout (Dropout) (None, 23, 23, 32) 0
_________________________________________________________________
conv2d_2 (Conv2D) (None, 23, 23, 64) 18496
_________________________________________________________________
batch_normalization_2 (Batch (None, 23, 23, 64) 256
_________________________________________________________________
re_lu_2 (ReLU) (None, 23, 23, 64) 0
_________________________________________________________________
conv2d_3 (Conv2D) (None, 21, 21, 64) 36928
_________________________________________________________________
batch_normalization_3 (Batch (None, 21, 21, 64) 256
_________________________________________________________________
re_lu_3 (ReLU) (None, 21, 21, 64) 0
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 10, 10, 64) 0
_________________________________________________________________
dropout_1 (Dropout) (None, 10, 10, 64) 0
_________________________________________________________________
conv2d_4 (Conv2D) (None, 10, 10, 128) 73856
_________________________________________________________________
batch_normalization_4 (Batch (None, 10, 10, 128) 512
_________________________________________________________________
re_lu_4 (ReLU) (None, 10, 10, 128) 0
_________________________________________________________________
conv2d_5 (Conv2D) (None, 8, 8, 128) 147584
_________________________________________________________________
batch_normalization_5 (Batch (None, 8, 8, 128) 512
_________________________________________________________________
re_lu_5 (ReLU) (None, 8, 8, 128) 0
_________________________________________________________________
dropout_2 (Dropout) (None, 8, 8, 128) 0
_________________________________________________________________
flatten (Flatten) (None, 8192) 0
_________________________________________________________________
dense (Dense) (None, 512) 4194816
_________________________________________________________________
batch_normalization_6 (Batch (None, 512) 2048
_________________________________________________________________
re_lu_6 (ReLU) (None, 512) 0
_________________________________________________________________
dropout_3 (Dropout) (None, 512) 0
_________________________________________________________________
dense_1 (Dense) (None, 43) 22059
=================================================================
Total params: 4,507,723
Trainable params: 4,505,803
Non-trainable params: 1,920
_________________________________________________________________
最后,编译并拟合模型,如下所示:
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['acc'])
history = model.fit(x_train, y_train,
batch_size=64,
epochs=10,
validation_data=(x_test, y_test),
verbose=1)
训练完成的模型准确率约为 99.8%
:
此外,如果我们使用完全相同的模型,但不进行直方图归一化时,则模型的准确率大约为 97%
,可以看到使用直方图归一化进行预处理后可以显著提高模型的分类准确率.
相关链接
Keras深度学习实战(1)——神经网络基础与模型训练过程详解
Keras深度学习实战(2)——使用Keras构建神经网络
Keras深度学习实战(7)——卷积神经网络详解与实现
边栏推荐
- 【LeetCode】Day109-最长回文串
- 想随时、随地、随心使用数据库的朋友们,全体注意!
- 基于flowable的upp(统一流程平台)运行性能优化
- MySQL中超键、主键及候选键的区别是什么
- OpenCV installation, QT, VS configuration project settings
- When compiling a program with boost library with VS2013, it prompts fatal error C1001: An internal error occurred in the compiler
- Multi-Party Threshold Private Set Intersection with Sublinear Communication-2021:解读
- How to solve the dynamic binding of el-form-item prop attribute does not take effect
- [National Programming] "Software Programming - Lecture Video" [Zero Basic Introduction to Practical Application]
- 暑假第二周总结博客
猜你喜欢
No need to crack, install Visual Studio 2013 Community Edition on the official website
Prometheus的Recording rules实践
The XML configuration
Leetcode73. Matrix Zeroing
Leetcode74. 搜索二维矩阵
粒子滤波 particle filter —从贝叶斯滤波到粒子滤波——Part-I(贝叶斯滤波)
MySQL 45 Talk | 09 How to choose common index and unique index?
深入浅出Flask PIN
Screen: GFF, OGS, Oncell, Incell of full lamination process
shell脚本专题(07):文件由cfs到bos
随机推荐
MySQL中超键、主键及候选键的区别是什么
SQL function TO_DATE (1)
硬件大熊原创合集(2022/07更新)
顺序表的简单描述及代码的简单实现
选择合适的 DevOps 工具,从理解 DevOps 开始
【LeetCode】Day109-最长回文串
ExcelPatternTool: Excel表格-数据库互导工具
如何让固定点的监控设备在EasyCVR平台GIS电子地图上显示地理位置?
GZIPOutputStream 类源码分析
SQL函数 TO_DATE(二)
亚马逊云科技Build On2022技能提升计划第二季——揭秘出海爆款新物种背后的黑科技
云原生全景图详解
Summer vacation second week wrap-up blog
用VS2013编译带boost库程序时提示 fatal error C1001: 编译器中发生内部错误
打开微信客服
1065 A+B and C (64bit)
#yyds干货盘点# 面试必刷TOP101: 链表中倒数最后k个结点
explain each field introduction
MySQL关系型数据库事务的ACID特性与实现方法
想随时、随地、随心使用数据库的朋友们,全体注意!