当前位置:网站首页>Based on the face of the common expression recognition - model building, training and testing
Based on the face of the common expression recognition - model building, training and testing
2022-07-30 20:01:00 【GodGump】
模型搭建与训练
数据接口准备
data_transforms = {
'train': transforms.Compose([
transforms.Scale(64),
transforms.RandomSizedCrop(48),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
]),
'val': transforms.Compose([
transforms.Scale(64),
transforms.CenterCrop(48),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
]),
}
data_dir = './train_val_data/'
image_datasets = {
x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x]) for x in ['train', 'val']}
dataloders = {
x: torch.utils.data.DataLoader(image_datasets[x],
batch_size=16,
shuffle=True,
num_workers=4) for x in ['train', 'val']}
模型定义
import torch.nn as nn
import torch.nn.functional as F
class simpleconv3(nn.Module):
def __init__(self):
super(simpleconv3,self).__init__()
self.conv1 = nn.Conv2d(3, 12, 3, 2)
self.bn1 = nn.BatchNorm2d(12)
self.conv2 = nn.Conv2d(12, 24, 3, 2)
self.bn2 = nn.BatchNorm2d(24)
self.conv3 = nn.Conv2d(24, 48, 3, 2)
self.bn3 = nn.BatchNorm2d(48)
self.fc1 = nn.Linear(48 * 5 * 5 , 1200)
self.fc2 = nn.Linear(1200 , 128)
self.fc3 = nn.Linear(128 , 4)
def forward(self , x):
x = F.relu(self.bn1(self.conv1(x)))
#print "bn1 shape",x.shape
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
x = x.view(-1 , 48 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
模型训练
#coding:utf8
from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
import torchvision
from torchvision import datasets, models, transforms
import time
import os
from tensorboardX import SummaryWriter
import torch.nn.functional as F
import numpy as np
import warnings
warnings.filterwarnings('ignore')
writer = SummaryWriter()
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
for phase in ['train', 'val']:
if phase == 'train':
scheduler.step()
model.train(True) # Set model to training mode
else:
model.train(False) # Set model to evaluate mode
running_loss = 0.0
running_corrects = 0.0
for data in dataloders[phase]:
inputs, labels = data
if use_gpu:
inputs = Variable(inputs.cuda())
labels = Variable(labels.cuda())
else:
inputs, labels = Variable(inputs), Variable(labels)
optimizer.zero_grad()
outputs = model(inputs)
_, preds = torch.max(outputs.data, 1)
loss = criterion(outputs, labels)
if phase == 'train':
loss.backward()
optimizer.step()
running_loss += loss.data.item()
running_corrects += torch.sum(preds == labels).item()
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects / dataset_sizes[phase]
if phase == 'train':
writer.add_scalar('data/trainloss', epoch_loss, epoch)
writer.add_scalar('data/trainacc', epoch_acc, epoch)
else:
writer.add_scalar('data/valloss', epoch_loss, epoch)
writer.add_scalar('data/valacc', epoch_acc, epoch)
print('{} Loss: {:.4f} Acc: {:.4f}'.format(
phase, epoch_loss, epoch_acc))
writer.export_scalars_to_json("./all_scalars.json")
writer.close()
return model
if __name__ == '__main__':
data_transforms = {
'train': transforms.Compose([
transforms.Scale(64),
transforms.RandomSizedCrop(48),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
]),
'val': transforms.Compose([
transforms.Scale(64),
transforms.CenterCrop(48),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
]),
}
data_dir = './Emotion_Recognition_File/train_val_data/' # 数据集所在的位置
image_datasets = {
x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x]) for x in ['train', 'val']}
dataloders = {
x: torch.utils.data.DataLoader(image_datasets[x],
batch_size=64,
shuffle=True if x=="train" else False,
num_workers=8) for x in ['train', 'val']}
dataset_sizes = {
x: len(image_datasets[x]) for x in ['train', 'val']}
use_gpu = torch.cuda.is_available()
print("是否使用 GPU", use_gpu)
modelclc = simpleconv3()
print(modelclc)
if use_gpu:
modelclc = modelclc.cuda()
criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.SGD(modelclc.parameters(), lr=0.1, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=100, gamma=0.1)
modelclc = train_model(model=modelclc,
criterion=criterion,
optimizer=optimizer_ft,
scheduler=exp_lr_scheduler,
num_epochs=10) # 这里可以调节训练的轮次
if not os.path.exists("models"):
os.mkdir('models')
torch.save(modelclc.state_dict(),'models/model.ckpt')
模型测试
# coding:utf8
import sys
import numpy as np
import cv2
import os
import dlib
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
import torchvision
from torchvision import datasets, models, transforms
import time
from PIL import Image
import torch.nn.functional as F
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
PREDICTOR_PATH = "./Emotion_Recognition_File/face_detect_model/shape_predictor_68_face_landmarks.dat"
predictor = dlib.shape_predictor(PREDICTOR_PATH)
cascade_path = './Emotion_Recognition_File/face_detect_model/haarcascade_frontalface_default.xml'
cascade = cv2.CascadeClassifier(cascade_path)
if not os.path.exists("results"):
os.mkdir("results")
def standardization(data):
mu = np.mean(data, axis=0)
sigma = np.std(data, axis=0)
return (data - mu) / sigma
def get_landmarks(im):
rects = cascade.detectMultiScale(im, 1.3, 5)
x, y, w, h = rects[0]
rect = dlib.rectangle(int(x), int(y), int(x + w), int(y + h))
return np.matrix([[p.x, p.y] for p in predictor(im, rect).parts()])
def annotate_landmarks(im, landmarks):
im = im.copy()
for idx, point in enumerate(landmarks):
pos = (point[0, 0], point[0, 1])
cv2.putText(im,
str(idx),
pos,
fontFace=cv2.FONT_HERSHEY_SCRIPT_SIMPLEX,
fontScale=0.4,
color=(0, 0, 255))
cv2.circle(im, pos, 3, color=(0, 255, 255))
return im
testsize = 48 # 测试图大小
data_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
net = simpleconv3()
net.eval()
modelpath = "./models/model.ckpt" # 模型路径
net.load_state_dict(
torch.load(modelpath, map_location=lambda storage, loc: storage))
# 一次测试一个文件
img_path = "./Emotion_Recognition_File/find_face_img/"
imagepaths = os.listdir(img_path) # 图像文件夹
for imagepath in imagepaths:
im = cv2.imread(os.path.join(img_path, imagepath), 1)
try:
rects = cascade.detectMultiScale(im, 1.3, 5)
x, y, w, h = rects[0]
rect = dlib.rectangle(int(x), int(y), int(x + w), int(y + h))
landmarks = np.matrix([[p.x, p.y]
for p in predictor(im, rect).parts()])
except:
# print("没有检测到人脸")
continue # 没有检测到人脸
xmin = 10000
xmax = 0
ymin = 10000
ymax = 0
for i in range(48, 67):
x = landmarks[i, 0]
y = landmarks[i, 1]
if x < xmin:
xmin = x
if x > xmax:
xmax = x
if y < ymin:
ymin = y
if y > ymax:
ymax = y
roiwidth = xmax - xmin
roiheight = ymax - ymin
roi = im[ymin:ymax, xmin:xmax, 0:3]
if roiwidth > roiheight:
dstlen = 1.5 * roiwidth
else:
dstlen = 1.5 * roiheight
diff_xlen = dstlen - roiwidth
diff_ylen = dstlen - roiheight
newx = xmin
newy = ymin
imagerows, imagecols, channel = im.shape
if newx >= diff_xlen / 2 and newx + roiwidth + diff_xlen / 2 < imagecols:
newx = newx - diff_xlen / 2
elif newx < diff_xlen / 2:
newx = 0
else:
newx = imagecols - dstlen
if newy >= diff_ylen / 2 and newy + roiheight + diff_ylen / 2 < imagerows:
newy = newy - diff_ylen / 2
elif newy < diff_ylen / 2:
newy = 0
else:
newy = imagerows - dstlen
roi = im[int(newy):int(newy + dstlen), int(newx):int(newx + dstlen), 0:3]
roi = cv2.cvtColor(roi, cv2.COLOR_BGR2RGB)
roiresized = cv2.resize(roi,
(testsize, testsize)).astype(np.float32) / 255.0
imgblob = data_transforms(roiresized).unsqueeze(0)
imgblob.requires_grad = False
imgblob = Variable(imgblob)
torch.no_grad()
predict = F.softmax(net(imgblob))
print(predict)
index = np.argmax(predict.detach().numpy())
im_show = cv2.imread(os.path.join(img_path, imagepath), 1)
im_h, im_w, im_c = im_show.shape
pos_x = int(newx + dstlen)
pos_y = int(newy + dstlen)
font = cv2.FONT_HERSHEY_SIMPLEX
cv2.rectangle(im_show, (int(newx), int(newy)),
(int(newx + dstlen), int(newy + dstlen)), (0, 255, 255), 2)
if index == 0:
cv2.putText(im_show, 'none', (pos_x, pos_y), font, 1.5, (0, 0, 255), 2)
if index == 1:
cv2.putText(im_show, 'pout', (pos_x, pos_y), font, 1.5, (0, 0, 255), 2)
if index == 2:
cv2.putText(im_show, 'smile', (pos_x, pos_y), font, 1.5, (0, 0, 255), 2)
if index == 3:
cv2.putText(im_show, 'open', (pos_x, pos_y), font, 1.5, (0, 0, 255), 2)
# cv2.namedWindow('result', 0)
# cv2.imshow('result', im_show)
cv2.imwrite(os.path.join('results', imagepath), im_show)
# print(os.path.join('results', imagepath))
plt.imshow(im_show[:, :, ::-1]) # 这里需要交换通道,因为 matplotlib 保存图片的通道顺序是 RGB,而在 OpenCV 中是 BGR
plt.show()
边栏推荐
- Linux下最新版MySQL 8.0的下载与安装(详细步骤)
- Linux download and install mysql5.7 version tutorial the most complete and detailed explanation
- 推荐系统-排序层:排序层架构【用户、物品特征处理步骤】
- HarmonyOS笔记-----------(三)
- Maxwell 一款简单易上手的实时抓取Mysql数据的软件
- Interviewer Ali: Describe to me the phenomenon of cache breakdown, and talk about your solution?
- PHP低代码开发平台 V5.0.7新版发布
- .eslintrc.js for musicApp
- 从离线到实时对客,湖仓一体释放全量数据价值
- 阿里面试这些微服务还不会?那还是别去了,基本等通知
猜你喜欢
Typora设置标题自动标号
Database Tuning - Database Tuning

