当前位置:网站首页>NMS原理及其代码实现
NMS原理及其代码实现
2022-08-05 00:15:00 【Le0v1n】
1. 为什么要用NMS
YOLOv3在预测阶段, 每个目标至少会生成3个proposals, 但一个目标一般只显示一个proposal, 因此需要对proposals进行去重,
这里去重的方法是NMS. 而NMS的筛选依据是IoU.
2. NMS的步骤
先对所有proposals进行置信度(confidence)的排序, 按照置信度的大小进行降序排序(从大到小排序).
将最大置信度的proposal ( p r o p o s a l m a x {\rm proposal_{max}} proposalmax)取出来, 与剩下的proposals( p r o p o s a l s r e s t {\rm proposals_{rest}} proposalsrest)进行IoU的计算, 这里的目的是筛选后面的proposals, 意思是说:
如果 p r o p o s a l s r e s t {\rm proposals_{rest}} proposalsrest中的某一个proposal与 p r o p o s a l m a x {\rm proposal_{max}} proposalmax之间计算得到的IoU的值小于设定的阈值( t h r e s h {\rm thresh} thresh), 那么就认为这个proposal是可用的(可保留的)
一旦有IoU > t h r e s h {\rm thresh} thresh的proposal, 我们则认为该proposal和 p r o p o s a l m a x {\rm proposal_{max}} proposalmax预测的是同一个目标(object), 因此该框就冗余了(因为它的置信度没有 p r o p o s a l m a x {\rm proposal_{max}} proposalmax的置信度高), 需要去除.
置信度最大的proposal是一定会保留的, 我们是在挑选剩余的proposals
p r o p o s a l m a x {\rm proposal_{max}} proposalmax与所有 p r o p o s a l s r e s t {\rm proposals_{rest}} proposalsrest的IoU计算和筛选完毕后, 置信度指针指向下一个(第二高置信度的proposal)
重复2, 3 -> 递归
NMS的核心主要事项就是: IoU > 阈值 的proposals需要去除, < 阈值的proposals则保留, 目的是不影响预测其他object的proposals受到影响.
3. NMS的代码实现
import torch
def iou(proposal, proposals, isMin=False):
"""计算proposals的IoU 在计算IoU时, 需要求二者交集和并集. 假设两个框的坐标分别为: (x_11, y_11, x_12, y_12)和(x_21, y2_1, x_22, y_22) 交集框的坐标: (max(x_11, x_21), max(y_11, y_21), min(x_12, x_22), min(y_12, y_22)) Args: proposal (_type_): 置信度最高的proposal -> [4] proposals (_type_): 剩余的proposals -> [N, 4] isMin (bool, optional): IoU的计算模式, 有两种: 1. (True) 交集 / 最小面积 2. (False -> Default) 交集 / 并集 Return: IoU (float): 返回proposal与proposals的IoU """
# 计算当前框的面积: proposal = [x, y, w, h]
box_area = (proposal[2] - proposal[0]) * (proposal[3] - proposal[1])
# 计算proposals中所有框的面积 proposals = [N, [x, y, w, h]]
boxes_area = (proposals[:, 2] - proposals[:, 0]) * (proposals[:, 3] - proposals[:, 1])
# 计算交集proposal和proposals的计算
xx_1 = torch.maximum(proposal[0], proposals[:, 0]) # 交集的左上角x坐标
yy_1 = torch.maximum(proposal[1], proposals[:, 1]) # 交集的左上角y坐标
xx_2 = torch.minimum(proposal[2], proposals[:, 2]) # 交集的右下角x坐标
yy_2 = torch.minimum(proposal[3], proposals[:, 3]) # 交集的右下角y坐标
# 特殊情况: 两个框没有挨着 -> 没有交集
w, h = torch.maximum(torch.Tensor([0]), xx_2 - xx_1), torch.maximum(torch.Tensor([0]), yy_2 - yy_1)
# 获取交集的框的面积
intersection_area = w * h
if isMin: # 如果一个框在另一框的内部
return intersection_area / torch.min(box_area, boxes_area)
else: # 两个框相交 -> 交集 / 并集
return intersection_area / (box_area + boxes_area - intersection_area)
def nms(proposals, thresh=0.3, isMin=False):
"""非极大值抑制用来去除冗余的proposals Args: proposals (torch.tensor): 网络推理得到的proposals -> [conf, x, y, w, h] thresh (float, optional): NMS筛选的阈值. Defaults to 0.3. isMin (bool, optional): IoU的计算方式, 默认为交集/并集. Defaults to False. """
# 根据proposals的置信度进行降序排序
sorted_proposals = proposals[proposals[:, 0].argsort(descending=True)]
# 定义一个ls, 用来保存需要保留的proposals
keep_boxes = []
while len(sorted_proposals) > 0:
# 取出置信度最高的proposal并存放到ls中
_box = sorted_proposals[0]
keep_boxes.append(_box)
if len(sorted_proposals) > 1:
# 取出剩余的proposals
_boxes = sorted_proposals[1:]
# 置信度最高的proposal与其他proposals进行IoU的计算
""" 需要注意的是, NMS在筛选的时候是保留IoU小于thresh的. 为什么? 两个proposal的IoU越小, 说明两个proposal框起来的对象越不一样, 别忘了, NMS是为了去重, 所以需要保留小于IoU的proposals torch.where(条件): 返回符合条件的索引 """
sorted_proposals = _boxes[torch.where(iou(_box, _boxes, isMin) < thresh)]
# 当剩下最后一个时候, 就不进行IoU计算了(自己与自己计算IoU没有意义)
else:
break
# 将ls转换为高维的tensor
return torch.stack(keep_boxes)
if __name__ == "__main__":
proposal = torch.tensor(data=[0, 0, 4, 4])
proposals = torch.tensor(data=[[4, 4, 5, 5], # 没有交集
[1, 1, 5, 5]]) # 有交集
print(iou(proposal, proposals)) # tensor([0.0000, 0.3913])
boxes = torch.tensor(data=[
[0.5, 1, 1, 10, 10],
[0.9, 1, 1, 11, 11], # 和上面那个很相似
[0.4, 8, 8, 12, 12] # 和上面两个都不相似
])
print(nms(boxes, thresh=0.1))
"""仅保留了2个 tensor([[ 0.9000, 1.0000, 1.0000, 11.0000, 11.0000], [ 0.4000, 8.0000, 8.0000, 12.0000, 12.0000]]) """
print(nms(boxes, thresh=0.3))
"""全部都保留了 tensor([[ 0.9000, 1.0000, 1.0000, 11.0000, 11.0000], [ 0.5000, 1.0000, 1.0000, 10.0000, 10.0000], [ 0.4000, 8.0000, 8.0000, 12.0000, 12.0000]]) """
```
边栏推荐
- 中日颜色风格
- Mysql_13 事务
- 看图识字,DELL SC4020 / SCv2000 控制器更换过程
- 刘润直播预告 | 顶级高手,如何创造财富
- RK3399平台开发系列讲解(内核调试篇)2.50、嵌入式产品启动速度优化
- 资深游戏建模师告知新手,游戏场景建模师必备软件有哪些?
- 3. Actual combat---crawl the result page corresponding to Baidu's specified entry (a simple page collector)
- 【七夕情人节特效】-- canvas实现满屏爱心
- KT148A语音芯片ic工作原理以及芯片的内部架构描述
- ansible学习笔记分享-含剧本示例
猜你喜欢
what?测试/开发程序员要被淘汰了?年龄40被砍到了32?一瞬间,有点缓不过神来......
【LeetCode】滑动窗口题解汇总
[Happy Qixi Festival] How does Nacos realize the service registration function?
2022 Niu Ke Summer Multi-School Training Camp 5 (BCDFGHK)
Mysql_14 存储引擎
资深游戏建模师告知新手,游戏场景建模师必备软件有哪些?
入门3D游戏建模师知识必备
[CVA Valuation Training Camp] Financial Modeling Guide - Lecture 1
jenkins send mail system configuration
Huggingface入门篇 II (QA)
随机推荐
KT148A语音芯片ic工作原理以及芯片的内部架构描述
对写作的一些感悟
入门3D游戏建模师知识必备
一、爬虫基本概念
TinyMCE禁用转义
阅读笔记:如何理解DevOps?
SQL关联表更新
Statistical words (DAY 101) Huazhong University of Science and Technology postgraduate examination questions
情侣牵手[贪心 & 抽象]
typeScript - Partially apply a function
MAUI Blazor 权限经验分享 (定位,使用相机)
关于使用read table 语句
mysql基础
Huggingface入门篇 II (QA)
日志(logging模块)
导入JankStats检测卡帧库遇到问题记录
数据类型-整型(C语言)
网站最终产品页使用单一入口还是多入口?
统计单词(DAY 101)华中科技大学考研机试题
E - Many Operations (按位考虑 + dp思想记录操作后的结果