当前位置:网站首页>HRNet-Facial-Landmark-Detection 训练自己数据集
HRNet-Facial-Landmark-Detection 训练自己数据集
2022-07-29 23:03:00 【J ..】

This is the official code of High-Resolution Representations for Facial Landmark Detection. We extend the high-resolution representation (HRNet) [1] by augmenting the high-resolution representation by aggregating the (upsampled) representations from all the parallel convolutions, leading to stronger representations. The output representations are fed into classifier. We evaluate our methods on four datasets, COFW, AFLW, WFLW and 300W.

Git : https://github.com/HRNet/HRNet-Facial-Landmark-Detection
Train
准备训练数据集 (数据格式任意,方便读取即可) 参考 300W 数据集
experiments 目录下新建配置文件 eg: experiments/own/own_hrnet_w18.yaml
- 修改 DATASET中 ,ROOT、TRAINSET、TESTSET 目录路径 ,DATASET: 数据集名称
- 修改MODEL中 ,NUM_JOINTS: 对应自己训练集特征点数
DATASET: DATASET: OWN ROOT: '../data/own/images' TRAINSET: '../data/own/train.json' TESTSET: '../data/own/val.json' FLIP: true SCALE_FACTOR: 0.25 ROT_FACTOR: 30 MODEL: NAME: 'hrnet' NUM_JOINTS: 37 // 根据自己数据集特征点数量 INIT_WEIGHTS: true PRETRAINED: 'hrnetv2_pretrained/hrnetv2_w18_imagenet_pretrained.pth'
lib/datasets 中 新建 own.py 用于根据自己数据格式读取数据 拷贝 face300w.py内容, 修改
类名和__getitem__方法.- center,scale 计算公式
scale = max(w, h) / 200 center_w = (x1 + x2) / 2 center_h = (y1 + y2) / 2 - 根据自己格式读取. 我生成的是json格式.
def calCenterScale(self, bbox): w = bbox[2] - bbox[0] h = bbox[3] - bbox[1] center_w = (bbox[0] + bbox[2]) / 2.0 center_h = (bbox[1] + bbox[3]) / 2.0 scale = round((max(w, h) / 200.0), 2) return center_w, center_h, scale def __getitem__(self, idx): image_path = os.path.join(self.data_root, self.landmarks_frame[idx]["image_path"]) bbox = self.landmarks_frame[idx]['bbox'] center_w, center_h, scale = self.calCenterScale(bbox) center = torch.Tensor([center_w, center_h]) pts = np.array(self.landmarks_frame[idx]["keypoints"]) pts = pts.astype('float').reshape(-1, 2) ... - 修改 lib/datasets/init.py 增加自己的数据集名称 (yaml 中设置的名称)
from .aflw import AFLW from .cofw import COFW from .face300w import Face300W from .wflw import WFLW from .own import Own __all__ = ['AFLW', 'COFW', 'Face300W', 'WFLW', 'OWN', 'get_dataset'] def get_dataset(config): if config.DATASET.DATASET == 'AFLW': return AFLW elif config.DATASET.DATASET == 'COFW': return COFW elif config.DATASET.DATASET == '300W': return Face300W elif config.DATASET.DATASET == 'WFLW': return WFLW elif config.DATASET.DATASET == 'OWN': return Own else: raise NotImplemented()
- center,scale 计算公式
修改 lib/core/evaluation.py
compute_nme方法 ,增加自己的特征点数 取两个眼角下标。def compute_nme(preds, meta): targets = meta['pts'] preds = preds.numpy() target = targets.cpu().numpy() N = preds.shape[0] L = preds.shape[1] rmse = np.zeros(N) for i in range(N): pts_pred, pts_gt = preds[i,], target[i,] if L == 19: # aflw interocular = meta['box_size'][i] elif L == 29: # cofw interocular = np.linalg.norm(pts_gt[8,] - pts_gt[9,]) elif L == 68: # 300w # interocular interocular = np.linalg.norm(pts_gt[36,] - pts_gt[45,]) elif L == 98: interocular = np.linalg.norm(pts_gt[60,] - pts_gt[72,]) elif L == 37: interocular = np.linalg.norm(pts_gt[0,] - pts_gt[15,]) else: raise ValueError('Number of landmarks is wrong') rmse[i] = np.sum(np.linalg.norm(pts_pred - pts_gt, axis=1)) / (interocular * L) return rmse修改 utils/transforms.py 中
fliplr_joints方法 ( FLIP=false 无需改 )
** 据自己的特征点标注下标,如果从下标0开始标注,不需要 -1 ,类似 WFLW 数据集train.py 修改成自己的yaml,开始训练即可
Epoch: [0][0/916] Time 18.342s (18.342s) Speed 0.9 samples/s Data 14.961s (14.961s) Loss 0.00214 (0.00214) Epoch: [0][50/916] Time 0.542s (0.880s) Speed 29.5 samples/s Data 0.000s (0.294s) Loss 0.00076 (0.00085) Epoch: [0][100/916] Time 0.537s (0.708s) Speed 29.8 samples/s Data 0.000s (0.148s) Loss 0.00074 (0.00080) Epoch: [0][150/916] Time 0.530s (0.650s) Speed 30.2 samples/s Data 0.000s (0.099s) Loss 0.00075 (0.00079) Epoch: [0][200/916] Time 0.531s (0.621s) Speed 30.1 samples/s Data 0.001s (0.075s) Loss 0.00074 (0.00077) Epoch: [0][250/916] Time 0.532s (0.603s) Speed 30.1 samples/s Data 0.000s (0.060s) Loss 0.00072 (0.00077) Epoch: [0][300/916] Time 0.525s (0.592s) Speed 30.5 samples/s Data 0.000s (0.050s) Loss 0.00073 (0.00076) Epoch: [0][350/916] Time 0.541s (0.583s) Speed 29.6 samples/s Data 0.000s (0.043s) Loss 0.00071 (0.00075) Epoch: [0][400/916] Time 0.536s (0.577s) Speed 29.9 samples/s Data 0.000s (0.038s) Loss 0.00067 (0.00074) Epoch: [0][450/916] Time 0.534s (0.572s) Speed 30.0 samples/s Data 0.000s (0.034s) Loss 0.00057 (0.00073) Epoch: [0][500/916] Time 0.534s (0.568s) Speed 30.0 samples/s Data 0.000s (0.030s) Loss 0.00056 (0.00072) Epoch: [0][550/916] Time 0.528s (0.565s) Speed 30.3 samples/s Data 0.000s (0.027s) Loss 0.00055 (0.00071) Epoch: [0][600/916] Time 0.533s (0.562s) Speed 30.0 samples/s Data 0.001s (0.025s) Loss 0.00053 (0.00069) Epoch: [0][650/916] Time 0.528s (0.560s) Speed 30.3 samples/s Data 0.000s (0.023s) Loss 0.00051 (0.00068) Epoch: [0][700/916] Time 0.535s (0.558s) Speed 29.9 samples/s Data 0.000s (0.022s) Loss 0.00050 (0.00067) Epoch: [0][750/916] Time 0.537s (0.556s) Speed 29.8 samples/s Data 0.000s (0.020s) Loss 0.00053 (0.00066) Epoch: [0][800/916] Time 0.532s (0.555s) Speed 30.1 samples/s Data 0.000s (0.019s) Loss 0.00047 (0.00065) Epoch: [0][850/916] Time 0.531s (0.554s) Speed 30.1 samples/s Data 0.000s (0.018s) Loss 0.00051 (0.00064) Epoch: [0][900/916] Time 0.526s (0.552s) Speed 30.4 samples/s Data 0.000s (0.017s) Loss 0.00054 (0.00063) Train Epoch 0 time:0.5524 loss:0.0006 nme:0.3472 best: True Test Epoch 0 time:0.3146 loss:0.0005 nme:0.1605 [008]:0.8482 [010]:0.5162 => saving checkpoint to output\OWN\own_hrnet_w18
验证
- 验证可以参考这个帖子自行修改一下,感觉挺方便测试。 https://github.com/HRNet/HRNet-Facial-Landmark-Detection/issues/21
# ------------------------------------------------------------------------------ # Created by Gaofeng([email protected]) # ------------------------------------------------------------------------------ import os import argparse import torch import torch.nn as nn import torch.backends.cudnn as cudnn import sys import cv2 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) import lib.models as models from lib.config import config, update_config from PIL import Image import numpy as np from lib.utils.transforms import crop from lib.core.evaluation import decode_preds def parse_args(): parser = argparse.ArgumentParser(description='Train Face Alignment') parser.add_argument('--cfg', default='experiments/300w/face_alignment_300w_hrnet_w18.yaml', help='experiment configuration filename', type=str) parser.add_argument('--model-file', help='model parameters', default='HR18-300W.pth', type=str) parser.add_argument('--imagepath', help='Path of the image to be detected', default='111.jpg', type=str) parser.add_argument('--face', nargs='+', type=float, default=[911, 1281, 1254, 1731], help='The coordinate [x1,y1,x2,y2] of a face') args = parser.parse_args() update_config(config, args) return args def prepare_input(image, bbox, image_size): """ :param image:The path to the image to be detected :param bbox:The bbox of target face :param image_size: refers to config file :return: """ scale = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) / 200 center_w = (bbox[0] + bbox[2]) / 2 center_h = (bbox[1] + bbox[3]) / 2 center = torch.Tensor([center_w, center_h]) scale *= 1.25 img = np.array(Image.open(image).convert('RGB'), dtype=np.float32) mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) std = np.array([0.229, 0.224, 0.225], dtype=np.float32) img = crop(img, center, scale, image_size, rot=0) img = img.astype(np.float32) img = (img / 255.0 - mean) / std img = img.transpose([2, 0, 1]) img = torch.Tensor(img) img = img.unsqueeze(0) return img, center, scale def main(): args = parse_args() cudnn.benchmark = config.CUDNN.BENCHMARK cudnn.determinstic = config.CUDNN.DETERMINISTIC cudnn.enabled = config.CUDNN.ENABLED config.defrost() config.MODEL.INIT_WEIGHTS = False config.freeze() model = models.get_face_alignment_net(config) if config.GPUS is list: gpus = list(config.GPUS) else: gpus = [config.GPUS] model = nn.DataParallel(model, device_ids=gpus).cuda() # load model state_dict = torch.load(args.model_file) model.load_state_dict(state_dict) model.eval() inp, center, scale = prepare_input(args.imagepath, args.face, config.MODEL.IMAGE_SIZE) output = model(inp) score_map = output.data.cpu() preds = decode_preds(score_map, center, scale, [64, 64]) preds = preds.numpy() cv2.namedWindow('test', 0) img_once = cv2.imread(args.imagepath) for i in preds[0, :, :]: cv2.circle(img_once, tuple(list(int(p) for p in i.tolist())), 2, (255, 255, 0), 1) cv2.imshow('test', img_once) if cv2.waitKey(0) == 27: cv2.destroyAllWindows() if __name__ == '__main__': main()
END
- 感兴趣的兄弟姐妹,可以参考参考,有问题欢迎指正交流。