MySQL数据库主从配置

ERROR 1045 (28000) Access denied for user ‘root‘@‘localhost‘解决方法

MySQL slow query optimization

【无标题】多集嵌套集合使不再有MultipleBagFetchException

MySQL database master-slave configuration

推荐系统-排序层-模型(一):Embedding + MLP(多层感知机)模型【Deep Crossing模型:经典的Embedding+MLP模型结构】

MySQL performance optimization (hardware, system configuration, table structure, SQL statements)

数据库索引:索引并不是万能药
随机推荐
ELK log analysis system
湖仓一体电商项目(四):项目数据种类与采集
MySQL database master-slave configuration
VBA runtime error '-2147217900 (80040e14): Automation error
Linux download and install mysql5.7 version tutorial the most complete and detailed explanation
Zabbix 5.0 Monitoring Tutorial (1)
MySQL大批量造数据
Download and installation of the latest version of MySQL 8.0 under Linux (detailed steps)
时间复杂度与空间复杂度
推荐系统:实时性【特征实时性:客户端实时特征(秒级,实时)、流处理平台(分钟级,近实时)、分布式批处理平台(小时/天级,非实时)】【模型实时性:在线学习、增量更新、全量更新】
MindSpore:【Resolve node failed】解析节点失败的问题
推荐系统-排序层:排序层架构【用户、物品特征处理步骤】
mysql8 installation under linux
普通的int main(){}没有写return 0;会怎么样?
LeetCode 0952. Calculate Maximum Component Size by Common Factor: Mapping / Union Search
M3SDA: Moment matching for multi-source domain adaptation
[hbuilder] cannot run some projects, open the terminal and cannot enter commands
KEIL问题:【keil Error: failed to execute ‘C:\Keil\ARM\ARMCC‘】
Encapsulates a console file selector based on inquirer
MySQL performance optimization (hardware, system configuration, table structure, SQL statements)