当前位置:网站首页>Keras深度学习实战——交通标志识别
Keras深度学习实战——交通标志识别
2022-08-01 18:41:00 【盼小辉丶】
Keras深度学习实战——交通标志识别
0. 前言
在道路交通场景中,交通标志识别作为驾驶辅助系统与无人驾驶车辆中不可缺少的技术,为车辆行驶中提供了安全保障。在道路上行驶的车辆,道路周围的环境包括许多重要的交通标志信息,根据交通标志信息在道路上做出正确的驾驶行为,通常能够避免发生交通事故。交通标志识别可以检测并识别当前行驶道路上的交通标志,然后得出有关道路的必要信息。
但交通标志会受到车辆的运动状态、光照以及遮挡等环境因素的影响,因此如何使车辆在道路交通中快速准确地帮助驾驶员识别交通标志已经成为智能交通领域的热点问题之一。鉴于交通标志识别在自动驾驶等应用中具有重要作用,在节中,我们将学习使用卷积神经网络实现交通标志识别。
1. 数据集与模型分析
1.1 数据集介绍
德国交通标志识别基准 (German Traffic Sign Recognition Benchmark, GTSRB) 是高级驾驶辅助系统和自动驾驶领域的交通标志图像分类基准。其中共包含 43 种不同类别的交通标志。可以在官方网页中下载相关数据集。
每张图片包含一个交通标志,图像包含实际交通标志周围的环境像素,大约为交通标志尺寸的 10% (至少为 5 个像素),图像以 PPM 格式存储,图像尺寸在 15x15 到 250x250 像素之间。
1.2 模型分析
为了使车辆在道路交通中快速准确地帮助驾驶员识别交通标志,我们将采用以下策略:
- 加载交通标志数据集
- 在输入图像上执行直方图归一化:
- 这是由于,数据集中的某些图像是在白天拍摄的,有些则可能是在夜晚拍摄的,拍摄照片时不同的照明条件会导致像素值发生变化
- 直方图归一化可以对像素值进行归一化,以便它们可以具有相似的分布
- 缩放输入图像,使其具有相同尺寸,并对输出向量进行独热编码
- 构建、编译并拟合卷积神经网络模型以减少分类交叉熵损失值
2. 交通标志识别
本节中,我们将实现在上一节中分析的交通标志识别模型。
2.1 数据集加载与预处理
首先,加载所需库,并将图像路径读入列表:
from skimage import io
import os
from glob import glob
import matplotlib.pyplot as plt
import cv2
root_dir = 'GTSRB_Final_Training_Images/GTSRB/Final_Training/Images/'
all_img_paths = glob(os.path.join(root_dir, '*/*.ppm'))
为了进一步了解数据集中,我们可以查看数据集中图像样本如下所示:
for i, img_path in enumerate(all_img_paths[:800:100]):
img = cv2.imread(img_path)
plt.subplot(2,4,i+1)
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
plt.show()
如下图所示,可以看出数据集中的图像具有不同的形状尺寸,并且图像的明暗也并不相同。因此,我们必须对图像进行预处理,以使所有图像在不同光照条件都具有同样的分布,并对其进行缩放具有相同形状。

在输入数据集执行直方图归一化,如下所示:
import numpy as np
from skimage import color, exposure, transform
NUM_CLASSES = 43
IMG_SIZE = 48
def preprocess_img(img):
hsv = color.rgb2hsv(img)
hsv[:, :, 2] = exposure.equalize_hist(hsv[:, :, 2])
img = color.hsv2rgb(hsv)
img = transform.resize(img, (IMG_SIZE, IMG_SIZE))
return img
在以上代码中,我们首先将 RGB 格式的图像转换为 HSV (Hue Saturation Value) 格式。通过将图像从 RGB 转换为 HSV 格式,之后,我们对 HSV 格式的图像值进行归一化,通过使用 equalize_hist 方法令它们具有相同分布。将图像在 HSV 格式的最后一个通道(即明度通道,使图像具有相同的明度分布)中执行归一化后,将其转换回 RGB 格式。最后,我们将图像调整为同样的尺寸。
完成直方图归一化后检查图像,观察直方图归一化前后图像的变化:
for i, img_path in enumerate(all_img_paths[:800:100]):
img = cv2.imread(img_path)
plt.subplot(4,4,2*i+1)
plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
plt.title('Original')
img = preprocess_img(img)
plt.subplot(4,4,2*i+2)
plt.imshow(img[:, :, ::-1])
plt.title('Transformed')
plt.show()