边栏推荐
- 【C语言入门】ZZULIOJ 1036-1040
- 8万字带你入门Rust
- 单片机ds1302时钟程序(51单片机液晶显示程序)
- DNA偶联二维过渡金属硫化物|DNA修饰贵金属纳米颗粒|使用方法
- Foxmail是什么邮箱?
- 分支语句那些事儿(上)~~~~看完少走两月弯路!!
- 【技术规划】描绘未来第 4 部分:技术路线图
- 邮件服务器配置「建议收藏」
- The sequence table of the linear table (the dry goods are full of sharing ~ contains all the function codes of the sequence table~
- C语言实现扫雷(9*9)游戏——详解
猜你喜欢

jenkins use and maintenance

真offer收割机 第二弹~大厂如何考察候选人?(附答案详解)

浅析即时通讯移动端开发DNS域名劫持等杂症
![[leetcode] 75. Color classification (medium) (double pointer, in-situ modification)](/img/0e/e4ed76902194755a3b075a73f272f3.png)
[leetcode] 75. Color classification (medium) (double pointer, in-situ modification)

Topics in Dynamic Programming

How to realize object selection in canvas (5)

The Sandbox 与 Gravity 达成合作,将《RO仙境传说》带入元宇宙

@Accessors 注解详解

线上无序的

Analysis of miscellaneous diseases such as DNS domain name hijacking in instant messaging mobile terminal development
随机推荐
【openlayers】地图【二】
一文读懂Elephant Swap,为何为ePLATO带来如此高的溢价?
DNA脱氧核糖核酸修饰四氧化三铁|DNA修饰氧化锌|使用方法
【企业架构框架】是什么让 TOGAF 10 成为有价值的贡献
DNA修饰碳纳米管|DNA修饰单层二硫化钼|DNA修饰二硫化钨(注意事项)
kaniko --customPlatform parameter: support image construction of different platforms (eg: arm, etc.)
【技术规划】描绘未来第 4 部分:技术路线图
Professor Lu Shouqun from COPU was invited to give a speech at ApacheCon Asia
邮件服务器配置「建议收藏」
cached_network_image crashes with multiple images
kaniko --customPlatform参数:支持不同平台的镜像构建(如:arm等)
2022年最新甘肃建筑施工焊工(建筑特种作业)模拟题库及答案解析
@Accessors 注解详解
【leetcode】50. Pow(x, n)(中等)(快速幂)
Win7x64中使用PowerDesigner连接Oralce数据库报“[Oracle][ODBC][Ora]ORA-12154:TNS:无法解析指定的连接标识符”错误解决方法
JetsonNano学习(五)JetsonNano 安装 PyTorch 及 Torchvision
JVM 上数据处理语言的竞争:Kotlin, Scala 和 SPL
【面试:并发篇29:多线程:volatile】原理
【C语言】链表详解(无头单向非循环)
一个print函数,挺会玩啊?