当前位置:网站首页>损失函数:DIOU loss手写实现
损失函数:DIOU loss手写实现
2022-06-30 12:45:00 【计算机视觉-Archer】
下面是纯diou代码
'''
计算两个box的中心点距离d
'''
# d = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
d = math.sqrt((pred[:, -1] - target[:, -1]) ** 2 + (pred[:, -2] - target[:, -2]) ** 2)
# 左边x
pred_l = pred[:, -1] - pred[:, -1] / 2
target_l = target[:, -1] - target[:, -1] / 2
# 上边y
pred_t = pred[:, -2] - pred[:, -2] / 2
target_t = target[:, -2] - target[:, -2] / 2
# 右边x
pred_r = pred[:, -1] + pred[:, -1] / 2
target_r = target[:, -1] + target[:, -1] / 2
# 下边y
pred_b = pred[:, -2] + pred[:, -2] / 2
target_b = target[:, -2] + target[:, -2] / 2
'''
计算两个box的bound的对角线距离
'''
bound_l = torch.min(pred_l, target_l) # left
bound_r = torch.max(pred_r, target_r) # right
bound_t = torch.min(pred_t, target_t) # top
bound_b = torch.max(pred_b, target_b) # bottom
c = math.sqrt((bound_r - bound_l) ** 2 + (bound_b - bound_t) ** 2)
dloss = iou - (d ** 2) / (c ** 2)
loss = 1 - dloss.clamp(min=-1.0, max=1.0)
第一步 计算两个box的中心点距离d
首先要知道pred和target的输出结果是什么
pred[:,:2]第一个:表示多个图片,第二个:2表示前两个数值,代表矩形框中心点(Y,X)
pred[:,2:]第一个:表示多个图片,第二个2:表示后两个数值,代表矩形框长宽(H,W)
target[:,:2]同理,
d =

根据上面的分析来计算左右上下坐标lrtb

然后计算内部2个矩形的最小外接矩形的对角线长度c

