当前位置:网站首页>轻松学Pytorch-全卷积神经网络实现表情识别
轻松学Pytorch-全卷积神经网络实现表情识别
2022-06-10 18:42:00 【小白学视觉】
重磅干货,第一时间送达
转载自:OpenCV学堂
我又又一周没有更新这个系列文章了,但是我说过我会继续坚持更新下去的,今天给大家更新了一篇如何使用残差Block构建全卷积神经网络实现图像分类,对的,你没有看错就是基于全卷积神经网络实现人脸表情图像的识别,其中数据集一部分来自CK+,更多的是我自己使用OpenVINO的表情识别模型来自动标注的,总数大致有5000张的表情图像。
模型结构
基于残差Block,不过这个Block跟上一篇中不一样地方是支持下采样,它的代码实现如下:
1class ResidualBlock(torch.nn.Module):
2 def __init__(self, in_channels, out_channels, stride=1):
3 """
4 Args:
5 in_channels (int): Number of input channels.
6 out_channels (int): Number of output channels.
7 stride (int): Controls the stride.
8 """
9 super(ResidualBlock, self).__init__()
10
11 self.skip = torch.nn.Sequential()
12
13 if stride != 1 or in_channels != out_channels:
14 self.skip = torch.nn.Sequential(
15 torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, bias=False),
16 torch.nn.BatchNorm2d(out_channels))
17
18 self.block = torch.nn.Sequential(
19 torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, stride=stride, bias=False),
20 torch.nn.BatchNorm2d(out_channels),
21 torch.nn.ReLU(),
22 torch.nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1, stride=1, bias=False),
23 torch.nn.BatchNorm2d(out_channels))
24
25 def forward(self, x):
26 out = self.block(x)
27 identity = self.skip(x)
28 out += identity
29 out = F.relu(out)
30 return out其中stride参数为2的时候就会实现自动下采样;为1的时候表示跟前面大小保持一致。
模型结构中包括多个残差Block,最终的输出Nx8x1x1, 表示8种表情,然后通过softmax完成分类识别。模型的输入:NCHW=Nx3x64x64。模型结构参考了OpenVINO框架中的Caffe版本的表情识别模型。最终的模型实现代码如下:
1class EmotionsResNet(torch.nn.Module):
2 def __init__(self):
3 super(EmotionsResNet, self).__init__()
4 self.cnn_layers = torch.nn.Sequential(
5 # 卷积层 (64x64x3的图像)
6 ResidualBlock(3, 32, 1),
7 ResidualBlock(32, 64, 2),
8 ResidualBlock(64, 64, 2),
9 ResidualBlock(64, 128, 2),
10 ResidualBlock(128, 128, 2),
11 ResidualBlock(128, 256, 2),
12 ResidualBlock(256, 256, 2),
13 ResidualBlock(256, 8, 1),
14 )
15
16 def forward(self, x):
17 # stack convolution layers
18 x = self.cnn_layers(x)
19
20 # Nx5x1x1
21 B, C, H, W = x.size()
22 out = x.view(B, -1)
23 return out模型训练:
基于交叉熵实现了模型训练,训练了15个epoch之后,保存模型。训练的代码如下:
1if __name__ == "__main__":
2 # create a complete CNN
3 model = EmotionsResNet()
4 print(model)
5
6 # 使用GPU
7 if train_on_gpu:
8 model.cuda()
9
10 ds = EmotionDataset("D:/facedb/emotion_dataset")
11 num_train_samples = ds.num_of_samples()
12 bs = 16
13 dataloader = DataLoader(ds, batch_size=bs, shuffle=True)
14
15 # 训练模型的次数
16 num_epochs = 15
17 # optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
18 optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
19 model.train()
20
21 # 损失函数
22 mse_loss = torch.nn.MSELoss()
23 cross_loss = torch.nn.CrossEntropyLoss()
24 index = 0
25 for epoch in range(num_epochs):
26 train_loss = 0.0
27 for i_batch, sample_batched in enumerate(dataloader):
28 images_batch, emotion_batch = \
29 sample_batched['image'], sample_batched['emotion']
30 if train_on_gpu:
31 images_batch, emotion_batch= images_batch.cuda(), emotion_batch.cuda()
32 optimizer.zero_grad()
33
34 # forward pass: compute predicted outputs by passing inputs to the model
35 m_emotion_out_ = model(images_batch)
36 emotion_batch = emotion_batch.long()
37
38 # calculate the batch loss
39 loss = cross_loss(m_emotion_out_, emotion_batch)
40
41 # backward pass: compute gradient of the loss with respect to model parameters
42 loss.backward()
43
44 # perform a single optimization step (parameter update)
45 optimizer.step()
46
47 # update training loss
48 train_loss += loss.item()
49 if index % 100 == 0:
50 print('step: {} \tTraining Loss: {:.6f} '.format(index, loss.item()))
51 index += 1
52
53 # 计算平均损失
54 train_loss = train_loss / num_train_samples
55
56 # 显示训练集与验证集的损失函数
57 print('Epoch: {} \tTraining Loss: {:.6f} '.format(epoch, train_loss))
58
59 # save model
60 model.eval()
61 torch.save(model, 'face_emotions_model.pt')测试与演示
基于OpenCV人脸检测得到的ROI区域,输入到训练好的人脸表情识别模型中,就可以预测人脸表情,完成实时人脸表情识别,演示代码如下:
1cnn_model = torch.load("./face_emotions_model.pt")
2print(cnn_model)
3# capture = cv.VideoCapture(0)
4capture = cv.VideoCapture("D:/images/video/example_dsh.mp4")
5
6# load tensorflow model
7net = cv.dnn.readNetFromTensorflow(model_bin, config=config_text)
8while True:
9 ret, frame = capture.read()
10 if ret is not True:
11 break
12 frame = cv.flip(frame, 1)
13 h, w, c = frame.shape
14 blobImage = cv.dnn.blobFromImage(frame, 1.0, (300, 300), (104.0, 177.0, 123.0), False, False);
15 net.setInput(blobImage)
16 cvOut = net.forward()
17 # 绘制检测矩形
18 for detection in cvOut[0,0,:,:]:
19 score = float(detection[2])
20 if score > 0.5:
21 left = detection[3]*w
22 top = detection[4]*h
23 right = detection[5]*w
24 bottom = detection[6]*h
25
26 # roi and detect landmark
27 roi = frame[np.int32(top):np.int32(bottom),np.int32(left):np.int32(right),:]
28 rw = right - left
29 rh = bottom - top
30 img = cv.resize(roi, (64, 64))
31 img = (np.float32(img) / 255.0 - 0.5) / 0.5
32 img = img.transpose((2, 0, 1))
33 x_input = torch.from_numpy(img).view(1, 3, 64, 64)
34 emotion_ = cnn_model(x_input.cuda())
35 predict_ = torch.max(emotion_, 1)[1].cpu().detach().numpy()[0]
36 emotion_txt = emotion_labels[predict_]
37 # 绘制
38 cv.putText(frame, ("%s"%(emotion_txt)), (int(left), int(top)-15), cv.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
39 cv.rectangle(frame, (int(left), int(top)), (int(right), int(bottom)), (255, 0, 0), thickness=2)
40 c = cv.waitKey(10)
41 if c == 27:
42 break
43 cv.imshow("face detection + emotion", frame)
44
45cv.waitKey(0)
46cv.destroyAllWindows()运行结果如下:

废话就不多说了,还是希望大家支持,我继续写下去!
好消息!
小白学视觉知识星球
开始面向外开放啦

下载1:OpenCV-Contrib扩展模块中文版教程
在「小白学视觉」公众号后台回复:扩展模块中文教程,即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。
下载2:Python视觉实战项目52讲
在「小白学视觉」公众号后台回复:Python视觉实战项目,即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。
下载3:OpenCV实战项目20讲
在「小白学视觉」公众号后台回复:OpenCV实战项目20讲,即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。
交流群
欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~边栏推荐
- c(指针-02)
- MATLAB 根据任意角度、取样点数(分辨率)、位置、大小画椭圆代码
- c指针(面试经典题目练习)
- Datascience & ml: detailed introduction to risk control indicators / field related concepts and dimension logic of risk control in the field of financial technology
- New trends and prospects of data center planning and design under the background of "double carbon"
- 通过举栗子的方式来讲解面试题(可面试,可复习,可学习)
- [C language] still don't understand the structure? Take a look at this article to give you a preliminary understanding of structure
- 2022.05.25(LC_718_最长重复子数组)
- 100003字,带你解密 双11、618电商大促场景下的系统架构体系
- Go language learning notes - cross domain configuration, global exception capture | web framework gin (IV)
猜你喜欢

mysql8.0(新特性小结)

Analysis of optical storage direct flexible power distribution system

【C语言进阶】指针的进阶【上篇】

c指针(面试经典题目练习)

SAR image focusing quality evaluation plug-in

大厂是怎么写数据分析报告的?

c(指针02)

This article introduces you to j.u.c's futuretask, fork/join framework and BlockingQueue

如何查询 SAPGUI 屏幕上某个字段对应的数据库表存储

mysql(17-触发器)
随机推荐
Tencent libco collaboration open source library source code analysis full series summary blog
Super simple course design SSM student management system (including simple addition, deletion, modification and query of source code)
China pufuteng hotels and resorts launched new spa products to celebrate the global health day on June 11
VS从txt文件读取中文汉字产生乱码的解决办法(超简单)
叮咚抢菜-派送时段监听及推送工具
This article introduces you to j.u.c's futuretask, fork/join framework and BlockingQueue
MySQL数据库设计概念(多表查询&事务操作)
Nodejs judge system type get host name execute console command Chinese garbled code
数据库防火墙的性能和高可用性分析
How is it safe for individuals to invest in financial management?
写作技术文章是留给自己未来的财富
Mysql (17 déclencheurs)
调试的技巧
我的第一部作品:TensorFlow2.x
中国 璞富腾酒店及度假村旗下酒店推出全新水疗产品共庆6月11日全球健康日
LeetCode_ Concurrent search set_ Medium_ 399. division evaluation
lingo12软件下载及lingo语言入门资源
Go语学习笔记 - 跨域配置、全局异常捕获 | Web框架Gin(四)
APICloud可视化开发新手图文教程
【 random talk 】 congratulations on getting the title of CSDN expert. Your efforts will eventually pay off