当前位置:网站首页>Easily learn pytoch full convolution neural network to realize expression recognition
Easily learn pytoch full convolution neural network to realize expression recognition
2022-06-10 19:45:00 【Xiaobai learns vision】
Heavy dry goods , First time delivery
Reprinted from :OpenCV School
I haven't updated this series for another week , But I said I would keep updating , Today I updated an article on how to use residuals Block Constructing full convolution neural network to realize image classification , Right , You're right. It's the recognition of facial expression image based on full convolution neural network , Part of the data set is from CK+, I use it more by myself OpenVINO Expression recognition model to automatically annotate , The total is roughly 5000 Zhang's expression image .
Model structure
Based on residual Block, But this Block The difference in the previous article is that sampling is supported , Its code implementation is as follows :
1class ResidualBlock(torch.nn.Module):
2 def __init__(self, in_channels, out_channels, stride=1):
3 """
4 Args:
5 in_channels (int): Number of input channels.
6 out_channels (int): Number of output channels.
7 stride (int): Controls the stride.
8 """
9 super(ResidualBlock, self).__init__()
10
11 self.skip = torch.nn.Sequential()
12
13 if stride != 1 or in_channels != out_channels:
14 self.skip = torch.nn.Sequential(
15 torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, bias=False),
16 torch.nn.BatchNorm2d(out_channels))
17
18 self.block = torch.nn.Sequential(
19 torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, stride=stride, bias=False),
20 torch.nn.BatchNorm2d(out_channels),
21 torch.nn.ReLU(),
22 torch.nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1, stride=1, bias=False),
23 torch.nn.BatchNorm2d(out_channels))
24
25 def forward(self, x):
26 out = self.block(x)
27 identity = self.skip(x)
28 out += identity
29 out = F.relu(out)
30 return outamong stride Parameter is 2 The automatic down sampling will be realized when ; by 1 It means the same size as the previous one .
Multiple residuals are included in the model structure Block, The final output Nx8x1x1, Express 8 Kind of expression , And then through softmax Complete classification identification . Model input :NCHW=Nx3x64x64. Model structure references OpenVINO In the framework Caffe Version of the expression recognition model . The final model implementation code is as follows :
1class EmotionsResNet(torch.nn.Module):
2 def __init__(self):
3 super(EmotionsResNet, self).__init__()
4 self.cnn_layers = torch.nn.Sequential(
5 # Convolution layer (64x64x3 Image )
6 ResidualBlock(3, 32, 1),
7 ResidualBlock(32, 64, 2),
8 ResidualBlock(64, 64, 2),
9 ResidualBlock(64, 128, 2),
10 ResidualBlock(128, 128, 2),
11 ResidualBlock(128, 256, 2),
12 ResidualBlock(256, 256, 2),
13 ResidualBlock(256, 8, 1),
14 )
15
16 def forward(self, x):
17 # stack convolution layers
18 x = self.cnn_layers(x)
19
20 # Nx5x1x1
21 B, C, H, W = x.size()
22 out = x.view(B, -1)
23 return outmodel training :
Model training is realized based on cross entropy , Trained 15 individual epoch after , Save the model . The training code is as follows :
1if __name__ == "__main__":
2 # create a complete CNN
3 model = EmotionsResNet()
4 print(model)
5
6 # Use GPU
7 if train_on_gpu:
8 model.cuda()
9
10 ds = EmotionDataset("D:/facedb/emotion_dataset")
11 num_train_samples = ds.num_of_samples()
12 bs = 16
13 dataloader = DataLoader(ds, batch_size=bs, shuffle=True)
14
15 # The number of training models
16 num_epochs = 15
17 # optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
18 optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
19 model.train()
20
21 # Loss function
22 mse_loss = torch.nn.MSELoss()
23 cross_loss = torch.nn.CrossEntropyLoss()
24 index = 0
25 for epoch in range(num_epochs):
26 train_loss = 0.0
27 for i_batch, sample_batched in enumerate(dataloader):
28 images_batch, emotion_batch = \
29 sample_batched['image'], sample_batched['emotion']
30 if train_on_gpu:
31 images_batch, emotion_batch= images_batch.cuda(), emotion_batch.cuda()
32 optimizer.zero_grad()
33
34 # forward pass: compute predicted outputs by passing inputs to the model
35 m_emotion_out_ = model(images_batch)
36 emotion_batch = emotion_batch.long()
37
38 # calculate the batch loss
39 loss = cross_loss(m_emotion_out_, emotion_batch)
40
41 # backward pass: compute gradient of the loss with respect to model parameters
42 loss.backward()
43
44 # perform a single optimization step (parameter update)
45 optimizer.step()
46
47 # update training loss
48 train_loss += loss.item()
49 if index % 100 == 0:
50 print('step: {} \tTraining Loss: {:.6f} '.format(index, loss.item()))
51 index += 1
52
53 # Calculate the average loss
54 train_loss = train_loss / num_train_samples
55
56 # Show the loss function of training set and verification set
57 print('Epoch: {} \tTraining Loss: {:.6f} '.format(epoch, train_loss))
58
59 # save model
60 model.eval()
61 torch.save(model, 'face_emotions_model.pt')Testing and demonstration
be based on OpenCV Face detection ROI Area , Input into the trained facial expression recognition model , You can predict facial expressions , Complete real-time facial expression recognition , The demo code is as follows :
1cnn_model = torch.load("./face_emotions_model.pt")
2print(cnn_model)
3# capture = cv.VideoCapture(0)
4capture = cv.VideoCapture("D:/images/video/example_dsh.mp4")
5
6# load tensorflow model
7net = cv.dnn.readNetFromTensorflow(model_bin, config=config_text)
8while True:
9 ret, frame = capture.read()
10 if ret is not True:
11 break
12 frame = cv.flip(frame, 1)
13 h, w, c = frame.shape
14 blobImage = cv.dnn.blobFromImage(frame, 1.0, (300, 300), (104.0, 177.0, 123.0), False, False);
15 net.setInput(blobImage)
16 cvOut = net.forward()
17 # Draw detection rectangle
18 for detection in cvOut[0,0,:,:]:
19 score = float(detection[2])
20 if score > 0.5:
21 left = detection[3]*w
22 top = detection[4]*h
23 right = detection[5]*w
24 bottom = detection[6]*h
25
26 # roi and detect landmark
27 roi = frame[np.int32(top):np.int32(bottom),np.int32(left):np.int32(right),:]
28 rw = right - left
29 rh = bottom - top
30 img = cv.resize(roi, (64, 64))
31 img = (np.float32(img) / 255.0 - 0.5) / 0.5
32 img = img.transpose((2, 0, 1))
33 x_input = torch.from_numpy(img).view(1, 3, 64, 64)
34 emotion_ = cnn_model(x_input.cuda())
35 predict_ = torch.max(emotion_, 1)[1].cpu().detach().numpy()[0]
36 emotion_txt = emotion_labels[predict_]
37 # draw
38 cv.putText(frame, ("%s"%(emotion_txt)), (int(left), int(top)-15), cv.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
39 cv.rectangle(frame, (int(left), int(top)), (int(right), int(bottom)), (255, 0, 0), thickness=2)
40 c = cv.waitKey(10)
41 if c == 27:
42 break
43 cv.imshow("face detection + emotion", frame)
44
45cv.waitKey(0)
46cv.destroyAllWindows()The operation results are as follows :

No more nonsense , I hope you will support me , I'll keep writing !
The good news !
Xiaobai learns visual knowledge about the planet
Open to the outside world

download 1:OpenCV-Contrib Chinese version of extension module
stay 「 Xiaobai studies vision 」 Official account back office reply : Extension module Chinese course , You can download the first copy of the whole network OpenCV Extension module tutorial Chinese version , Cover expansion module installation 、SFM Algorithm 、 Stereo vision 、 Target tracking 、 Biological vision 、 Super resolution processing and other more than 20 chapters .
download 2:Python Visual combat project 52 speak
stay 「 Xiaobai studies vision 」 Official account back office reply :Python Visual combat project , You can download, including image segmentation 、 Mask detection 、 Lane line detection 、 Vehicle count 、 Add Eyeliner 、 License plate recognition 、 Character recognition 、 Emotional tests 、 Text content extraction 、 Face recognition, etc 31 A visual combat project , Help fast school computer vision .
download 3:OpenCV Actual project 20 speak
stay 「 Xiaobai studies vision 」 Official account back office reply :OpenCV Actual project 20 speak , You can download the 20 Based on OpenCV Realization 20 A real project , Realization OpenCV Learn advanced .
Communication group
Welcome to join the official account reader group to communicate with your colleagues , There are SLAM、 3 d visual 、 sensor 、 Autopilot 、 Computational photography 、 testing 、 Division 、 distinguish 、 Medical imaging 、GAN、 Wechat groups such as algorithm competition ( It will be subdivided gradually in the future ), Please scan the following micro signal clustering , remarks :” nickname + School / company + Research direction “, for example :” Zhang San + Shanghai Jiaotong University + Vision SLAM“. Please note... According to the format , Otherwise, it will not pass . After successful addition, they will be invited to relevant wechat groups according to the research direction . Please do not send ads in the group , Or you'll be invited out , Thanks for your understanding ~边栏推荐
- 掌握高性能计算前,我们先了解一下它的历史
- Code solution of simplex method (including super detailed code notes and the whole flow chart)
- 【web】個人主頁web大作業「課錶」「相册」「留言板」
- 腾讯云数据库TDSQL-大咖论道 | 基础软件的过去、现在、未来
- Low carbon data center construction ideas and future trends
- Beijing Metro ticketing system
- This article introduces you to j.u.c's futuretask, fork/join framework and BlockingQueue
- Tencent cloud database tdsql- a big guy talks about the past, present and future of basic software
- c指针(面试经典题目练习)
- ESP8266 系统环境搭建
猜你喜欢

【web】个人主页web大作业「课表」「相册」「留言板」

软件测试月薪10K如何涨到30K,只有自动化测试能做到
叮咚抢菜-派送时段监听及推送工具

掌握高性能计算前,我们先了解一下它的历史

100003 words, take you to decrypt the system architecture under the double 11 and 618 e-commerce promotion scenarios

Developers changing the world - Yao Guang teenagers playing Tetris

2022.05.28(LC_516_最长回文子序列)

专项测试之「 性能测试」总结

DDD落地实践复盘 - 记理论培训&事件风暴

【C语言】还搞不明白结构体吗?不妨来看看这篇文章,带你初步了解结构体
随机推荐
Mysql database design concept (multi table query & transaction operation)
100003字,带你解密 双11、618电商大促场景下的系统架构体系
DDD landing practice repeat record of theoretical training & Event storm
Before we learn about high-performance computing, let's take a look at its history
MySQL数据库设计概念(多表查询&事务操作)
Source code analysis of Tencent libco collaborative process open source library (II) -- persimmon starts from the soft pinch, and the sample code officially begins to explore the source code
Beijing Metro ticketing system
中国 璞富腾酒店及度假村旗下酒店推出全新水疗产品共庆6月11日全球健康日
Framework and practice of smart city network security construction
This article introduces you to j.u.c's futuretask, fork/join framework and BlockingQueue
Morris traversal of binary tree
Yuntu says that every successful business system cannot be separated from apig
改变世界的开发者丨玩转“俄罗斯方块”的瑶光少年
VS从txt文件读取中文汉字产生乱码的解决办法(超简单)
Vs solution to garbled Chinese characters read from txt files (super simple)
【web】個人主頁web大作業「課錶」「相册」「留言板」
Basic improvement - tree DP supplement
专项测试之「 性能测试」总结
Apicloud visual development novice graphic tutorial
深入理解LightGBM