当前位置:网站首页>Based on the face of the common expression recognition - model building, training and testing
Based on the face of the common expression recognition - model building, training and testing
2022-07-30 20:01:00 【GodGump】
模型搭建与训练
数据接口准备
data_transforms = {
'train': transforms.Compose([
transforms.Scale(64),
transforms.RandomSizedCrop(48),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
]),
'val': transforms.Compose([
transforms.Scale(64),
transforms.CenterCrop(48),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
]),
}
data_dir = './train_val_data/'
image_datasets = {
x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x]) for x in ['train', 'val']}
dataloders = {
x: torch.utils.data.DataLoader(image_datasets[x],
batch_size=16,
shuffle=True,
num_workers=4) for x in ['train', 'val']}
模型定义
import torch.nn as nn
import torch.nn.functional as F
class simpleconv3(nn.Module):
def __init__(self):
super(simpleconv3,self).__init__()
self.conv1 = nn.Conv2d(3, 12, 3, 2)
self.bn1 = nn.BatchNorm2d(12)
self.conv2 = nn.Conv2d(12, 24, 3, 2)
self.bn2 = nn.BatchNorm2d(24)
self.conv3 = nn.Conv2d(24, 48, 3, 2)
self.bn3 = nn.BatchNorm2d(48)
self.fc1 = nn.Linear(48 * 5 * 5 , 1200)
self.fc2 = nn.Linear(1200 , 128)
self.fc3 = nn.Linear(128 , 4)
def forward(self , x):
x = F.relu(self.bn1(self.conv1(x)))
#print "bn1 shape",x.shape
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
x = x.view(-1 , 48 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
模型训练
#coding:utf8
from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
import torchvision
from torchvision import datasets, models, transforms
import time
import os
from tensorboardX import SummaryWriter
import torch.nn.functional as F
import numpy as np
import warnings
warnings.filterwarnings('ignore')
writer = SummaryWriter()
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
for phase in ['train', 'val']:
if phase == 'train':
scheduler.step()
model.train(True) # Set model to training mode
else:
model.train(False) # Set model to evaluate mode
running_loss = 0.0
running_corrects = 0.0
for data in dataloders[phase]:
inputs, labels = data
if use_gpu:
inputs = Variable(inputs.cuda())
labels = Variable(labels.cuda())
else:
inputs, labels = Variable(inputs), Variable(labels)
optimizer.zero_grad()
outputs = model(inputs)
_, preds = torch.max(outputs.data, 1)
loss = criterion(outputs, labels)
if phase == 'train':
loss.backward()
optimizer.step()
running_loss += loss.data.item()
running_corrects += torch.sum(preds == labels).item()
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects / dataset_sizes[phase]
if phase == 'train':
writer.add_scalar('data/trainloss', epoch_loss, epoch)
writer.add_scalar('data/trainacc', epoch_acc, epoch)
else:
writer.add_scalar('data/valloss', epoch_loss, epoch)
writer.add_scalar('data/valacc', epoch_acc, epoch)
print('{} Loss: {:.4f} Acc: {:.4f}'.format(
phase, epoch_loss, epoch_acc))
writer.export_scalars_to_json("./all_scalars.json")
writer.close()
return model
if __name__ == '__main__':
data_transforms = {
'train': transforms.Compose([
transforms.Scale(64),
transforms.RandomSizedCrop(48),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
]),
'val': transforms.Compose([
transforms.Scale(64),
transforms.CenterCrop(48),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
]),
}
data_dir = './Emotion_Recognition_File/train_val_data/' # 数据集所在的位置
image_datasets = {
x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x]) for x in ['train', 'val']}
dataloders = {
x: torch.utils.data.DataLoader(image_datasets[x],
batch_size=64,
shuffle=True if x=="train" else False,
num_workers=8) for x in ['train', 'val']}
dataset_sizes = {
x: len(image_datasets[x]) for x in ['train', 'val']}
use_gpu = torch.cuda.is_available()
print("是否使用 GPU", use_gpu)
modelclc = simpleconv3()
print(modelclc)
if use_gpu:
modelclc = modelclc.cuda()
criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.SGD(modelclc.parameters(), lr=0.1, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=100, gamma=0.1)
modelclc = train_model(model=modelclc,
criterion=criterion,
optimizer=optimizer_ft,
scheduler=exp_lr_scheduler,
num_epochs=10) # 这里可以调节训练的轮次
if not os.path.exists("models"):
os.mkdir('models')
torch.save(modelclc.state_dict(),'models/model.ckpt')
模型测试
# coding:utf8
import sys
import numpy as np
import cv2
import os
import dlib
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
import torchvision
from torchvision import datasets, models, transforms
import time
from PIL import Image
import torch.nn.functional as F
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
PREDICTOR_PATH = "./Emotion_Recognition_File/face_detect_model/shape_predictor_68_face_landmarks.dat"
predictor = dlib.shape_predictor(PREDICTOR_PATH)
cascade_path = './Emotion_Recognition_File/face_detect_model/haarcascade_frontalface_default.xml'
cascade = cv2.CascadeClassifier(cascade_path)
if not os.path.exists("results"):
os.mkdir("results")
def standardization(data):
mu = np.mean(data, axis=0)
sigma = np.std(data, axis=0)
return (data - mu) / sigma
def get_landmarks(im):
rects = cascade.detectMultiScale(im, 1.3, 5)
x, y, w, h = rects[0]
rect = dlib.rectangle(int(x), int(y), int(x + w), int(y + h))
return np.matrix([[p.x, p.y] for p in predictor(im, rect).parts()])
def annotate_landmarks(im, landmarks):
im = im.copy()
for idx, point in enumerate(landmarks):
pos = (point[0, 0], point[0, 1])
cv2.putText(im,
str(idx),
pos,
fontFace=cv2.FONT_HERSHEY_SCRIPT_SIMPLEX,
fontScale=0.4,
color=(0, 0, 255))
cv2.circle(im, pos, 3, color=(0, 255, 255))
return im
testsize = 48 # 测试图大小
data_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
net = simpleconv3()
net.eval()
modelpath = "./models/model.ckpt" # 模型路径
net.load_state_dict(
torch.load(modelpath, map_location=lambda storage, loc: storage))
# 一次测试一个文件
img_path = "./Emotion_Recognition_File/find_face_img/"
imagepaths = os.listdir(img_path) # 图像文件夹
for imagepath in imagepaths:
im = cv2.imread(os.path.join(img_path, imagepath), 1)
try:
rects = cascade.detectMultiScale(im, 1.3, 5)
x, y, w, h = rects[0]
rect = dlib.rectangle(int(x), int(y), int(x + w), int(y + h))
landmarks = np.matrix([[p.x, p.y]
for p in predictor(im, rect).parts()])
except:
# print("没有检测到人脸")
continue # 没有检测到人脸
xmin = 10000
xmax = 0
ymin = 10000
ymax = 0
for i in range(48, 67):
x = landmarks[i, 0]
y = landmarks[i, 1]
if x < xmin:
xmin = x
if x > xmax:
xmax = x
if y < ymin:
ymin = y
if y > ymax:
ymax = y
roiwidth = xmax - xmin
roiheight = ymax - ymin
roi = im[ymin:ymax, xmin:xmax, 0:3]
if roiwidth > roiheight:
dstlen = 1.5 * roiwidth
else:
dstlen = 1.5 * roiheight
diff_xlen = dstlen - roiwidth
diff_ylen = dstlen - roiheight
newx = xmin
newy = ymin
imagerows, imagecols, channel = im.shape
if newx >= diff_xlen / 2 and newx + roiwidth + diff_xlen / 2 < imagecols:
newx = newx - diff_xlen / 2
elif newx < diff_xlen / 2:
newx = 0
else:
newx = imagecols - dstlen
if newy >= diff_ylen / 2 and newy + roiheight + diff_ylen / 2 < imagerows:
newy = newy - diff_ylen / 2
elif newy < diff_ylen / 2:
newy = 0
else:
newy = imagerows - dstlen
roi = im[int(newy):int(newy + dstlen), int(newx):int(newx + dstlen), 0:3]
roi = cv2.cvtColor(roi, cv2.COLOR_BGR2RGB)
roiresized = cv2.resize(roi,
(testsize, testsize)).astype(np.float32) / 255.0
imgblob = data_transforms(roiresized).unsqueeze(0)
imgblob.requires_grad = False
imgblob = Variable(imgblob)
torch.no_grad()
predict = F.softmax(net(imgblob))
print(predict)
index = np.argmax(predict.detach().numpy())
im_show = cv2.imread(os.path.join(img_path, imagepath), 1)
im_h, im_w, im_c = im_show.shape
pos_x = int(newx + dstlen)
pos_y = int(newy + dstlen)
font = cv2.FONT_HERSHEY_SIMPLEX
cv2.rectangle(im_show, (int(newx), int(newy)),
(int(newx + dstlen), int(newy + dstlen)), (0, 255, 255), 2)
if index == 0:
cv2.putText(im_show, 'none', (pos_x, pos_y), font, 1.5, (0, 0, 255), 2)
if index == 1:
cv2.putText(im_show, 'pout', (pos_x, pos_y), font, 1.5, (0, 0, 255), 2)
if index == 2:
cv2.putText(im_show, 'smile', (pos_x, pos_y), font, 1.5, (0, 0, 255), 2)
if index == 3:
cv2.putText(im_show, 'open', (pos_x, pos_y), font, 1.5, (0, 0, 255), 2)
# cv2.namedWindow('result', 0)
# cv2.imshow('result', im_show)
cv2.imwrite(os.path.join('results', imagepath), im_show)
# print(os.path.join('results', imagepath))
plt.imshow(im_show[:, :, ::-1]) # 这里需要交换通道,因为 matplotlib 保存图片的通道顺序是 RGB,而在 OpenCV 中是 BGR
plt.show()
边栏推荐
- centos7安装mysql8
- MySQL mass production of data
- 推荐系统:AB测试(AB Test)
- Linux下最新版MySQL 8.0的下载与安装(详细步骤)
- Different lower_case_table_names settings for server (‘1‘) and data dictionary (‘0‘) 解决方案
- el-input 只能输入整数(包括正数、负数、0)或者只能输入整数(包括正数、负数、0)和小数
- These services can't ali interview?Then don't go to, the basic notification, etc
- Niuke.com - Huawei Question Bank (100~108)
- Perfectly Clear QuickDesk & QuickServer图像校正优化工具
- Frog jumping steps (recursive and non-recursive) ------- Xiaolele walks the steps
猜你喜欢