d是两个预测矩形中心点的距离
下面接受各种极端情况
A 两个框中心对齐时候,d/c=0,iou可能0-1
A 两个框相距很远时,d/c=1,iou=0
所以d/c属于0-1
dloss=iou-d/c属于-1到1
因此设置loss=1-dloss属于0-2
class IOUloss(nn.Module):
def __init__(self, reduction="none", loss_type="iou"):
super(IOUloss, self).__init__()
self.reduction = reduction
self.loss_type = loss_type
def forward(self, pred, target):
assert pred.shape[0] == target.shape[0]
pred = pred.view(-1, 4)
target = target.view(-1, 4)
tl = torch.max(
(pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
)
# pred target都是[H,W,Y,X]
# (Y,X)-(H,W) 左上角
br = torch.min(
(pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
)
# (X,Y)+(H,W) 右下角
area_p = torch.prod(pred[:, 2:], 1) # HxW
area_g = torch.prod(target[:, 2:], 1)
en = (tl < br).type(tl.type()).prod(dim=1)
area_i = torch.prod(br - tl, 1) * en
area_u = area_p + area_g - area_i
iou = (area_i) / (area_u + 1e-16)
if self.loss_type == "iou":
loss = 1 - iou ** 2
elif self.loss_type == "giou":
c_tl = torch.min(
(pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
)
c_br = torch.max(
(pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
)
area_c = torch.prod(c_br - c_tl, 1)
giou = iou - (area_c - area_u) / area_c.clamp(1e-16)
loss = 1 - giou.clamp(min=-1.0, max=1.0)
# pred[:, :2] pred[:, 2:]
# (Y,X) (H,W)
# target[:, :2] target[:, 2:]
# (Y,X) (H,W)
elif self.loss_type == "diou":
'''
计算两个box的中心点距离d
'''
# d = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
d = math.sqrt((pred[:, -1] - target[:, -1]) ** 2 + (pred[:, -2] - target[:, -2]) ** 2)
# 左边x
pred_l = pred[:, -1] - pred[:, -1] / 2
target_l = target[:, -1] - target[:, -1] / 2
# 上边y
pred_t = pred[:, -2] - pred[:, -2] / 2
target_t = target[:, -2] - target[:, -2] / 2
# 右边x
pred_r = pred[:, -1] + pred[:, -1] / 2
target_r = target[:, -1] + target[:, -1] / 2
# 下边y
pred_b = pred[:, -2] + pred[:, -2] / 2
target_b = target[:, -2] + target[:, -2] / 2
'''
计算两个box的bound的对角线距离
'''
bound_l = torch.min(pred_l, target_l) # left
bound_r = torch.max(pred_r, target_r) # right
bound_t = torch.min(pred_t, target_t) # top
bound_b = torch.max(pred_b, target_b) # bottom
c = math.sqrt((bound_r - bound_l) ** 2 + (bound_b - bound_t) ** 2)
dloss = iou - (d ** 2) / (c ** 2)
loss = 1 - dloss.clamp(min=-1.0, max=1.0)
# Step1
# def DIoU(a, b):
# d = a.center_distance(b)
# c = a.bound_diagonal_distance(b)
# return IoU(a, b) - (d ** 2) / (c ** 2)
# Step2-1
# def center_distance(self, other):
# '''
# 计算两个box的中心点距离
# '''
# return euclidean_distance(self.center, other.center)
# Step2-2
# def euclidean_distance(p1, p2):
# '''
# 计算两个点的欧式距离
# '''
# x1, y1 = p1
# x2, y2 = p2
# return math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
# Step3
# def bound_diagonal_distance(self, other):
# '''
# 计算两个box的bound的对角线距离
# '''
# bound = self.boundof(other)
# return euclidean_distance((bound.x, bound.y), (bound.r, bound.b))
# Step3-2
# def boundof(self, other):
# '''
# 计算box和other的边缘外包框,使得2个box都在框内的最小矩形
# '''
# xmin = min(self.x, other.x)
# ymin = min(self.y, other.y)
# xmax = max(self.r, other.r)
# ymax = max(self.b, other.b)
# return BBox(xmin, ymin, xmax, ymax)
# Step3-3
# def euclidean_distance(p1, p2):
# '''
# 计算两个点的欧式距离
# '''
# x1, y1 = p1
# x2, y2 = p2
# return math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
if self.reduction == "mean":
loss = loss.mean()
elif self.reduction == "sum":
loss = loss.sum()
return loss边栏推荐
- 产品经理专业知识50篇(七)-如何建立一套完整的用户成长体系?
- Introduction to the renewal of substrate source code: the pledge amount is greatly reduced, and rocksdb can be completely disabled
- 商品服务-平台属性
- Resource realization applet opening wechat official small store tutorial
- RK356x U-Boot研究所(命令篇)3.3 env相关命令的用法
- 常用的ui组件
- 【精选】资源变现资讯、新闻、自媒体、博客小程序(可引流,开通流量主,带pc后台管理)
- Definition of variables and assignment of variables in MySQL
- 正则系列之断言Assertions
- Open source of xinzhibao applet
猜你喜欢

jmeter 学习笔记

RK356x U-Boot研究所(命令篇)3.2 help命令的用法
![[deep anatomy of C language] storage principle of float variable in memory & comparison between pointer variable and](/img/3d/5d7fafba4ff7903afbd51d6d58dcdf.png)
[deep anatomy of C language] storage principle of float variable in memory & comparison between pointer variable and "zero value"

QT read / write excel--qxlsx worksheet display / hide status setting 4

rxjs Observable 两大类操作符简介

独立站即web3.0,国家“十四五“规划要求企业建数字化网站!

深度长文探讨Join运算的简化和提速

Unity脚本的基础语法(1)-游戏对象的常用操作

postman 自动生成 curl 代码片段

资源变现小程序开通流量主教程
随机推荐
WTM major updates, multi tenancy and single sign on
Rk356x u-boot Institute (command section) 3.2 usage of help command
Introduction to two types of rxjs observable operators
After the market value evaporated by 65billion yuan, the "mask king" made steady medical treatment and focused on condoms
【精选】资源变现资讯、新闻、自媒体、博客小程序(可引流,开通流量主,带pc后台管理)
An interesting thing happened in the project
Understanding and mastery of ffmpeg avbufferpool
postman 自动生成 curl 代码片段
Introduction to the new source code of substrat: fix the memory leak of the mission engine of beefy, and optimize the smart contract deletion queue
[recruitment (Guangzhou)] Chenggong Yi (Guangzhou) Net core middle and Senior Development Engineer
golang文件的写入、追加、读取、复制操作:bufio包的使用示例
[one day learning awk] use of built-in variables
Motor control Clarke( α/β) Derivation of equal amplitude transformation
Methodology for troubleshooting problems (applicable to troubleshooting problems arising from any multi-party cooperation)
Unity脚本程序的开发
Common UI components
rxjs Observable 两大类操作符简介
mqtt-ros模拟发布一个自定义消息类型
产品经理专业知识50篇(七)-如何建立一套完整的用户成长体系?
为基础性语言摇旗呐喊
https://github.com/Megvii-BaseDetection/YOLOX