当前位置:网站首页>YOLOv7-Pose尝鲜,基于YOLOv7的关键点模型测评
YOLOv7-Pose尝鲜,基于YOLOv7的关键点模型测评
2022-08-04 18:01:00 【pogg_】
【前言】
本文首发于GiantPandaCV,未经许可请勿转载!目前人体姿态估计总体分为Top-down和Bottom-up两种,与目标检测不同,无论是基于热力图或是基于检测器处理的关键点检测算法,都较为依赖计算资源,推理耗时略长,今年出现了以YOLO为基线的关键点检测器。玩过目标检测的童鞋都知道YOLO以及各种变种目前算是工业落地较多的一类检测器,其简单的设计思想,长期活跃的社区生态,使其始终占据着较高的话题度。
【演变】
在ECCV 2022和CVPRW 2022会议上,YoLo-Pose和KaPao(下称为yolo-like-pose)都基于流行的YOLO目标检测框架提出一种新颖的无热力图的方法,类似于很久以前谷歌使用回归计算关键点的思想,yolo-like-pose一不使用检测器进行二阶处理,二部使用热力图拼接,虽然是一种暴力回归关键点的检测算法,但在处理速度上具有一定优势。
kapao
去年11月,滑铁卢大学率先提出了 KaPao:Rethinking Keypoint Representations: Modeling Keypoints and Poses as Objects for Multi-Person Human Pose Estimation,基于YOLOv5进行关键点检测,该文章目前已被ECCV 2022接收,该算法所取得的性能如下:
paper:https://arxiv.org/abs/2111.08557
code:https://github.com/wmcnally/kapao
yolov5-pose
今年4月,yolo-pose也挂在了arvix,在论文中,通过调研发现 HeatMap 的方式普遍使用L1 Loss。然而,L1损失并不一定适合获得最佳的OKS。且由于HeatMap是概率图,因此在基于纯HeatMap的方法中不可能使用OKS作为loss,只有当回归到关键点位置时,OKS才能被用作损失函数。
因此,yolo-pose使用oks loss作为关键点的损失
相关代码在https://github.com/TexasInstruments/edgeai-yolov5/blob/yolo-pose/utils/loss.py也可见到:
if self.kpt_label:
#Direct kpt prediction
pkpt_x = ps[:, 6::3] * 2. - 0.5
pkpt_y = ps[:, 7::3] * 2. - 0.5
pkpt_score = ps[:, 8::3]
#mask
kpt_mask = (tkpt[i][:, 0::2] != 0)
lkptv += self.BCEcls(pkpt_score, kpt_mask.float())
#l2 distance based loss
#lkpt += (((pkpt-tkpt[i])*kpt_mask)**2).mean() #Try to make this loss based on distance instead of ordinary difference
#oks based loss
d = (pkpt_x-tkpt[i][:,0::2])**2 + (pkpt_y-tkpt[i][:,1::2])**2
s = torch.prod(tbox[i][:,-2:], dim=1, keepdim=True)
kpt_loss_factor = (torch.sum(kpt_mask != 0) + torch.sum(kpt_mask == 0))/torch.sum(kpt_mask != 0)
lkpt += kpt_loss_factor*((1 - torch.exp(-d/(s*(4*sigmas**2)+1e-9)))*kpt_mask).mean()
相关性能如下:
yolov7-pose
上个星期,YOLOv7的作者也放出了关于人体关键点检测的模型,该模型基于YOLOv7-w6,
目前作者提供了.pt文件和推理测试的脚本,有兴趣的童靴可以去看看,本文的重点更偏向于对yolov7-pose.pt进行onnx文件的抽取和推理。
【yolov7-pose + onnxruntime】
首先下载好官方的预训练模型,使用提供的脚本进行推理:
% weigths = torch.load('weights/yolov7-w6-pose.pt')
% image = cv2.imread('sample/pose.jpeg')
!python pose.py
一、yolov7-w6 VS yolov7-w6-pose:
- 首先看下yolov7-w6使用的检测头
- f f f 表示一共有四组不同尺度的检测头,分别为15×15,30×30,60×60,120×120,对应输出的节点为114,115,116,117
- nc对应coco的80个类别
- no表示 c l a s s . n u m + o b j + r e g = 80 + 1 + 4 = 85 class_.num+obj+reg = 80+1+4=85 class.num+obj+reg=80+1+4=85
- 再看看yolov7-w6-pose使用的检测头:
上述重复的地方不累述,讲几个点:
- n c = 1 nc=1 nc=1 代表person一个类别
- nkpt表示人体的17个关键点
- n o = 17 ∗ 3 = n k p t ∗ ( x + y + o b j ) = 57 no=17*3=nkpt*(x+y+obj)=57 no=17∗3=nkpt∗(x+y+obj)=57
二、修改export脚本
如果直接使用export脚本进行onnx的抽取一定报错,在上一节我们已经看到pose.pt模型使用的检测头为IKeypoint,那么脚本需要进行相应更改:
在export.py的这个位置插入:
# 原代码:
for k, m in model.named_modules():
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
if isinstance(m, models.common.Conv): # assign export-friendly activations
if isinstance(m.act, nn.Hardswish):
m.act = Hardswish()
elif isinstance(m.act, nn.SiLU):
m.act = SiLU()
model.model[-1].export = not opt.grid # set Detect() layer grid export
# 修改代码:
for k, m in model.named_modules():
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
if isinstance(m, models.common.Conv): # assign export-friendly activations
if isinstance(m.act, nn.Hardswish):
m.act = Hardswish()
elif isinstance(m.act, nn.SiLU):
m.act = SiLU()
elif isinstance(m, models.yolo.IKeypoint):
m.forward = m.forward_keypoint # assign forward (optional)
# 此处切换检测头
model.model[-1].export = not opt.grid # set Detect() layer grid export
forward_keypoint在原始的yolov7 repo源码中有,作者已经封装好,但估计是还没打算开放使用。
使用以下命令进行抽取:
python export.py --weights 'weights/yolov7-w6-pose.pt' --img-size 960 --simplify True
抽取后的onnx检测头:
三、onnxruntime推理
onnxruntime推理代码:
import onnxruntime
import matplotlib.pyplot as plt
import torch
import cv2
from torchvision import transforms
import numpy as np
from utils.datasets import letterbox
from utils.general import non_max_suppression_kpt
from utils.plots import output_to_keypoint, plot_skeleton_kpts
device = torch.device("cpu")
image = cv2.imread('sample/pose.jpeg')
image = letterbox(image, 960, stride=64, auto=True)[0]
image_ = image.copy()
image = transforms.ToTensor()(image)
image = torch.tensor(np.array([image.numpy()]))
print(image.shape)
sess = onnxruntime.InferenceSession('weights/yolov7-w6-pose.onnx')
out = sess.run(['output'], {
'images': image.numpy()})[0]
out = torch.from_numpy(out)
output = non_max_suppression_kpt(out, 0.25, 0.65, nc=1, nkpt=17, kpt_label=True)
output = output_to_keypoint(output)
nimg = image[0].permute(1, 2, 0) * 255
nimg = nimg.cpu().numpy().astype(np.uint8)
nimg = cv2.cvtColor(nimg, cv2.COLOR_RGB2BGR)
for idx in range(output.shape[0]):
plot_skeleton_kpts(nimg, output[idx, 7:].T, 3)
# matplotlib inline
plt.figure(figsize=(8, 8))
plt.axis('off')
plt.imshow(nimg)
plt.show()
plt.savefig("tmp")
推理效果几乎无损,但耗时会缩短一倍左右,另外有几个点:
- image = letterbox(image, 960, stride=64, auto=True)[0] 中stride指的是最大步长,yolov7-w6和yolov5s下采样多了一步,导致在8,16,32的基础上多了64的下采样步长
- output = non_max_suppression_kpt(out, 0.25, 0.65, nc=1, nkpt=17, kpt_label=True) ,nc 和 kpt_label 等信息在netron打印模型文件时可以看到
- 所得到的onnx相比原半精度模型大了将近三倍,后续排查原因
- yolov7-w6-pose极度吃显存,推理一张960×960的图像,需要2-4G的显存,训练更难以想象
边栏推荐
- Cholesterol-PEG-DBCO,CLS-PEG-DBCO,胆固醇-聚乙二醇-二苯基环辛炔科研试剂
- 【软件工程之美 - 专栏笔记】37 | 遇到线上故障,你和高手的差距在哪里?
- 谷歌开源芯片 180 纳米制造工艺
- 谁能解答?从mysql的binlog读取数据到kafka,但是数据类型有Insert,updata,
- 《机器学习理论到应用》电子书免费下载
- About the two architectures of ETL (ETL architecture and ELT architecture)
- Go 言 Go 语,一文看懂 Go 语言文件操作
- 全球电子产品需求放缓:三星越南工厂大幅压缩产能
- 对象实例化之后一定会存放在堆内存中?
- ptables基本语法使用规则
猜你喜欢
CAS:474922-26-4,DSPE-PEG-NH2,DSPE-PEG-amine,磷脂-聚乙二醇-氨基供应
2022 May 1 Mathematical Modeling Question C Explanation
巴比特 | 元宇宙每日必读:微博动漫将招募全球各类虚拟偶像并为其提供扶持...
Matlab drawing 1
LVS+Keepalived群集
对象实例化之后一定会存放在堆内存中?
开发那些事儿:如何通过EasyCVR平台获取监控现场的人流量统计数据?
基于激励的需求响应计划下弹性微电网的短期可靠性和经济性评估(Matlab代码实现)
Cholesterol-PEG-DBCO,CLS-PEG-DBCO,胆固醇-聚乙二醇-二苯基环辛炔科研试剂
防火墙基础之防火墙做出口设备安全防护
随机推荐
华为云计算HCIE之oceanstor仿真器的安装教程
嵌入式开发:使用堆栈保护提高代码完整性
asp dotnet core 通过图片统计 csdn 用户访问
树莓派连接蓝牙音箱
EasyCVR calls the cloud recording API and returns an error and no recording file is generated. What is the reason?
Google Earth Engine APP - one-click online viewing of global images from 1984 to this year and loading an image analysis at the same time
公司自用的国产API管理神器
区间贪心(区间合并)
路由懒加载
darknet source code reading notes-02-list.h and lish.c
Route lazy loading
【日记】nodejs构建API框架以及RESTful API 和 JSON-RPC的取舍
离散化求前缀和
小程序经典案例
Understanding of margin collapse and coincidence
Thrift IDL Sample File
Documentary on Security Reinforcement of Network Range Monitoring System (1)—SSL/TLS Encrypted Transmission of Log Data
Matlab画图1
golang安装和基础配置
clickhouse online and offline table