从以上结果图片中,可以看到,在直方图归一化后,图像的亮度值发生了很大的变化,更容易分辨出交通标志的类别。
接下来,为模型准备输入和输出数组,并将数据集划分为训练和测试数据集:
from keras.utils import to_categorical
count = 0
imgs = []
labels = []
for img_path in all_img_paths:
img = preprocess_img(io.imread(img_path))
label = img_path.split('/')[-2]
imgs.append(img)
labels.append(label)
x = np.array(imgs)
y = to_categorical(labels, num_classes=NUM_CLASSES)
# 将数据集划分为训练和测试数据集
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)
2.2 模型构建与训练
构建用于交通标志识别的分类模型:
from keras.models import Sequential
from keras.layers import Conv2D, MaxPool2D, Dropout, ReLU
from keras.layers import Flatten, Dense, BatchNormalization
model = Sequential()
model.add(Conv2D(32, (3,3), padding='same', input_shape=(IMG_SIZE, IMG_SIZE, 3)))
model.add(BatchNormalization())
model.add(ReLU())
model.add(Conv2D(32, (3,3)))
model.add(BatchNormalization())
model.add(ReLU())
model.add(MaxPool2D(pool_size=(2,2)))
model.add(Dropout(0.2))
model.add(Conv2D(64, (3,3), padding='same'))
model.add(BatchNormalization())
model.add(ReLU())
model.add(Conv2D(64, (3,3)))
model.add(BatchNormalization())
model.add(ReLU())
model.add(MaxPool2D(pool_size=(2,2)))
model.add(Dropout(0.2))
model.add(Conv2D(128, (3,3), padding='same'))
model.add(BatchNormalization())
model.add(ReLU())
model.add(Conv2D(128, (3,3)))
model.add(BatchNormalization())
model.add(ReLU())
model.add(Dropout(0.2))
model.add(Flatten())
model.add(Dense(512))
model.add(BatchNormalization())
model.add(ReLU())
model.add(Dropout(0.2))
model.add(Dense(NUM_CLASSES, activation='softmax'))
model.summary()
该模型的简要结构信息输入如下:
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 48, 48, 32) 896
_________________________________________________________________
batch_normalization (BatchNo (None, 48, 48, 32) 128
_________________________________________________________________
re_lu (ReLU) (None, 48, 48, 32) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 46, 46, 32) 9248
_________________________________________________________________
batch_normalization_1 (Batch (None, 46, 46, 32) 128
_________________________________________________________________
re_lu_1 (ReLU) (None, 46, 46, 32) 0
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 23, 23, 32) 0
_________________________________________________________________
dropout (Dropout) (None, 23, 23, 32) 0
_________________________________________________________________
conv2d_2 (Conv2D) (None, 23, 23, 64) 18496
_________________________________________________________________
batch_normalization_2 (Batch (None, 23, 23, 64) 256
_________________________________________________________________
re_lu_2 (ReLU) (None, 23, 23, 64) 0
_________________________________________________________________
conv2d_3 (Conv2D) (None, 21, 21, 64) 36928
_________________________________________________________________
batch_normalization_3 (Batch (None, 21, 21, 64) 256
_________________________________________________________________
re_lu_3 (ReLU) (None, 21, 21, 64) 0
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 10, 10, 64) 0
_________________________________________________________________
dropout_1 (Dropout) (None, 10, 10, 64) 0
_________________________________________________________________
conv2d_4 (Conv2D) (None, 10, 10, 128) 73856
_________________________________________________________________
batch_normalization_4 (Batch (None, 10, 10, 128) 512
_________________________________________________________________
re_lu_4 (ReLU) (None, 10, 10, 128) 0
_________________________________________________________________
conv2d_5 (Conv2D) (None, 8, 8, 128) 147584
_________________________________________________________________
batch_normalization_5 (Batch (None, 8, 8, 128) 512
_________________________________________________________________
re_lu_5 (ReLU) (None, 8, 8, 128) 0
_________________________________________________________________
dropout_2 (Dropout) (None, 8, 8, 128) 0
_________________________________________________________________
flatten (Flatten) (None, 8192) 0
_________________________________________________________________
dense (Dense) (None, 512) 4194816
_________________________________________________________________
batch_normalization_6 (Batch (None, 512) 2048
_________________________________________________________________
re_lu_6 (ReLU) (None, 512) 0
_________________________________________________________________
dropout_3 (Dropout) (None, 512) 0
_________________________________________________________________
dense_1 (Dense) (None, 43) 22059
=================================================================
Total params: 4,507,723
Trainable params: 4,505,803
Non-trainable params: 1,920
_________________________________________________________________
最后,编译并拟合模型,如下所示:
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['acc'])
history = model.fit(x_train, y_train,
batch_size=64,
epochs=10,
validation_data=(x_test, y_test),
verbose=1)
训练完成的模型准确率约为 99.8%:

此外,如果我们使用完全相同的模型,但不进行直方图归一化时,则模型的准确率大约为 97%,可以看到使用直方图归一化进行预处理后可以显著提高模型的分类准确率。
相关链接
Keras深度学习实战(1)——神经网络基础与模型训练过程详解
Keras深度学习实战(2)——使用Keras构建神经网络
Keras深度学习实战(7)——卷积神经网络详解与实现
边栏推荐
- 国标GB28181协议EasyGBS平台兼容老版本收流端口的功能实现
- 【Day_09 0427】走方格的方案数
- [Neural Network] This article will take you to easily analyze the neural network (with an example of spoofing your girlfriend)
- COS User Practice Call for Papers
- Leetcode74. 搜索二维矩阵
- 如何让固定点的监控设备在EasyCVR平台GIS电子地图上显示地理位置?
- Goldfish Brother RHCA Memoirs: CL210 manages OPENSTACK network -- network configuration options
- Stop using MySQL online DDL
- What is the JVM runtime data area and the JMM memory model
- How opencv implements image skew correction
猜你喜欢

粒子滤波 particle filter —从贝叶斯滤波到粒子滤波——Part-I(贝叶斯滤波)

483-82 (23, 239, 450, 113)

【综述专栏】IJCAI 2022 | 图结构学习最新综述:研究进展与未来展望

OpenCV installation, QT, VS configuration project settings

C#/VB.NET:从 PDF 文档中提取所有表格

B005 - STC8 based single chip microcomputer intelligent street light control system

面试必问的HashCode技术内幕

The XML configuration

Solve the problem that MySQL cannot insert Chinese data

No need to crack, install Visual Studio 2013 Community Edition on the official website
随机推荐
【无标题】setInterval和setTimeout详解
el-form-item prop属性动态绑定不生效如何解决
Industry Salon Phase II丨How to enable chemical companies to reduce costs and increase efficiency through supply chain digital business collaboration?
阿里云的域名和ip绑定
COS 用户实践征文
【Day_12 0507】查找组成一个偶数最接近的两个素数
XML配置
Go GORM事务实例分析
C#/VB.NET 从PDF中提取表格
BITS Pilani|SAC-AP:基于 Soft Actor Critic 的深度强化学习用于警报优先级
2022,程序员应该如何找工作
成都理工大学&电子科技大学|用于强化学习的域自适应状态表示对齐
shell脚本专题(07):文件由cfs到bos
Multi-Party Threshold Private Set Intersection with Sublinear Communication-2021:解读
SQL函数 TO_DATE(二)
【pyqt5】自定义控件 实现能够保持长宽比地缩放子控件
7月30号|来一场手把手助您打造智能视觉新爆款的技术动手实验
odoo+物联网
三维空间中点的插值
[Neural Network] This article will take you to easily analyze the neural network (with an example of spoofing your girlfriend)