当前位置:网站首页>HRNet-Facial-Landmark-Detection 训练自己数据集
HRNet-Facial-Landmark-Detection 训练自己数据集
2022-07-29 23:03:00 【J ..】
This is the official code of High-Resolution Representations for Facial Landmark Detection. We extend the high-resolution representation (HRNet) [1] by augmenting the high-resolution representation by aggregating the (upsampled) representations from all the parallel convolutions, leading to stronger representations. The output representations are fed into classifier. We evaluate our methods on four datasets, COFW, AFLW, WFLW and 300W.
Git : https://github.com/HRNet/HRNet-Facial-Landmark-Detection
Train
准备训练数据集 (数据格式任意,方便读取即可) 参考 300W 数据集
experiments 目录下新建配置文件 eg: experiments/own/own_hrnet_w18.yaml
- 修改 DATASET中 ,ROOT、TRAINSET、TESTSET 目录路径 ,DATASET: 数据集名称
- 修改MODEL中 ,NUM_JOINTS: 对应自己训练集特征点数
DATASET: DATASET: OWN ROOT: '../data/own/images' TRAINSET: '../data/own/train.json' TESTSET: '../data/own/val.json' FLIP: true SCALE_FACTOR: 0.25 ROT_FACTOR: 30 MODEL: NAME: 'hrnet' NUM_JOINTS: 37 // 根据自己数据集特征点数量 INIT_WEIGHTS: true PRETRAINED: 'hrnetv2_pretrained/hrnetv2_w18_imagenet_pretrained.pth'
lib/datasets 中 新建 own.py 用于根据自己数据格式读取数据 拷贝 face300w.py内容, 修改
类名
和__getitem__
方法.- center,scale 计算公式
scale = max(w, h) / 200 center_w = (x1 + x2) / 2 center_h = (y1 + y2) / 2
- 根据自己格式读取. 我生成的是json格式.
def calCenterScale(self, bbox): w = bbox[2] - bbox[0] h = bbox[3] - bbox[1] center_w = (bbox[0] + bbox[2]) / 2.0 center_h = (bbox[1] + bbox[3]) / 2.0 scale = round((max(w, h) / 200.0), 2) return center_w, center_h, scale def __getitem__(self, idx): image_path = os.path.join(self.data_root, self.landmarks_frame[idx]["image_path"]) bbox = self.landmarks_frame[idx]['bbox'] center_w, center_h, scale = self.calCenterScale(bbox) center = torch.Tensor([center_w, center_h]) pts = np.array(self.landmarks_frame[idx]["keypoints"]) pts = pts.astype('float').reshape(-1, 2) ...
- 修改 lib/datasets/init.py 增加自己的数据集名称 (yaml 中设置的名称)
from .aflw import AFLW from .cofw import COFW from .face300w import Face300W from .wflw import WFLW from .own import Own __all__ = ['AFLW', 'COFW', 'Face300W', 'WFLW', 'OWN', 'get_dataset'] def get_dataset(config): if config.DATASET.DATASET == 'AFLW': return AFLW elif config.DATASET.DATASET == 'COFW': return COFW elif config.DATASET.DATASET == '300W': return Face300W elif config.DATASET.DATASET == 'WFLW': return WFLW elif config.DATASET.DATASET == 'OWN': return Own else: raise NotImplemented()
- center,scale 计算公式
修改 lib/core/evaluation.py
compute_nme
方法 ,增加自己的特征点数 取两个眼角下标。def compute_nme(preds, meta): targets = meta['pts'] preds = preds.numpy() target = targets.cpu().numpy() N = preds.shape[0] L = preds.shape[1] rmse = np.zeros(N) for i in range(N): pts_pred, pts_gt = preds[i,], target[i,] if L == 19: # aflw interocular = meta['box_size'][i] elif L == 29: # cofw interocular = np.linalg.norm(pts_gt[8,] - pts_gt[9,]) elif L == 68: # 300w # interocular interocular = np.linalg.norm(pts_gt[36,] - pts_gt[45,]) elif L == 98: interocular = np.linalg.norm(pts_gt[60,] - pts_gt[72,]) elif L == 37: interocular = np.linalg.norm(pts_gt[0,] - pts_gt[15,]) else: raise ValueError('Number of landmarks is wrong') rmse[i] = np.sum(np.linalg.norm(pts_pred - pts_gt, axis=1)) / (interocular * L) return rmse
修改 utils/transforms.py 中
fliplr_joints
方法 ( FLIP=false 无需改 )
** 据自己的特征点标注下标,如果从下标0开始标注,不需要 -1 ,类似 WFLW 数据集train.py 修改成自己的yaml,开始训练即可
Epoch: [0][0/916] Time 18.342s (18.342s) Speed 0.9 samples/s Data 14.961s (14.961s) Loss 0.00214 (0.00214) Epoch: [0][50/916] Time 0.542s (0.880s) Speed 29.5 samples/s Data 0.000s (0.294s) Loss 0.00076 (0.00085) Epoch: [0][100/916] Time 0.537s (0.708s) Speed 29.8 samples/s Data 0.000s (0.148s) Loss 0.00074 (0.00080) Epoch: [0][150/916] Time 0.530s (0.650s) Speed 30.2 samples/s Data 0.000s (0.099s) Loss 0.00075 (0.00079) Epoch: [0][200/916] Time 0.531s (0.621s) Speed 30.1 samples/s Data 0.001s (0.075s) Loss 0.00074 (0.00077) Epoch: [0][250/916] Time 0.532s (0.603s) Speed 30.1 samples/s Data 0.000s (0.060s) Loss 0.00072 (0.00077) Epoch: [0][300/916] Time 0.525s (0.592s) Speed 30.5 samples/s Data 0.000s (0.050s) Loss 0.00073 (0.00076) Epoch: [0][350/916] Time 0.541s (0.583s) Speed 29.6 samples/s Data 0.000s (0.043s) Loss 0.00071 (0.00075) Epoch: [0][400/916] Time 0.536s (0.577s) Speed 29.9 samples/s Data 0.000s (0.038s) Loss 0.00067 (0.00074) Epoch: [0][450/916] Time 0.534s (0.572s) Speed 30.0 samples/s Data 0.000s (0.034s) Loss 0.00057 (0.00073) Epoch: [0][500/916] Time 0.534s (0.568s) Speed 30.0 samples/s Data 0.000s (0.030s) Loss 0.00056 (0.00072) Epoch: [0][550/916] Time 0.528s (0.565s) Speed 30.3 samples/s Data 0.000s (0.027s) Loss 0.00055 (0.00071) Epoch: [0][600/916] Time 0.533s (0.562s) Speed 30.0 samples/s Data 0.001s (0.025s) Loss 0.00053 (0.00069) Epoch: [0][650/916] Time 0.528s (0.560s) Speed 30.3 samples/s Data 0.000s (0.023s) Loss 0.00051 (0.00068) Epoch: [0][700/916] Time 0.535s (0.558s) Speed 29.9 samples/s Data 0.000s (0.022s) Loss 0.00050 (0.00067) Epoch: [0][750/916] Time 0.537s (0.556s) Speed 29.8 samples/s Data 0.000s (0.020s) Loss 0.00053 (0.00066) Epoch: [0][800/916] Time 0.532s (0.555s) Speed 30.1 samples/s Data 0.000s (0.019s) Loss 0.00047 (0.00065) Epoch: [0][850/916] Time 0.531s (0.554s) Speed 30.1 samples/s Data 0.000s (0.018s) Loss 0.00051 (0.00064) Epoch: [0][900/916] Time 0.526s (0.552s) Speed 30.4 samples/s Data 0.000s (0.017s) Loss 0.00054 (0.00063) Train Epoch 0 time:0.5524 loss:0.0006 nme:0.3472 best: True Test Epoch 0 time:0.3146 loss:0.0005 nme:0.1605 [008]:0.8482 [010]:0.5162 => saving checkpoint to output\OWN\own_hrnet_w18
验证
- 验证可以参考这个帖子自行修改一下,感觉挺方便测试。 https://github.com/HRNet/HRNet-Facial-Landmark-Detection/issues/21
# ------------------------------------------------------------------------------ # Created by Gaofeng([email protected]) # ------------------------------------------------------------------------------ import os import argparse import torch import torch.nn as nn import torch.backends.cudnn as cudnn import sys import cv2 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) import lib.models as models from lib.config import config, update_config from PIL import Image import numpy as np from lib.utils.transforms import crop from lib.core.evaluation import decode_preds def parse_args(): parser = argparse.ArgumentParser(description='Train Face Alignment') parser.add_argument('--cfg', default='experiments/300w/face_alignment_300w_hrnet_w18.yaml', help='experiment configuration filename', type=str) parser.add_argument('--model-file', help='model parameters', default='HR18-300W.pth', type=str) parser.add_argument('--imagepath', help='Path of the image to be detected', default='111.jpg', type=str) parser.add_argument('--face', nargs='+', type=float, default=[911, 1281, 1254, 1731], help='The coordinate [x1,y1,x2,y2] of a face') args = parser.parse_args() update_config(config, args) return args def prepare_input(image, bbox, image_size): """ :param image:The path to the image to be detected :param bbox:The bbox of target face :param image_size: refers to config file :return: """ scale = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 200 center_w = (bbox[0] + bbox[2]) / 2 center_h = (bbox[1] + bbox[3]) / 2 center = torch.Tensor([center_w, center_h]) scale *= 1.25 img = np.array(Image.open(image).convert('RGB'), dtype=np.float32) mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) std = np.array([0.229, 0.224, 0.225], dtype=np.float32) img = crop(img, center, scale, image_size, rot=0) img = img.astype(np.float32) img = (img / 255.0 - mean) / std img = img.transpose([2, 0, 1]) img = torch.Tensor(img) img = img.unsqueeze(0) return img, center, scale def main(): args = parse_args() cudnn.benchmark = config.CUDNN.BENCHMARK cudnn.determinstic = config.CUDNN.DETERMINISTIC cudnn.enabled = config.CUDNN.ENABLED config.defrost() config.MODEL.INIT_WEIGHTS = False config.freeze() model = models.get_face_alignment_net(config) if config.GPUS is list: gpus = list(config.GPUS) else: gpus = [config.GPUS] model = nn.DataParallel(model, device_ids=gpus).cuda() # load model state_dict = torch.load(args.model_file) model.load_state_dict(state_dict) model.eval() inp, center, scale = prepare_input(args.imagepath, args.face, config.MODEL.IMAGE_SIZE) output = model(inp) score_map = output.data.cpu() preds = decode_preds(score_map, center, scale, [64, 64]) preds = preds.numpy() cv2.namedWindow('test', 0) img_once = cv2.imread(args.imagepath) for i in preds[0, :, :]: cv2.circle(img_once, tuple(list(int(p) for p in i.tolist())), 2, (255, 255, 0), 1) cv2.imshow('test', img_once) if cv2.waitKey(0) == 27: cv2.destroyAllWindows() if __name__ == '__main__': main()
END
- 感兴趣的兄弟姐妹,可以参考参考,有问题欢迎指正交流。
边栏推荐
- 消失的两个月......
- 线上无序的
- 【MySQL系列】 MySQL表的增删改查(进阶)
- MySQL Interview Questions: Detailed Explanation of User Amount Recharge Interview Questions
- html+css+php+mysql实现注册+登录+修改密码(附完整代码)
- 7.联合索引(最左前缀原则)
- Single chip ds1302 clock program (51 single chip liquid crystal display program)
- 运动步数抽奖小程序开发
- 【面试:并发篇30:多线程:happen-before】
- 8万字带你入门Rust
猜你喜欢
随机推荐
互联网基石:TCP/IP四层模型,由浅入深直击原理!
文档贡献与写作必读-OpenHarmony开发者文档风格指南
Redis和MySQL如何保持数据一致性
J9 Number Theory: Why do we need Web3?
pnpm + workspace + changesets 构建你的 monorepo 工程
【企业架构】描绘未来第 3 部分:产品路线图
jenkins use and maintenance
Professor Lu Shouqun from COPU was invited to give a speech at ApacheCon Asia
【企业架构框架】是什么让 TOGAF 10 成为有价值的贡献
暴力递归到动态规划 03 (背包问题)
cached_network_image crashes with multiple images
地狱挖掘者系列#1
【openlayers】地图【一】
真offer收割机 第一弹~大厂如何考察候选人?(附答案详解)
Access Modbus TCP and Modbus RTU protocol devices using Neuron
【2023校招刷题】笔试及面试中常考知识点、手撕代码总结
SAP ABAP 守护进程的实现方式
cached_network_image 多个图片卡顿崩溃
【LeetCode-SQL每日一练】——2. 第二高的薪水
How to realize object selection in canvas (5)