当前位置:网站首页>训练一个自己的分类 | 【包教包会,数据都准备好了】
训练一个自己的分类 | 【包教包会,数据都准备好了】
2022-07-28 07:56:00 【51CTO】
微信公众号:AI算法与图像处理
前言:
在找工作的时候,经常被问到你自己有没有做过什么项目,我回答跑过手写数字识别,然后就被呵呵了。。。
确实这个相对于编程里面的 hello world
因此,今天带大家搞个自己的分类器,其实并没有那么难
而且分类器的工业应用很广,例如产品的缺陷检测(检测不同的缺陷类型)等。
1、准备数据
花朵的图片数据:http://download.tensorflow.org/example_images/flower_photos.tgz
(有能力的,可以去网上爬其他的图片)

2、了解网络结构 DenseNet
DenseNet 是一种具有密集连接的卷积神经网络。在该网络中,任何两层之间都有直接的连接,也就是说,网络每一层的输入都是前面所有层输出的并集,而该层所学习的特征图也会被直接传给其后面所有层作为输入。下图是 DenseNet 的一个示意图。

这是一个非常牛皮的网络结构,获得 CVPR 2017最佳论文。
优点:
减轻梯度消失
加强特征传递
特征利用率高
减少参数数量
论文链接:https://arxiv.org/pdf/1608.06993.pdf
3、代码实现
环境:python3.6、tensorflow-gpu 1.5.0、keras2.2.4
下面注释很详细了,我就不多bb了。对代码不理解的可以参考tensorflow官方的教程:https://tensorflow.google.cn/tutorials/
''
'
需要的修改:
1
、修改训练集和验证集的路径(这里我采用的是绝对路径),模型和log的保存路径
2
、根据自己的需要改变batch_size,
epoch,
学习率等
''
'
#
--
coding:
utf
-
8
--
import
os
import
sys
import
glob
import
matplotlib.
pyplot
as
plt
from
keras
import
__version__
from
keras.
applications.
densenet
import
DenseNet201,
preprocess_input
from
keras.
models
import
Model
from
keras.
layers
import
Dense,
GlobalAveragePooling2D
from
keras.
preprocessing.
image
import
ImageDataGenerator
from
keras.
optimizers
import
SGD
from
keras.
callbacks
import
ModelCheckpoint,
TensorBoard
#
os.
environ[
“CUDA_VISIBLE_DEVICES”]
=
“0”
def
get_nb_files(
directory):
""
"Get number of files by searching directory recursively"
""
if
not
os.
path.
exists(
directory):
return
0
cnt
=
0
for
r,
dirs,
files
in
os.
walk(
directory):
for
dr
in
dirs:
cnt
+=
len(
glob.
glob(
os.
path.
join(
r,
dr
+
"/*")))
return
cnt
#
数据准备
IM_WIDTH,
IM_HEIGHT
=
224,
224 #
densenet指定的图片尺寸
#
FC_SIZE
=
1024 #
全连接层的节点个数
#
NB_IV3_LAYERS_TO_FREEZE
=
172 #
冻结层的数量
train_dir
=
'/media/pzw/0E50196C0E50196C/job/triode1/train' #
训练集数据
val_dir
=
'/media/pzw/0E50196C0E50196C/job/triode1/validate' #
验证集数据
nb_classes
=
5
nb_epoch
=
5
batch_size
=
8
nb_train_samples
=
get_nb_files(
train_dir) #
训练样本个数
nb_classes
=
len(
glob.
glob(
train_dir
+
"/*")) #
分类数
nb_val_samples
=
get_nb_files(
val_dir) #
验证集样本个数
nb_epoch
=
int(
nb_epoch) #
epoch数量
batch_size
=
int(
batch_size)
#
图片生成器
train_datagen
=
ImageDataGenerator(
preprocessing_function
=
preprocess_input,
rotation_range
=
5,
width_shift_range
=
0.02,
height_shift_range
=
0.02,
shear_range
=
0.02,
zoom_range
=
0.02,
horizontal_flip
=
True
)
test_datagen
=
ImageDataGenerator(
preprocessing_function
=
preprocess_input,
rotation_range
=
5,
width_shift_range
=
0.02,
height_shift_range
=
0.02,
shear_range
=
0.02,
zoom_range
=
0.02,
horizontal_flip
=
True
)
#
训练数据与测试数据
train_generator
=
train_datagen.
flow_from_directory(
train_dir,
target_size
=(
IM_WIDTH,
IM_HEIGHT),
batch_size
=
batch_size,
class_mode
=
'categorical')
validation_generator
=
test_datagen.
flow_from_directory(
val_dir,
target_size
=(
IM_WIDTH,
IM_HEIGHT),
batch_size
=
batch_size,
class_mode
=
'categorical')
#
添加新层
def
add_new_last_layer(
base_model,
nb_classes):
""
"
添加最后的层
输入
base_model和分类数量
输出
新的keras的model
""
"
x
=
base_model.
output
x
=
GlobalAveragePooling2D()(
x)
#
x
=
Dense(
FC_SIZE,
activation
=
'relu')(
x)
#new
FC
layer,
random
init
predictions
=
Dense(
nb_classes,
activation
=
'softmax')(
x) #
new
softmax
layer
model
=
Model(
input
=
base_model.
input,
output
=
predictions)
return
model
#
搭建模型
model
=
DenseNet201(
include_top
=
False)
model
=
add_new_last_layer(
model,
nb_classes)
#
当下次训练想从之前模型的起点开始的时候,可以把下面那个解除注释
#
model.
load_weights(
'model/checkpoint-02e-val_acc_0.82.hdf5')
model.
compile(
optimizer
=
SGD(
lr
=
0.001,
momentum
=
0.9,
decay
=
0.0001,
nesterov
=
True),
loss
=
'categorical_crossentropy',
metrics
=[
'accuracy'])
#
更好地保存模型
Save
the
model
after
every
epoch.
output_model_file
=
'/media/pzw/0E50196C0E50196C/job/triode1/denseNet/model/checkpoint-{epoch:02d}e-val_acc_{val_acc:.2f}.hdf5'
#
keras.
callbacks.
ModelCheckpoint(
filepath,
monitor
=
'val_loss',
verbose
=
0,
save_best_only
=
False,
save_weights_only
=
False,
mode
=
'auto',
period
=
1)
checkpoint
=
ModelCheckpoint(
output_model_file,
monitor
=
'val_acc',
verbose
=
1,
save_best_only
=
True)
#
tensorboard可视化
RUN
=
RUN
+
1
if
'RUN'
in
locals()
else
1 #
locals()
函数会以字典类型返回当前位置的全部局部变量。
LOG_DIR
=
'/media/pzw/0E50196C0E50196C/job/triode1/denseNet/training_logs/run{}'.
format(
RUN)
tensorboard
=
TensorBoard(
log_dir
=
LOG_DIR,
write_images
=
True)
#
开始训练
history_ft
=
model.
fit_generator(
train_generator,
samples_per_epoch
=
nb_train_samples,
nb_epoch
=
nb_epoch,
callbacks
=[
tensorboard,
checkpoint],
validation_data
=
validation_generator,
nb_val_samples
=
nb_val_samples)
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
- 28.
- 29.
- 30.
- 31.
- 32.
- 33.
- 34.
- 35.
- 36.
- 37.
- 38.
- 39.
- 40.
- 41.
- 42.
- 43.
- 44.
- 45.
- 46.
- 47.
- 48.
- 49.
- 50.
- 51.
- 52.
- 53.
- 54.
- 55.
- 56.
- 57.
- 58.
- 59.
- 60.
- 61.
- 62.
- 63.
- 64.
- 65.
- 66.
- 67.
- 68.
- 69.
- 70.
- 71.
- 72.
- 73.
- 74.
- 75.
- 76.
- 77.
- 78.
- 79.
- 80.
- 81.
- 82.
- 83.
- 84.
- 85.
- 86.
- 87.
- 88.
- 89.
- 90.
- 91.
- 92.
- 93.
- 94.
- 95.
- 96.
- 97.
- 98.
- 99.
- 100.
- 101.
- 102.
- 103.
- 104.
- 105.
- 106.
- 107.
- 108.
- 109.
- 110.
- 111.
- 112.
- 113.
- 114.
- 115.
- 116.
- 117.
- 118.
- 119.
- 120.
- 121.
- 122.
- 123.
这个模型训练的收敛速度是真的很快,非常的给力
训练完成之后可以用tensorboard查看loss的变化和准确率
当然我们要看一下模型的测试效果如何,所以看一下单张图片的测试结果
#
--
coding:
utf
-
8
--
''
'
1
、修改模型的路径和图片的路径
2
、对于类别个数不同要修改labels内的参数和plt里面的list
''
'
#
定义层
import
sys
import
argparse
import
numpy
as
np
from
PIL
import
Image
from
io
import
BytesIO
import
matplotlib.
pyplot
as
plt
from
keras.
preprocessing
import
image
from
keras.
models
import
load_model
from
keras.
applications.
densenet
import
preprocess_input
#
狂阶图片指定尺寸
target_size
= (
224,
224)
#
预测函数
#
输入:model,图片,目标尺寸
#
输出:预测predict
def
predict(
model,
img,
target_size):
""
"Run model prediction on image
Args:
model:
keras
model
img:
PIL
format
image
target_size: (
w,
h)
tuple
Returns:
list
of
predicted
labels
and
their
probabilities
""
"
if
img.
size
!=
target_size:
img
=
img.
resize(
target_size)
x
=
image.
img_to_array(
img)
x
=
np.
expand_dims(
x,
axis
=
0)
x
=
preprocess_input(
x)
preds
=
model.
predict(
x)
return
preds[
0]
#
画图函数
#
预测之后画图,这里默认是猫狗,当然可以修改label
labels
= (
"daisy",
"dandelion",
"roses",
"sunflowers",
"tulips")
def
plot_preds(
image,
preds,
labels):
""
"Displays image and the top-n predicted probabilities in a bar graph
Args:
image:
PIL
image
preds:
list
of
predicted
labels
and
their
probabilities
""
"
plt.
imshow(
image)
plt.
axis(
'off')
plt.
figure()
plt.
barh([
0,
1,
2,
3,
4],
preds,
alpha
=
0.5)
plt.
yticks([
0,
1,
2,
3,
4],
labels)
plt.
xlabel(
'Probability')
plt.
xlim(
0,
1.01)
plt.
tight_layout()
plt.
show()
#
载入模型
model
=
load_model(
'/media/pzw/0E50196C0E50196C/job/triode1/DenseNet/model/checkpoint-08e-val_acc_0.92.hdf5')
#
本地图片
img
=
Image.
open(
'20180820111351_1_1_Q90.jpg')
preds
=
predict(
model,
img,
target_size)
plot_preds(
img,
preds,
labels)
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
- 28.
- 29.
- 30.
- 31.
- 32.
- 33.
- 34.
- 35.
- 36.
- 37.
- 38.
- 39.
- 40.
- 41.
- 42.
- 43.
- 44.
- 45.
- 46.
- 47.
- 48.
- 49.
- 50.
- 51.
- 52.
- 53.
- 54.
- 55.
- 56.
- 57.
- 58.
- 59.
- 60.
- 61.
- 62.
- 63.
- 64.
- 65.
- 66.
- 67.
- 68.
- 69.
- 70.
- 71.
- 72.
- 73.
- 74.
- 75.
- 76.
效果:

