当前位置:网站首页>Keras' deep learning practice -- gender classification based on vgg19 model
Keras' deep learning practice -- gender classification based on vgg19 model
2022-07-02 17:45:00 【Hope Xiaohui】
Keras Deep learning practice —— be based on VGG19 The model implements gender classification
0. Preface
stay 《 The migration study 》 in , We learned about using transfer learning , Only a few samples are needed to train the model with good performance ; And use pre trained based on transfer learning VGG16
The model Gender classification Actual combat , Further deepen the understanding of the working principle of transfer learning .
1. VGG19 Architecture brief introduction
this paper , We will introduce another common network model architecture ——VGG19
, And use pre trained VGG19
Model development Gender classification actual combat .VGG19
yes VGG16
Improved version , With more convolution and pooling operations ,VGG19
The architecture of the model is as follows :
Model: "vgg19"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 256, 256, 3)] 0
_________________________________________________________________
block1_conv1 (Conv2D) (None, 256, 256, 64) 1792
_________________________________________________________________
block1_conv2 (Conv2D) (None, 256, 256, 64) 36928
_________________________________________________________________
block1_pool (MaxPooling2D) (None, 128, 128, 64) 0
_________________________________________________________________
block2_conv1 (Conv2D) (None, 128, 128, 128) 73856
_________________________________________________________________
block2_conv2 (Conv2D) (None, 128, 128, 128) 147584
_________________________________________________________________
block2_pool (MaxPooling2D) (None, 64, 64, 128) 0
_________________________________________________________________
block3_conv1 (Conv2D) (None, 64, 64, 256) 295168
_________________________________________________________________
block3_conv2 (Conv2D) (None, 64, 64, 256) 590080
_________________________________________________________________
block3_conv3 (Conv2D) (None, 64, 64, 256) 590080
_________________________________________________________________
block3_conv4 (Conv2D) (None, 64, 64, 256) 590080
_________________________________________________________________
block3_pool (MaxPooling2D) (None, 32, 32, 256) 0
_________________________________________________________________
block4_conv1 (Conv2D) (None, 32, 32, 512) 1180160
_________________________________________________________________
block4_conv2 (Conv2D) (None, 32, 32, 512) 2359808
_________________________________________________________________
block4_conv3 (Conv2D) (None, 32, 32, 512) 2359808
_________________________________________________________________
block4_conv4 (Conv2D) (None, 32, 32, 512) 2359808
_________________________________________________________________
block4_pool (MaxPooling2D) (None, 16, 16, 512) 0
_________________________________________________________________
block5_conv1 (Conv2D) (None, 16, 16, 512) 2359808
_________________________________________________________________
block5_conv2 (Conv2D) (None, 16, 16, 512) 2359808
_________________________________________________________________
block5_conv3 (Conv2D) (None, 16, 16, 512) 2359808
_________________________________________________________________
block5_conv4 (Conv2D) (None, 16, 16, 512) 2359808
_________________________________________________________________
block5_pool (MaxPooling2D) (None, 8, 8, 512) 0
=================================================================
Total params: 20,024,384
Trainable params: 20,024,384
Non-trainable params: 0
_________________________________________________________________
You can see , The architecture shown above has more network layers and more parameters . It should be noted that ,VGG16
and VGG19
In architecture 16
and 19
Represents the number of network layers in these networks .
Pass each image through VGG19
After network , Extract to 8 x 8 x 512
After output , This output will become the input to fine tune the model . Next , Create input and output data sets , And then build 、 The process of compiling and fitting the model is the same as Use pre training based VGG16 Model for gender classification The process is the same .
2. Use pre training VGG19 Model for gender classification
In this section , We are based on The migration study Using pre-trained VGG19
Model for gender classification .
2.1 Build input and output data
First , Prepare input and output data , We reuse 《 Convolution neural network for gender classification 》 Data set and data loading code used in :
from keras.applications import VGG19
from keras.applications.vgg19 import preprocess_input
from glob import glob
from skimage import io
import cv2
import numpy as np
model = VGG19(include_top=False, weights='imagenet', input_shape=(256, 256, 3))
x = []
y = []
for i in glob('man_woman/a_resized/*.jpg')[:800]:
try:
image = io.imread(i)
x.append(image)
y.append(0)
except:
continue
for i in glob('man_woman/b_resized/*.jpg')[:800]:
try:
image = io.imread(i)
x.append(image)
y.append(1)
except:
continue
x_vgg19 = []
for i in range(len(x)):
img = x[i]
img = preprocess_input(img.reshape((1, 256, 256, 3)))
img_feature = model.predict(img)
x_vgg19.append(img_feature)
Convert the input and output to their corresponding arrays , And create training and test data sets :
x_vgg19 = np.array(x_vgg19)
x_vgg19 = x_vgg19.reshape(x_vgg19.shape[0], x_vgg19.shape[2], x_vgg19.shape[3], x_vgg19.shape[4])
y = np.array(y)
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x_vgg19, y, test_size=0.2)
2.2 Model construction and training
Build fine tuning model :
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dropout, Dense
model_fine_tuning = Sequential()
model_fine_tuning.add(Conv2D(512,
kernel_size=(3, 3),
activation='relu',
input_shape=(x_train.shape[1], x_train.shape[2], x_train.shape[3])))
model_fine_tuning.add(MaxPooling2D(pool_size=(2, 2)))
model_fine_tuning.add(Flatten())
model_fine_tuning.add(Dense(1024, activation='relu'))
model_fine_tuning.add(Dropout(0.6))
model_fine_tuning.add(Dense(1, activation='sigmoid'))
model_fine_tuning.summary()
The brief information input of the model architecture is as follows :
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 6, 6, 512) 2359808
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 3, 3, 512) 0
_________________________________________________________________
flatten (Flatten) (None, 4608) 0
_________________________________________________________________
dense (Dense) (None, 1024) 4719616
_________________________________________________________________
dropout (Dropout) (None, 1024) 0
_________________________________________________________________
dense_1 (Dense) (None, 1) 1025
=================================================================
Total params: 7,080,449
Trainable params: 7,080,449
Non-trainable params: 0
_________________________________________________________________
Next , Compile and fit the model :
model_fine_tuning.compile(loss='binary_crossentropy',optimizer='adam',metrics=['acc'])
history = model_fine_tuning.fit(x_train, y_train,
batch_size=32,
epochs=20,
verbose=1,
validation_data = (x_test, y_test))
Last , We draw during training , The loss and accuracy of the model in training and testing data sets . You can see , When we use VGG19
Architecture , It can reach about 95%
The accuracy of , Results and use VGG16
The performance of the architecture is similar :
2.3 Example of model error classification
Some examples of misclassified images are as follows :
x = np.array(x)
from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)
x_test_vgg19 = []
for i in range(len(x_test)):
img = x_test[i]
img = preprocess_input(img.reshape((1, 256, 256, 3)))
img_feature = model.predict(img)
x_test_vgg19.append(img_feature)
x_test_vgg19 = np.array(x_test_vgg19)
x_test_vgg19 = x_test_vgg19.reshape(x_test_vgg19.shape[0], x_test_vgg19.shape[2], x_test_vgg19.shape[3], x_test_vgg19.shape[4])
y_pred = model_fine_tuning.predict(x_test_vgg19)
wrong = np.argsort(np.abs(y_pred.flatten()-y_test))
print(wrong)
y_test_char = np.where(y_test==0,'M','F')
y_pred_char = np.where(y_pred>0.5,'F','M')
plt.subplot(221)
plt.imshow(x_test[wrong[-1]])
plt.title('Actual: '+str(y_test_char[wrong[-1]])+', '+'Predicted: '+str((y_pred_char[wrong[-1]][0])))
plt.subplot(222)
plt.imshow(x_test[wrong[-2]])
plt.title('Actual: '+str(y_test_char[wrong[-2]])+', '+'Predicted: '+str((y_pred_char[wrong[-2]][0])))
plt.subplot(223)
plt.imshow(x_test[wrong[-3]])
plt.title('Actual: '+str(y_test_char[wrong[-3]])+', '+'Predicted: '+str((y_pred_char[wrong[-3]][0])))
plt.subplot(224)
plt.imshow(x_test[wrong[-4]])
plt.title('Actual: '+str(y_test_char[wrong[-4]])+', '+'Predicted: '+str((y_pred_char[wrong[-4]][0])))
plt.show()
From the picture , It can be seen that ,VGG19
Be similar to VGG16
In addition to the misclassification caused by the small space occupied by characters in the image , They tend to judge whether the character is male or female according to their hair .
Related links
Keras Deep learning practice (7)—— Convolution neural network detailed explanation and implementation
Keras Deep learning practice (9)—— The limitations of convolutional neural networks
Keras Deep learning practice (10)—— The migration study
Keras Deep learning practice —— Using convolution neural network to achieve gender classification
边栏推荐
- 求简单微分方程
- List summation [dummy+ tail interpolation + function processing list reference common pit]
- VScode知识点——常见报错
- Five reasons to choose SAP Spartacus as the implementation framework of SAP commerce cloud storefront
- Common SQL statements (complete example)
- JDBC
- chrome浏览器快速访问stackoverflow
- 智能水电表能耗监测云平台
- si446使用记录(一):基本资料获取
- MB10M-ASEMI整流桥MB10M
猜你喜欢
Si446 usage record (I): basic data acquisition
每日一题——小乐乐改数字
ASEMI整流桥UMB10F参数,UMB10F规格,UMB10F封装
[target tracking] |siamfc
Platform management background and business menu resource management: business permissions and menu resource management design
chrome浏览器快速访问stackoverflow
Virtual lab basic experiment tutorial -7 Polarization (2)
Rk1126 platform project summary
Chapter 15 string localization and message Dictionary (1)
售价仅40元,树莓派Pico开发板加入WiFi模块,刚上市就脱销
随机推荐
helm kubernetes包管理工具
PCL knowledge points - voxelized grid method for down sampling of point clouds
嵌入式 ~ 介绍
HBuilderX运行到手机或模拟器提示没有找到设备
【目标跟踪】|SiamFC
[comment le réseau se connecte] chapitre 6: demande d'accès au serveur et réponse au client (terminé)
牛客JS2 文件扩展名
[target tracking] | data set summary
VScode知识点——常见报错
简单线性规划问题
Visibilitychange – refresh the page data when the specified tab is visible
Microservice architecture practice: Construction of highly available distributed file system fastdfs architecture
Longest non repeating subarray
【曆史上的今天】7 月 2 日:BitTorrent 問世;商業系統 Linspire 被收購;索尼部署 PlayStation Now
JS20 array flattening
Atcoder beginer contest 237 VP supplement
Chapter 15 string localization and message Dictionary (1)
uva1169
第十五章 字符串本地化和消息字典(一)
chrome瀏覽器快速訪問stackoverflow