当前位置:网站首页>Keras' deep learning practice -- gender classification based on inception V3
Keras' deep learning practice -- gender classification based on inception V3
2022-07-08 01:55:00 【Hope Xiaohui】
Keras Deep learning practice —— be based on Inception v3 Achieve gender classification
0. Preface
We've learned that based on VGG16 and VGG19 Architecture implementation Gender classification , besides , There are many other deep neural network architectures that are more skillfully designed , for example Inception On the premise of ensuring the quality of the model , Greatly reduce the number of model parameters . In this section , We will be on Inception The core idea of the model is introduced , Then use pre training based Inception The architecture implements gender classification .
1. Inception structure
For better understanding Inception The core idea of the model , Let's first consider the following scenarios : In the data set , Some objects in the image occupy most of the image , But in other images, the object may only occupy a small part of the whole image . If we use convolution kernels of the same size in both cases , It will make it difficult for the model to learn to recognize the smaller object in the image and the larger object in the image at the same time .
To solve this problem , We can use many convolution kernels of different sizes in the same layer . under these circumstances , The network is essentially wider , Instead of getting deeper , As shown below :

In the diagram above , We use many convolution kernels of different sizes in a given layer ,Inception v1 The module has nine linearly stacked modules Inception modular , As shown below :

1.1 Inception v1 Loss function
stay Inception v1 In the architecture diagram , You can see that the architecture is both deep and wide , This is likely to cause the gradient to disappear .
In order to solve the problem of gradient disappearance ,Inception v1 There are two auxiliary classifiers , They come from Inception modular , Try to base on Inception The total loss of the network is minimized , As shown below :
total_loss = real_loss + 0.3 * aux_loss_1 + 0.3 * aux_loss_2
It should be noted that , Auxiliary loss is only used during training , It will be ignored during model testing .
1.2 Inception v2 and Inception v3
Inception v2 and Inception v3 It's right Inception v1 Architecture improvements , Among them in Inception v2 in ,Inception The author optimizes the algorithm on the basis of convolution , To process images faster ; stay Inception v3 in ,Inception Based on the original convolution kernel, the author adds 7 x 7 Convolution kernel , And connect them in series . To make a long story short ,Inception The contributions are as follows :
- Use
InceptionThe module captures the multi-scale details of the image - Use
1 x 1Convolution ACTS as the bottleneck layer - Use the average pool layer instead of the full connection layer , Reduce the amount of model parameters
- Use auxiliary branches to avoid the disappearance of gradients
2. Using pre-trained Inception v3 The model implements gender classification
stay 《 The migration study 》 in , We learned about using transfer learning , Only a few samples are needed to train the model with good performance ; And use pre trained based on transfer learning VGG16 The model Gender classification Actual combat . In this section , We use pre trained Inception v3 Build a model to recognize the gender of the characters in the image .
2.1 Model implementation
First , Load the required libraries , And load the pre trained Inception v3 Model :
from keras.applications import InceptionV3
from keras.applications.inception_v3 import preprocess_input
from glob import glob
from skimage import io
import cv2
import numpy as np
model = InceptionV3(include_top=False, weights='imagenet', input_shape=(256, 256, 3))
Create input and output data sets :
x = []
y = []
for i in glob('man_woman/a_resized/*.jpg')[:8000]:
try:
image = io.imread(i)
x.append(image)
y.append(0)
except:
continue
for i in glob('man_woman/b_resized/*.jpg')[:8000]:
try:
image = io.imread(i)
x.append(image)
y.append(1)
except:
continue
x_inception_v3 = []
for i in range(len(x)):
img = x[i]
img = preprocess_input(img.reshape((1, 256, 256, 3)))
img_feature = model.predict(img)
x_inception_v3.append(img_feature)
Convert input and output to numpy Array , The data set is divided into training and testing sets :
x_inception_v3 = np.array(x_inception_v3)
x_inception_v3 = x_inception_v3.reshape(x_inception_v3.shape[0], x_inception_v3.shape[2], x_inception_v3.shape[3], x_inception_v3.shape[4])
y = np.array(y)
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x_inception_v3, y, test_size=0.2)
Based on the output of the pre training model, a fine tuning model is constructed :
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dropout, Dense
model_fine_tuning = Sequential()
model_fine_tuning.add(Conv2D(2048,
kernel_size=(3, 3),
activation='relu',
input_shape=(x_train.shape[1], x_train.shape[2], x_train.shape[3])))
model_fine_tuning.add(MaxPooling2D(pool_size=(2, 2)))
model_fine_tuning.add(Flatten())
model_fine_tuning.add(Dense(1024, activation='relu'))
model_fine_tuning.add(Dropout(0.5))
model_fine_tuning.add(Dense(1, activation='sigmoid'))
model_fine_tuning.summary()
The brief information output of the previous fine-tuning model , As shown below :
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_94 (Conv2D) (None, 4, 4, 2048) 37750784
_________________________________________________________________
max_pooling2d_4 (MaxPooling2 (None, 2, 2, 2048) 0
_________________________________________________________________
flatten (Flatten) (None, 8192) 0
_________________________________________________________________
dense (Dense) (None, 1024) 8389632
_________________________________________________________________
dropout (Dropout) (None, 1024) 0
_________________________________________________________________
dense_1 (Dense) (None, 1) 1025
=================================================================
Total params: 46,141,441
Trainable params: 46,141,441
Non-trainable params: 0
_________________________________________________________________
Last , Compile and fit the model :
model_fine_tuning.compile(loss='binary_crossentropy',optimizer='adam',metrics=['acc'])
history = model_fine_tuning.fit(x_train, y_train,
batch_size=32,
epochs=20,
verbose=1,
validation_data = (x_test, y_test))
During training , The changes of accuracy and loss values of the model on the training data set and the test data set are as follows :

You can see , Based on and training Inception V3 The accuracy of the gender classification model can reach 95% about .
2.2 Examples of misclassified pictures
Examples of misclassified images are as follows :
x = np.array(x)
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)
x_test_inception_v3 = []
for i in range(len(x_test)):
img = x_test[i]
img = preprocess_input(img.reshape((1, 256, 256, 3)))
img_feature = model.predict(img)
x_test_inception_v3.append(img_feature)
x_test_inception_v3 = np.array(x_test_inception_v3)
x_test_inception_v3 = x_test_inception_v3.reshape(x_test_inception_v3.shape[0], x_test_inception_v3.shape[2], x_test_inception_v3.shape[3], x_test_inception_v3.shape[4])
y_pred = model_fine_tuning.predict(x_test_inception_v3)
wrong = np.argsort(np.abs(y_pred.flatten()-y_test))
print(wrong)
y_test_char = np.where(y_test==0,'M','F')
y_pred_char = np.where(y_pred>0.5,'F','M')
plt.subplot(221)
plt.imshow(x_test[wrong[-1]])
plt.title('Actual: '+str(y_test_char[wrong[-1]])+', '+'Predicted: '+str((y_pred_char[wrong[-1]][0])))
plt.subplot(222)
plt.imshow(x_test[wrong[-2]])
plt.title('Actual: '+str(y_test_char[wrong[-2]])+', '+'Predicted: '+str((y_pred_char[wrong[-2]][0])))
plt.subplot(223)
plt.imshow(x_test[wrong[-3]])
plt.title('Actual: '+str(y_test_char[wrong[-3]])+', '+'Predicted: '+str((y_pred_char[wrong[-3]][0])))
plt.subplot(224)
plt.imshow(x_test[wrong[-4]])
plt.title('Actual: '+str(y_test_char[wrong[-4]])+', '+'Predicted: '+str((y_pred_char[wrong[-4]][0])))
plt.show()