Mac安装PHP开发环境

基于人脸的常见表情识别(2)——数据获取与整理

推荐系统-排序层:排序层架构【用户、物品特征处理步骤】
![Recommended system: cold start problem [user cold start, item cold start, system cold start]](/img/e1/c3f8c89616d63cd3af84e715e783f9.png)
Recommended system: cold start problem [user cold start, item cold start, system cold start]

PHP低代码开发引擎—表单设计

MySQL分组后取最大一条数据【最优解】

普通的int main(){}没有写return 0;会怎么样?

coming!Dongfang Selection brings goods to the live broadcast of Longjiang agricultural products

These services can't ali interview?Then don't go to, the basic notification, etc

Download and installation of the latest version of MySQL 8.0 under Linux (detailed steps)
随机推荐
MySQL eight-part text recitation version
MindSpore:数据处理问题
Face-based Common Expression Recognition (2) - Data Acquisition and Arrangement
ImportError:attempted relative import with no known parent package
PHP低代码开发引擎—表单设计
推荐系统:开源项目/工具【谷歌:TensorFlow Recommenders】【Facebook:TorchRec】【百度:Graph4Rec】【阿里:DeepRec和EasyRec】
MySQL database - views and indexes
Scala类中的属性
移动web开发01
推荐系统:AB测试(AB Test)
时间复杂度与空间复杂度
MySQL database --- Addition, deletion, modification and query of MySQL tables (advanced)
MySQL大批量造数据
MySQL复制表结构、表数据的方法
网络层协议------IP协议
刷题记录----字符串
MindSpore:【MindSpore1.1】Mindspore安装后验证出现cudaSetDevice failed错误
MySQL大总结
英文字母间隔突然增大(全角与半角转换)
Mapped Statements collection does not contain value for的解决方法