当前位置:网站首页>深度学习100例——卷积神经网络(CNN)实现服装图像分类
深度学习100例——卷积神经网络(CNN)实现服装图像分类
2022-08-03 10:30:00 【Ding Jiaxiong】
活动地址:CSDN21天学习挑战赛
深度学习100例——卷积神经网络(CNN)服装图像分类
文章目录
1. 前期准备工作
我的环境:
1.1 设置GPU
import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU") # tf.config.list_physical_devices# 获得当前主机上某种特定运算设备类型(如 GPU 或 CPU )的列表
if gpus:
gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPU
tf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用
tf.config.set_visible_devices([gpu0],"GPU") # 设置可见设备列表
当前机器只有一个GPU
1.2 导入数据
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt
(train_images, train_labels), (test_images, test_labels) = datasets.fashion_mnist.load_data()
查看一下数据
(train_images, train_labels) → 训练集
(test_images, test_labels) → 测试集
和mnist数据集类似,60000个训练样本,10000个测试样本,且70000张图片均是28 x 28的像素图。
每张图的像素值介于 0 - 255 之间,标签是整数数组,介于 0 - 9 之间。
标签对应表:
标签 | 类 | 标签 | 类 |
---|---|---|---|
0 | T恤 | 5 | 凉鞋 |
1 | 裤子 | 6 | 衬衫 |
2 | 套头衫 | 7 | 运动鞋 |
3 | 连衣裙 | 8 | 包 |
4 | 外套 | 9 | 短靴 |
1.3 数据归一化
→ 将像素的值标准化到 0 - 1的区间内。
train_images, test_images = train_images / 255.0, test_images / 255.0
train_images.shape,test_images.shape,train_labels.shape,test_labels.shape
归一化前:
归一化后:
1.4 调整图片格式
train_images = train_images.reshape((60000, 28, 28, 1))
test_images = test_images.reshape((10000, 28, 28, 1))
train_images.shape,test_images.shape,train_labels.shape,test_labels.shape
数据维度转换,增加通道数
1.5 可视化查看数据图像
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
plt.figure(figsize=(20,10))
for i in range(20):
plt.subplot(5,10,i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(train_images[i], cmap=plt.cm.binary)
plt.xlabel(class_names[train_labels[i]])
plt.show()
同样只展示0-19索引的样本图像。
2. 构建CNN网络
2.1 构建网络
model = models.Sequential([
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)), #卷积层1,卷积核3 x 3
layers.MaxPooling2D((2, 2)), #池化层1,2 x 2采样
layers.Conv2D(64, (3, 3), activation='relu'), #卷积层2,卷积核3 x 3
layers.MaxPooling2D((2, 2)), #池化层2,2 x 2采样
layers.Conv2D(64, (3, 3), activation='relu'), #卷积层3,卷积核3 x 3
layers.Flatten(), #Flatten层,连接卷积层与全连接层
layers.Dense(64, activation='relu'), #全连接层,特征进一步提取
layers.Dense(10) #输出层,输出预期结果
])
model.summary() # 打印网络结构
2.2 模型说明
各层的作用:
- 输入层:将数据输入到训练网络
- 卷积层:使用卷积核提取图片的特征
- 池化层:进行下采样,用更高层的抽象表示图像特征
- Flatten层:将多维的输入一维化,常用在卷积层到全连接层的过渡
- 全连接层:特征提取器
- 输出层:输出结果
3. 编译模型
model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
compile()方法用于设置训练时,使用的优化器optimizer、损失函数loss、准确率评测标准metrics
SparseCategoricalCrossentropy → 交叉熵损失函数,当from_logits参数为True时,会使用softmax将预测y转换为概率。
4.训练模型
history = model.fit(train_images, train_labels, epochs=10, validation_data=(test_images, test_labels))
5. 模型预测
模型的预测结果是一个包含10个数字的数组,代表模型对10种不同服装中每种时装的 置信度。
pre = model.predict(test_images)
6. 模型评估
plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0.5, 1])
plt.legend(loc='lower right')
plt.show()
test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)
边栏推荐
- 使用 Scrapy 框架对重复的 url 无法获取数据,dont_filter=True
- QT with OpenGL(HDR)
- GBase 8c分布式数据库,数据如何分布最优?
- error C2872: “flann”: 不明确的符号 解决方法
- OS层面包重组失败过高,数据库层面gc lost 频繁
- Mysql OCP 75题
- Interview Blitz 71: What's the difference between GET and POST?
- array of function pointers
- type="module" you know, but type="importmap" you know
- 成对连接点云分割
猜你喜欢
随机推荐
Leecode-SQL 1484. 按日期分组销售产品
按位取反怎么运算_按位取反运算
聊天app开发——防炸麦以及节省成本的内容鉴定方法
Scrapy + Selenium 实现模拟登录,获取页面动态加载数据
Pixel mobile phone system
Mysql OCP 27题
SQL教程之递归 CTE Common Table Expression
文旅部:进一步加强旅游景区暑期安全管理工作
罕见的数学天才,靠“假结婚”才得到追求事业的机会
训练双塔检索模型,可以不用query-doc样本了?明星机构联合发文
Regulation action for one hundred days during the summer, more than 700 traffic safety hidden dangers were thrown out
With strong network, China mobile to calculate excitation surging energy network construction
2022年起重机械指挥培训试题模拟考试平台操作
js中最简单base64图片流实现自动下载
go泛型使用方法
js函数防抖和函数节流及其使用场景。
决策树和随机森林
简述设计的意义是什么_定义和概念的最大区别
迅为IMX6开发板QT系统创建AP热点基于RTL8723交叉编译hostapd
安全研究员:大量Solana钱包被盗