Related links
Keras Deep learning practice (7)—— Convolution neural network detailed explanation and implementation
Keras Deep learning practice (9)—— The limitations of convolutional neural networks
Keras Deep learning practice (10)—— The migration study
Keras Deep learning practice —— Using convolution neural network to achieve gender classification
Keras Deep learning practice —— be based on VGG19 The model implements gender classification
边栏推荐
- Application of slip ring in direct drive motor rotor
- Nanny level tutorial: Azkaban executes jar package (with test samples and results)
- Introduction to ADB tools
- Js中forEach map无法跳出循环问题以及forEach会不会修改原数组
- Codeforces Round #633 (Div. 2) B. Sorted Adjacent Differences
- 如何让导电滑环信号更好
- Voice of users | understanding of gbase 8A database learning
- MySQL查询为什么没走索引?这篇文章带你全面解析
- Cross modal semantic association alignment retrieval - image text matching
- burpsuite
猜你喜欢

Remote sensing contribution experience sharing

nmap工具介紹及常用命令

Matlab r2021b installing libsvm

Keras深度学习实战——基于Inception v3实现性别分类

ArrayList源码深度剖析,从最基本的扩容原理,到魔幻的迭代器和fast-fail机制,你想要的这都有!!!

保姆级教程:Azkaban执行jar包(带测试样例及结果)

Working principle of stm32gpio port

【SolidWorks】修改工程图格式

进程和线程的退出

How to make enterprise recruitment QR code?
随机推荐
PB9.0 insert OLE control error repair tool
用户之声 | 对于GBase 8a数据库学习的感悟
common commands
腾讯游戏客户端开发面试 (Unity + Cocos) 双重轰炸 社招6轮面试
metasploit
SQLite3 data storage location created by Android
ANSI / NEMA- MW- 1000-2020 磁铁线标准。. 最新原版
ClickHouse原理解析与应用实践》读书笔记(8)
Optimization of ecological | Lake Warehouse Integration: gbase 8A MPP + xeos
Tapdata 的 2.0 版 ,开源的 Live Data Platform 现已发布
Redisson distributed lock unlocking exception
I don't know. The real interest rate of Huabai installment is so high
Tapdata 的 2.0 版 ,开源的 Live Data Platform 现已发布
Gbase observation | how to protect the security of information system with frequent data leakage
碳刷滑环在发电机中的作用
MySQL查询为什么没走索引?这篇文章带你全面解析
body有8px的神秘边距
进程和线程的退出
ROS problems (topic types do not match, topic datatype/md5sum not match, MSG XXX have changed. rerun cmake)
Usage of xcolor color in latex