当前位置:网站首页>基于人脸的常见表情识别——模型搭建、训练与测试
基于人脸的常见表情识别——模型搭建、训练与测试
2022-07-30 19:50:00 【GodGump】
模型搭建与训练
数据接口准备
data_transforms = {
'train': transforms.Compose([
transforms.Scale(64),
transforms.RandomSizedCrop(48),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
]),
'val': transforms.Compose([
transforms.Scale(64),
transforms.CenterCrop(48),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
]),
}
data_dir = './train_val_data/'
image_datasets = {
x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x]) for x in ['train', 'val']}
dataloders = {
x: torch.utils.data.DataLoader(image_datasets[x],
batch_size=16,
shuffle=True,
num_workers=4) for x in ['train', 'val']}
模型定义
import torch.nn as nn
import torch.nn.functional as F
class simpleconv3(nn.Module):
def __init__(self):
super(simpleconv3,self).__init__()
self.conv1 = nn.Conv2d(3, 12, 3, 2)
self.bn1 = nn.BatchNorm2d(12)
self.conv2 = nn.Conv2d(12, 24, 3, 2)
self.bn2 = nn.BatchNorm2d(24)
self.conv3 = nn.Conv2d(24, 48, 3, 2)
self.bn3 = nn.BatchNorm2d(48)
self.fc1 = nn.Linear(48 * 5 * 5 , 1200)
self.fc2 = nn.Linear(1200 , 128)
self.fc3 = nn.Linear(128 , 4)
def forward(self , x):
x = F.relu(self.bn1(self.conv1(x)))
#print "bn1 shape",x.shape
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
x = x.view(-1 , 48 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
模型训练
#coding:utf8
from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
import torchvision
from torchvision import datasets, models, transforms
import time
import os
from tensorboardX import SummaryWriter
import torch.nn.functional as F
import numpy as np
import warnings
warnings.filterwarnings('ignore')
writer = SummaryWriter()
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
for phase in ['train', 'val']:
if phase == 'train':
scheduler.step()
model.train(True) # Set model to training mode
else:
model.train(False) # Set model to evaluate mode
running_loss = 0.0
running_corrects = 0.0
for data in dataloders[phase]:
inputs, labels = data
if use_gpu:
inputs = Variable(inputs.cuda())
labels = Variable(labels.cuda())
else:
inputs, labels = Variable(inputs), Variable(labels)
optimizer.zero_grad()
outputs = model(inputs)
_, preds = torch.max(outputs.data, 1)
loss = criterion(outputs, labels)
if phase == 'train':
loss.backward()
optimizer.step()
running_loss += loss.data.item()
running_corrects += torch.sum(preds == labels).item()
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects / dataset_sizes[phase]
if phase == 'train':
writer.add_scalar('data/trainloss', epoch_loss, epoch)
writer.add_scalar('data/trainacc', epoch_acc, epoch)
else:
writer.add_scalar('data/valloss', epoch_loss, epoch)
writer.add_scalar('data/valacc', epoch_acc, epoch)
print('{} Loss: {:.4f} Acc: {:.4f}'.format(
phase, epoch_loss, epoch_acc))
writer.export_scalars_to_json("./all_scalars.json")
writer.close()
return model
if __name__ == '__main__':
data_transforms = {
'train': transforms.Compose([
transforms.Scale(64),
transforms.RandomSizedCrop(48),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
]),
'val': transforms.Compose([
transforms.Scale(64),
transforms.CenterCrop(48),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
]),
}
data_dir = './Emotion_Recognition_File/train_val_data/' # 数据集所在的位置
image_datasets = {
x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x]) for x in ['train', 'val']}
dataloders = {
x: torch.utils.data.DataLoader(image_datasets[x],
batch_size=64,
shuffle=True if x=="train" else False,
num_workers=8) for x in ['train', 'val']}
dataset_sizes = {
x: len(image_datasets[x]) for x in ['train', 'val']}
use_gpu = torch.cuda.is_available()
print("是否使用 GPU", use_gpu)
modelclc = simpleconv3()
print(modelclc)
if use_gpu:
modelclc = modelclc.cuda()
criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.SGD(modelclc.parameters(), lr=0.1, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=100, gamma=0.1)
modelclc = train_model(model=modelclc,
criterion=criterion,
optimizer=optimizer_ft,
scheduler=exp_lr_scheduler,
num_epochs=10) # 这里可以调节训练的轮次
if not os.path.exists("models"):
os.mkdir('models')
torch.save(modelclc.state_dict(),'models/model.ckpt')
模型测试
# coding:utf8
import sys
import numpy as np
import cv2
import os
import dlib
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
import torchvision
from torchvision import datasets, models, transforms
import time
from PIL import Image
import torch.nn.functional as F
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
PREDICTOR_PATH = "./Emotion_Recognition_File/face_detect_model/shape_predictor_68_face_landmarks.dat"
predictor = dlib.shape_predictor(PREDICTOR_PATH)
cascade_path = './Emotion_Recognition_File/face_detect_model/haarcascade_frontalface_default.xml'
cascade = cv2.CascadeClassifier(cascade_path)
if not os.path.exists("results"):
os.mkdir("results")
def standardization(data):
mu = np.mean(data, axis=0)
sigma = np.std(data, axis=0)
return (data - mu) / sigma
def get_landmarks(im):
rects = cascade.detectMultiScale(im, 1.3, 5)
x, y, w, h = rects[0]
rect = dlib.rectangle(int(x), int(y), int(x + w), int(y + h))
return np.matrix([[p.x, p.y] for p in predictor(im, rect).parts()])
def annotate_landmarks(im, landmarks):
im = im.copy()
for idx, point in enumerate(landmarks):
pos = (point[0, 0], point[0, 1])
cv2.putText(im,
str(idx),
pos,
fontFace=cv2.FONT_HERSHEY_SCRIPT_SIMPLEX,
fontScale=0.4,
color=(0, 0, 255))
cv2.circle(im, pos, 3, color=(0, 255, 255))
return im
testsize = 48 # 测试图大小
data_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
net = simpleconv3()
net.eval()
modelpath = "./models/model.ckpt" # 模型路径
net.load_state_dict(
torch.load(modelpath, map_location=lambda storage, loc: storage))
# 一次测试一个文件
img_path = "./Emotion_Recognition_File/find_face_img/"
imagepaths = os.listdir(img_path) # 图像文件夹
for imagepath in imagepaths:
im = cv2.imread(os.path.join(img_path, imagepath), 1)
try:
rects = cascade.detectMultiScale(im, 1.3, 5)
x, y, w, h = rects[0]
rect = dlib.rectangle(int(x), int(y), int(x + w), int(y + h))
landmarks = np.matrix([[p.x, p.y]
for p in predictor(im, rect).parts()])
except:
# print("没有检测到人脸")
continue # 没有检测到人脸
xmin = 10000
xmax = 0
ymin = 10000
ymax = 0
for i in range(48, 67):
x = landmarks[i, 0]
y = landmarks[i, 1]
if x < xmin:
xmin = x
if x > xmax:
xmax = x
if y < ymin:
ymin = y
if y > ymax:
ymax = y
roiwidth = xmax - xmin
roiheight = ymax - ymin
roi = im[ymin:ymax, xmin:xmax, 0:3]
if roiwidth > roiheight:
dstlen = 1.5 * roiwidth
else:
dstlen = 1.5 * roiheight
diff_xlen = dstlen - roiwidth
diff_ylen = dstlen - roiheight
newx = xmin
newy = ymin
imagerows, imagecols, channel = im.shape
if newx >= diff_xlen / 2 and newx + roiwidth + diff_xlen / 2 < imagecols:
newx = newx - diff_xlen / 2
elif newx < diff_xlen / 2:
newx = 0
else:
newx = imagecols - dstlen
if newy >= diff_ylen / 2 and newy + roiheight + diff_ylen / 2 < imagerows:
newy = newy - diff_ylen / 2
elif newy < diff_ylen / 2:
newy = 0
else:
newy = imagerows - dstlen
roi = im[int(newy):int(newy + dstlen), int(newx):int(newx + dstlen), 0:3]
roi = cv2.cvtColor(roi, cv2.COLOR_BGR2RGB)
roiresized = cv2.resize(roi,
(testsize, testsize)).astype(np.float32) / 255.0
imgblob = data_transforms(roiresized).unsqueeze(0)
imgblob.requires_grad = False
imgblob = Variable(imgblob)
torch.no_grad()
predict = F.softmax(net(imgblob))
print(predict)
index = np.argmax(predict.detach().numpy())
im_show = cv2.imread(os.path.join(img_path, imagepath), 1)
im_h, im_w, im_c = im_show.shape
pos_x = int(newx + dstlen)
pos_y = int(newy + dstlen)
font = cv2.FONT_HERSHEY_SIMPLEX
cv2.rectangle(im_show, (int(newx), int(newy)),
(int(newx + dstlen), int(newy + dstlen)), (0, 255, 255), 2)
if index == 0:
cv2.putText(im_show, 'none', (pos_x, pos_y), font, 1.5, (0, 0, 255), 2)
if index == 1:
cv2.putText(im_show, 'pout', (pos_x, pos_y), font, 1.5, (0, 0, 255), 2)
if index == 2:
cv2.putText(im_show, 'smile', (pos_x, pos_y), font, 1.5, (0, 0, 255), 2)
if index == 3:
cv2.putText(im_show, 'open', (pos_x, pos_y), font, 1.5, (0, 0, 255), 2)
# cv2.namedWindow('result', 0)
# cv2.imshow('result', im_show)
cv2.imwrite(os.path.join('results', imagepath), im_show)
# print(os.path.join('results', imagepath))
plt.imshow(im_show[:, :, ::-1]) # 这里需要交换通道,因为 matplotlib 保存图片的通道顺序是 RGB,而在 OpenCV 中是 BGR
plt.show()
边栏推荐
- el-input 只能输入整数(包括正数、负数、0)或者只能输入整数(包括正数、负数、0)和小数
- 利用go制作微信机器人
- Start foreground Activity
- 来了!东方甄选为龙江农产品直播带货
- Centos7 install mysql8
- MySQL kills 10 questions, how many questions can you stick to?
- MySQL database - views and indexes
- 青蛙跳台阶(递归和非递归)-------小乐乐走台阶
- 【无标题】多集嵌套集合使不再有MultipleBagFetchException
- Database indexes: indexes are not a panacea
猜你喜欢
MindSpore:【模型训练】【mindinsight】timeline的时间和实际用时相差很远
Interviewer Ali: Describe to me the phenomenon of cache breakdown, and talk about your solution?
How to install and use PostgreSQL 14.4
MindSpore: CV.Rescale(rescale,shift)中参数rescale和shift的含义?
M3SDA:用于多源域自适应的矩匹配
iPhone真是十三香?两代产品完全对比,或许上一代更值得买
PHP低代码开发平台 V5.0.7新版发布
ELK log analysis system
HCIP --- 企业网的三层架构
VBA batch import Excel data into Access database
随机推荐
MySQL six-pulse sword, SQL customs clearance summary
Google's AlphaFold claims to have predicted almost every protein structure on Earth
看完《二舅》,我更内耗了
Mapped Statements collection does not contain value for的解决方法
jOOQ是如何设计事务API(详细指南)
MySQL kills 10 questions, how many questions can you stick to?
MindSpore:npu 多卡训练自定义数据集如何给不同npu传递不同数据
MindSpore:【Resolve node failed】解析节点失败的问题
VS Code connects to SQL Server
PHP低代码开发平台 V5.0.7新版发布
MySQL mass production of data
MySQL database --- Addition, deletion, modification and query of MySQL tables (advanced)
What is the difference between a cloud database and an on-premises database?
MySQL大批量造数据
Trial writing C language sanbang
Witness the magical awakening of the mini world in HUAWEI CLOUD
Cesium loads offline maps and offline terrain
推荐系统:AB测试(AB Test)
推荐系统-排序层:排序层架构【用户、物品特征处理步骤】
MySQL数据库主从配置