总结
1、这里讲到了模型的应用,当然在面试的时候还要懂基本的结构,还有作者的一些巧妙的想法,不仅仅是为了找一口饭吃,对于以后的工作和解决问题的思路,都能有很大的帮助,能让领导眼前一亮
2、这里多嘴说一哈,我们在学习别人的成果的时候,并不是在瞻仰或者说膜拜别人多牛,更多应该是学习他们遇到问题的时候是如何去解决的,谁都会遇到不会的问题,但是在遇到问题的时候如何去思考,才是别人牛的地方,从无到有,牛。

END


边栏推荐
- oracle sql 问题
- After reading these 12 interview questions, the new media operation post is yours
- Argocd Web UI loading is slow? A trick to teach you to solve
- c语言数组指针和指针数组辨析,浅析内存泄漏
- Huid learning 7: Hudi and Flink integration
- Marketing play is changeable, and understanding the rules is the key!
- Recycling of classes loaded by classloader
- Detailed explanation of the basic use of express, body parse and express art template modules (use, route, path matching, response method, managed static files, official website)
- 推荐一个摆脱变量名纠结的神器和批量修改文件名方法
- Learn to draw with nature communications -- complex violin drawing
猜你喜欢

Shell programming specifications and variables

Go panic and recover

一年涨薪三次背后的秘密

