当前位置:网站首页>CNN鲜花分类
CNN鲜花分类
2022-08-02 14:23:00 【别团等shy哥发育】
CNN鲜花分类
1、数据集介绍
总共5种花,按照文件夹区分花朵的类别。
下载下来的是个压缩包,需要将其解压。
数据集下载地址:https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
2、代码实战
2.1 导入依赖
import PIL
import numpy as np
import matplotlib.pyplot as plt
import pathlib
import lib
import tensorflow as tf
from tensorflow.keras import layers,models
2.2 下载数据
# 下载数据集到本地
data_url='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz'
data_dir=tf.keras.utils.get_file('flower_photos',origin=data_url,untar=True)#untar=True 下载后解压
data_dir=pathlib.Path(data_dir)
2.3 统计数据集
# 统计数据集大小
dataset_size=len(list(data_dir.glob('*/*.jpg')))
dataset_size
总共3670张照片,比上次小狗分类那个少多了。
# 显示部分图片
imgs=list(data_dir.glob('*/*.jpg'))
imgs
查看下第1张图片
img1=imgs[0] #第一张图片
img1
str(img1)
PIL.Image.open(str(img1)) #读取并显示
再查看下第2张图片
img2=imgs[1] #第2张图片
PIL.Image.open(str(img2))
2.4 创建dataset
训练集:
# 3 创建dataset
BATCH_SIZE=32
HEIGHT=180
WIDTH=180
#80%是训练集,20%是验证集
train_ds=tf.keras.preprocessing.image_dataset_from_directory(directory=data_dir,
batch_size=BATCH_SIZE,
validation_split=0.2,
subset='training',
seed=666,
image_size=(HEIGHT,WIDTH))
train_ds
class_names=train_ds.class_names #数据集类别
class_names
验证集:
val_ds=tf.keras.preprocessing.image_dataset_from_directory(directory=data_dir,
batch_size=BATCH_SIZE,
validation_split=0.2,
subset='validation',
seed=666,
image_size=(HEIGHT,WIDTH))
val_ds
2.5 可视化一个batch_size
# 可视化一个batch_size的数据
for images,labels in train_ds.take(1):
for i in range(9): # 一个batch_size有32张,这里只显示9张
plt.subplot(3,3,i+1)
plt.imshow(images[i].numpy().astype('uint8'))
plt.title(class_names[labels[i]])
plt.axis('off')
2.6 将数据集缓存到内存中,加速读取
#将数据集缓存到内存中,加速读取
AUTOTUNE=tf.data.AUTOTUNE
train_ds=train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds=val_ds.cache().prefetch(buffer_size=AUTOTUNE)
2.7 搭建模型
这里仅作测试,并没有使用预训练模型
#搭建模型
model=models.Sequential([
layers.experimental.preprocessing.Rescaling(1./255,input_shape=(HEIGHT,WIDTH,3)),# 数据归一化
layers.Conv2D(16,3,padding='same',activation='relu'),
layers.MaxPool2D(),
layers.Conv2D(32,3,padding='same',activation='relu'),
layers.MaxPool2D(),
layers.Conv2D(64,3,padding='same',activation='relu'),
layers.MaxPool2D(),
layers.Flatten(),
layers.Dense(128,activation='relu'),
layers.Dense(5)
])
model.summary()
2.8 编译模型
#编译模型
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
这里使用的SparseCategoricalCrossentropy会自动帮我们
2.9 模型训练
#模型训练
EPOCHS=10
history=model.fit(train_ds,validation_data=val_ds,epochs=EPOCHS)
这里由于设备太拉跨,略微出手已是显卡极限,所以就只设置了10个epoch
2.10 可视化训练结果
# 可视化训练结果
ranges=range(EPOCHS)
train_acc=history.history['accuracy']
val_acc=history.history['val_accuracy']
train_loss=history.history['loss']
val_loss=history.history['val_loss']
plt.figure(figsize=(16,8))
plt.subplot(1,2,1)
plt.plot(ranges,train_acc,label='train_acc')
plt.plot(ranges,val_acc,label='val_acc')
plt.title('Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.subplot(1,2,2)
plt.plot(ranges,train_loss,label='train_loss')
plt.plot(ranges,val_loss,label='val_loss')
plt.title('Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()
过拟合非常严重,下面对模型进行优化
3、模型优化
3.1 数据增强设置
# 数据增强参数设置
data_argumentation=tf.keras.Sequential([
# 随机水平翻转
layers.experimental.preprocessing.RandomFlip('horizontal',input_shape=(HEIGHT,WIDTH,3)),
# 随机旋转
layers.experimental.preprocessing.RandomRotation(0.1), # 旋转
# 随机缩放
layers.experimental.preprocessing.RandomZoom(0.1), #
])
这块的API太多了,多去查查官网。
3.2 显示数据增强后的效果
# 显示数据增强后的效果
for images,labels in train_ds.take(1):
for i in range(9): # 一个batch_size有32张,这里只显示9张
plt.subplot(3,3,i+1)
argumeng_images=data_argumentation(images) #数据增强
plt.imshow(argumeng_images[i].numpy().astype('uint8')) # 显示
plt.title(class_names[labels[i]])
plt.axis('off')
3.3 搭建新的模型
#搭建新的模型
model_2=models.Sequential([
data_argumentation, # 数据增强
layers.experimental.preprocessing.Rescaling(1./255),# 数据归一化
layers.Conv2D(16,3,padding='same',activation='relu'),
layers.MaxPool2D(),
layers.Conv2D(32,3,padding='same',activation='relu'),
layers.MaxPool2D(),
layers.Conv2D(64,3,padding='same',activation='relu'),
layers.MaxPool2D(),
layers.Dropout(0.2),
layers.Flatten(),
layers.Dense(128,activation='relu'),
layers.Dense(5)
])
model_2.summary()
3.4 编译模型
#编译模型
model_2.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
3.5 模型训练
#模型训练
history=model_2.fit(train_ds,validation_data=val_ds,epochs=EPOCHS)
3.6 可视化训练结果
# 可视化训练结果
ranges=range(EPOCHS)
train_acc=history.history['accuracy']
val_acc=history.history['val_accuracy']
train_loss=history.history['loss']
val_loss=history.history['val_loss']
plt.figure(figsize=(16,8))
plt.subplot(1,2,1)
plt.plot(ranges,train_acc,label='train_acc')
plt.plot(ranges,val_acc,label='val_acc')
plt.title('Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.subplot(1,2,2)
plt.plot(ranges,train_loss,label='train_loss')
plt.plot(ranges,val_loss,label='val_loss')
plt.title('Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()
现在这个效果比优化之前的好多了。
3.7 模型预测
# 模型预测
test_img=tf.keras.preprocessing.image.load_img('sunfloor.jpg',target_size=(HEIGHT,WIDTH))
test_img
这里我们自己在网上下载一张向日葵的图片进行预测
test_img=tf.keras.preprocessing.image.img_to_array(test_img) # 类型变换
test_img.shape
将数据扩充一维,因为第一个维度是batchsize
test_img=tf.expand_dims(test_img,0) #扩充一维
test_img.shape
预测:
preds=model_2.predict(test_img) #预测
preds.shape
得分:
preds #得分
得分转换成概率:
scores=tf.nn.softmax(preds[0])# 得分转换成概率
scores
print('模型预测可能性最大的类别是:{},概率值为:{}'.format(class_names[np.argmax(scores)],np.max(scores)))
这里最后一个全连接层可以直接加上个softmax激活函数,这样预测后就不用再转化了。
边栏推荐
猜你喜欢
2021年华为杯数学建模竞赛E题——信号干扰下的超宽带(UWB)精确定位问题
【Anaconda】一行语句解决Spyder启动问题
[Fault Diagnosis] Weak Fault Diagnosis of Fan Bearing Based on PSO_VMD_MCKD Method
使用 docker 搭建 redis-cluster 集群
DOM - Element Box Model
Impulse response invariant method and bilinear transformation method for IIR filter design
DOM — 元素的增删改查
【QMT】给QMT量化交易软件安装和调用第三方库(举例通达信pytdx,MyTT,含代码)
2022-7-12 第五组 瞒春 学习报告
C语言中国象棋源码以及图片
随机推荐
2022-07-16 第五小组 瞒春 学习笔记
异常简单总结
【SVM回归预测】基于LibSVM实现多特征数据的预测
DOM - Event Mechanism and Event Chain
如何使用Swiper外部插件写一个轮播图
2022-07-25 第六小组 瞒春 学习笔记
李开复花上千万投的缝纫机器人,团队出自大疆
一文让你快速手写C语言-三子棋游戏
IIR滤波器设计之冲激响应不变法与双线性变换法
DOM —— 事件代理
idea使用jdbc对数据库进行增删改查,以及使用懒汉方式实现单例模式
JSP技术
【web渗透】文件包含漏洞入门级超详细讲解
Redis最新6.27安装配置笔记及安装和常用命令快速上手复习指南
C语言的基本程序结构详细讲解
nodemon : 无法加载文件 D:\Program Files\nodejs\node_global\nodemon.ps1
职工管理系统(SSM整合)
第四章-4.1-最大子数组问题
ELK日志分析系统
有效的括号【暴力、分支判断、哈希表】