当前位置:网站首页>基于人脸的常见表情识别——模型搭建、训练与测试
基于人脸的常见表情识别——模型搭建、训练与测试
2022-07-30 19:50:00 【GodGump】
模型搭建与训练
数据接口准备
data_transforms = {
'train': transforms.Compose([
transforms.Scale(64),
transforms.RandomSizedCrop(48),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
]),
'val': transforms.Compose([
transforms.Scale(64),
transforms.CenterCrop(48),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
]),
}
data_dir = './train_val_data/'
image_datasets = {
x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x]) for x in ['train', 'val']}
dataloders = {
x: torch.utils.data.DataLoader(image_datasets[x],
batch_size=16,
shuffle=True,
num_workers=4) for x in ['train', 'val']}
模型定义
import torch.nn as nn
import torch.nn.functional as F
class simpleconv3(nn.Module):
def __init__(self):
super(simpleconv3,self).__init__()
self.conv1 = nn.Conv2d(3, 12, 3, 2)
self.bn1 = nn.BatchNorm2d(12)
self.conv2 = nn.Conv2d(12, 24, 3, 2)
self.bn2 = nn.BatchNorm2d(24)
self.conv3 = nn.Conv2d(24, 48, 3, 2)
self.bn3 = nn.BatchNorm2d(48)
self.fc1 = nn.Linear(48 * 5 * 5 , 1200)
self.fc2 = nn.Linear(1200 , 128)
self.fc3 = nn.Linear(128 , 4)
def forward(self , x):
x = F.relu(self.bn1(self.conv1(x)))
#print "bn1 shape",x.shape
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
x = x.view(-1 , 48 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
模型训练
#coding:utf8
from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
import torchvision
from torchvision import datasets, models, transforms
import time
import os
from tensorboardX import SummaryWriter
import torch.nn.functional as F
import numpy as np
import warnings
warnings.filterwarnings('ignore')
writer = SummaryWriter()
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
for phase in ['train', 'val']:
if phase == 'train':
scheduler.step()
model.train(True) # Set model to training mode
else:
model.train(False) # Set model to evaluate mode
running_loss = 0.0
running_corrects = 0.0
for data in dataloders[phase]:
inputs, labels = data
if use_gpu:
inputs = Variable(inputs.cuda())
labels = Variable(labels.cuda())
else:
inputs, labels = Variable(inputs), Variable(labels)
optimizer.zero_grad()
outputs = model(inputs)
_, preds = torch.max(outputs.data, 1)
loss = criterion(outputs, labels)
if phase == 'train':
loss.backward()
optimizer.step()
running_loss += loss.data.item()
running_corrects += torch.sum(preds == labels).item()
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects / dataset_sizes[phase]
if phase == 'train':
writer.add_scalar('data/trainloss', epoch_loss, epoch)
writer.add_scalar('data/trainacc', epoch_acc, epoch)
else:
writer.add_scalar('data/valloss', epoch_loss, epoch)
writer.add_scalar('data/valacc', epoch_acc, epoch)
print('{} Loss: {:.4f} Acc: {:.4f}'.format(
phase, epoch_loss, epoch_acc))
writer.export_scalars_to_json("./all_scalars.json")
writer.close()
return model
if __name__ == '__main__':
data_transforms = {
'train': transforms.Compose([
transforms.Scale(64),
transforms.RandomSizedCrop(48),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
]),
'val': transforms.Compose([
transforms.Scale(64),
transforms.CenterCrop(48),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
]),
}
data_dir = './Emotion_Recognition_File/train_val_data/' # 数据集所在的位置
image_datasets = {
x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x]) for x in ['train', 'val']}
dataloders = {
x: torch.utils.data.DataLoader(image_datasets[x],
batch_size=64,
shuffle=True if x=="train" else False,
num_workers=8) for x in ['train', 'val']}
dataset_sizes = {
x: len(image_datasets[x]) for x in ['train', 'val']}
use_gpu = torch.cuda.is_available()
print("是否使用 GPU", use_gpu)
modelclc = simpleconv3()
print(modelclc)
if use_gpu:
modelclc = modelclc.cuda()
criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.SGD(modelclc.parameters(), lr=0.1, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=100, gamma=0.1)
modelclc = train_model(model=modelclc,
criterion=criterion,
optimizer=optimizer_ft,
scheduler=exp_lr_scheduler,
num_epochs=10) # 这里可以调节训练的轮次
if not os.path.exists("models"):
os.mkdir('models')
torch.save(modelclc.state_dict(),'models/model.ckpt')
模型测试
# coding:utf8
import sys
import numpy as np
import cv2
import os
import dlib
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
import torchvision
from torchvision import datasets, models, transforms
import time
from PIL import Image
import torch.nn.functional as F
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
PREDICTOR_PATH = "./Emotion_Recognition_File/face_detect_model/shape_predictor_68_face_landmarks.dat"
predictor = dlib.shape_predictor(PREDICTOR_PATH)
cascade_path = './Emotion_Recognition_File/face_detect_model/haarcascade_frontalface_default.xml'
cascade = cv2.CascadeClassifier(cascade_path)
if not os.path.exists("results"):
os.mkdir("results")
def standardization(data):
mu = np.mean(data, axis=0)
sigma = np.std(data, axis=0)
return (data - mu) / sigma
def get_landmarks(im):
rects = cascade.detectMultiScale(im, 1.3, 5)
x, y, w, h = rects[0]
rect = dlib.rectangle(int(x), int(y), int(x + w), int(y + h))
return np.matrix([[p.x, p.y] for p in predictor(im, rect).parts()])
def annotate_landmarks(im, landmarks):
im = im.copy()
for idx, point in enumerate(landmarks):
pos = (point[0, 0], point[0, 1])
cv2.putText(im,
str(idx),
pos,
fontFace=cv2.FONT_HERSHEY_SCRIPT_SIMPLEX,
fontScale=0.4,
color=(0, 0, 255))
cv2.circle(im, pos, 3, color=(0, 255, 255))
return im
testsize = 48 # 测试图大小
data_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
net = simpleconv3()
net.eval()
modelpath = "./models/model.ckpt" # 模型路径
net.load_state_dict(
torch.load(modelpath, map_location=lambda storage, loc: storage))
# 一次测试一个文件
img_path = "./Emotion_Recognition_File/find_face_img/"
imagepaths = os.listdir(img_path) # 图像文件夹
for imagepath in imagepaths:
im = cv2.imread(os.path.join(img_path, imagepath), 1)
try:
rects = cascade.detectMultiScale(im, 1.3, 5)
x, y, w, h = rects[0]
rect = dlib.rectangle(int(x), int(y), int(x + w), int(y + h))
landmarks = np.matrix([[p.x, p.y]
for p in predictor(im, rect).parts()])
except:
# print("没有检测到人脸")
continue # 没有检测到人脸
xmin = 10000
xmax = 0
ymin = 10000
ymax = 0
for i in range(48, 67):
x = landmarks[i, 0]
y = landmarks[i, 1]
if x < xmin:
xmin = x
if x > xmax:
xmax = x
if y < ymin:
ymin = y
if y > ymax:
ymax = y
roiwidth = xmax - xmin
roiheight = ymax - ymin
roi = im[ymin:ymax, xmin:xmax, 0:3]
if roiwidth > roiheight:
dstlen = 1.5 * roiwidth
else:
dstlen = 1.5 * roiheight
diff_xlen = dstlen - roiwidth
diff_ylen = dstlen - roiheight
newx = xmin
newy = ymin
imagerows, imagecols, channel = im.shape
if newx >= diff_xlen / 2 and newx + roiwidth + diff_xlen / 2 < imagecols:
newx = newx - diff_xlen / 2
elif newx < diff_xlen / 2:
newx = 0
else:
newx = imagecols - dstlen
if newy >= diff_ylen / 2 and newy + roiheight + diff_ylen / 2 < imagerows:
newy = newy - diff_ylen / 2
elif newy < diff_ylen / 2:
newy = 0
else:
newy = imagerows - dstlen
roi = im[int(newy):int(newy + dstlen), int(newx):int(newx + dstlen), 0:3]
roi = cv2.cvtColor(roi, cv2.COLOR_BGR2RGB)
roiresized = cv2.resize(roi,
(testsize, testsize)).astype(np.float32) / 255.0
imgblob = data_transforms(roiresized).unsqueeze(0)
imgblob.requires_grad = False
imgblob = Variable(imgblob)
torch.no_grad()
predict = F.softmax(net(imgblob))
print(predict)
index = np.argmax(predict.detach().numpy())
im_show = cv2.imread(os.path.join(img_path, imagepath), 1)
im_h, im_w, im_c = im_show.shape
pos_x = int(newx + dstlen)
pos_y = int(newy + dstlen)
font = cv2.FONT_HERSHEY_SIMPLEX
cv2.rectangle(im_show, (int(newx), int(newy)),
(int(newx + dstlen), int(newy + dstlen)), (0, 255, 255), 2)
if index == 0:
cv2.putText(im_show, 'none', (pos_x, pos_y), font, 1.5, (0, 0, 255), 2)
if index == 1:
cv2.putText(im_show, 'pout', (pos_x, pos_y), font, 1.5, (0, 0, 255), 2)
if index == 2:
cv2.putText(im_show, 'smile', (pos_x, pos_y), font, 1.5, (0, 0, 255), 2)
if index == 3:
cv2.putText(im_show, 'open', (pos_x, pos_y), font, 1.5, (0, 0, 255), 2)
# cv2.namedWindow('result', 0)
# cv2.imshow('result', im_show)
cv2.imwrite(os.path.join('results', imagepath), im_show)
# print(os.path.join('results', imagepath))
plt.imshow(im_show[:, :, ::-1]) # 这里需要交换通道,因为 matplotlib 保存图片的通道顺序是 RGB,而在 OpenCV 中是 BGR
plt.show()
边栏推荐
- MySQL数据库————视图和索引
- Alibaba Cloud Martial Arts Headline Event Sharing
- 【私人系列】日常PHP遇到的各种稀奇古怪的问题
- 刷题记录----字符串
- Download and installation of the latest version of MySQL 8.0 under Linux (detailed steps)
- VBA batch import Excel data into Access database
- 【Node实现数据加密】
- jOOQ是如何设计事务API(详细指南)
- MindSpore:【resnet_thor模型】尝试运行resnet_thor时报Could not convert to
- Witness the magical awakening of the mini world in HUAWEI CLOUD
猜你喜欢
ELK日志分析系统
Install Mysql5.7 under Linux, super detailed and complete tutorial, and cloud mysql connection
“数字化重构系统,搞定 CEO 是第一步”
Database Tuning - Database Tuning
Perfectly Clear QuickDesk & QuickServer图像校正优化工具
Talking about Contrastive Learning (Contrastive Learning) the first bullet
【无标题】多集嵌套集合使不再有MultipleBagFetchException
MySQL数据库主从配置
Golang logging library zerolog use record
HCIP --- 企业网的三层架构
随机推荐
Linux下安装Mysql5.7,超详细完整教程,以及云mysql连接
Niuke.com - Huawei Question Bank (100~108)
Start foreground Activity
VS Code connects to SQL Server
Linux下安装MySQL教程
linux下mysql8安装
在jOOQ中获取数据的多种不同方式
How to copy table structure and table data in MySQL
055 c# print
MindSpore:【Resolve node failed】解析节点失败的问题
el-input 只能输入整数(包括正数、负数、0)或者只能输入整数(包括正数、负数、0)和小数
推荐系统:AB测试(AB Test)
MindSpore:对image作normalize的目的是什么?
【私人系列】日常PHP遇到的各种稀奇古怪的问题
Another company interview
MindSpore:【JupyterLab】按照新手教程训练时报错
MySQL数据库之JDBC编程
MySQL database - DQL data query language
PHP低代码开发平台 V5.0.7新版发布
【MindSpore】用coco2017训练Model_zoo上的 yolov4,迭代了两千多batch_size之后报错,大佬们帮忙看看。