Line generation (matrix)

Completion report of communication software development and Application

Solution: indexerror: index 13 is out of bounds for dimension 0 with size 13

Top all major platforms, 22 versions of interview core knowledge analysis notes, strong on the list
![[soft test software evaluator] 2013 comprehensive knowledge over the years](/img/c5/183acabd7015a5e515b7d83c127b2c.jpg)
[soft test software evaluator] 2013 comprehensive knowledge over the years

Explain cache consistency and memory barrier

Baidu AI Cloud Jiuzhou district and county brain, depicting a new blueprint for urban and rural areas!
随机推荐
阿里技术四面+交叉面+HR面,成功拿到offer,双非本科进不了大厂?
JSON 文件存储
Different HR labels
置顶各大平台,22版面试核心知识解析笔记,强势上榜
微服务架构 Sentinel 的服务限流及熔断
Huid learning 7: Hudi and Flink integration
Network interface network crystal head RJ45, Poe interface definition line sequence
Customer first | domestic Bi leader, smart software completes round C financing
Among China's top ten national snacks, it is actually the first
Kubernetes cluster configuration DNS Service
Marketing play is changeable, and understanding the rules is the key!
(13) Simple temperature alarm device based on 51 single chip microcomputer
Introduction of functions in C language (blood Book 20000 words!!!)
After summarizing more than 800 kubectl aliases, I'm no longer afraid that I can't remember commands!
象棋机器人夹伤7岁男孩手指,软件测试工程师的锅?我笑了。。。
A new method of exposing services in kubernetes clusters
Two dimensional array and operation
HCIP第八天
Go channel
Use of tkmapper - super detailed