当前位置:网站首页>损失函数:DIOU loss手写实现
损失函数:DIOU loss手写实现
2022-06-30 12:45:00 【计算机视觉-Archer】
下面是纯diou代码
'''
计算两个box的中心点距离d
'''
# d = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
d = math.sqrt((pred[:, -1] - target[:, -1]) ** 2 + (pred[:, -2] - target[:, -2]) ** 2)
# 左边x
pred_l = pred[:, -1] - pred[:, -1] / 2
target_l = target[:, -1] - target[:, -1] / 2
# 上边y
pred_t = pred[:, -2] - pred[:, -2] / 2
target_t = target[:, -2] - target[:, -2] / 2
# 右边x
pred_r = pred[:, -1] + pred[:, -1] / 2
target_r = target[:, -1] + target[:, -1] / 2
# 下边y
pred_b = pred[:, -2] + pred[:, -2] / 2
target_b = target[:, -2] + target[:, -2] / 2
'''
计算两个box的bound的对角线距离
'''
bound_l = torch.min(pred_l, target_l) # left
bound_r = torch.max(pred_r, target_r) # right
bound_t = torch.min(pred_t, target_t) # top
bound_b = torch.max(pred_b, target_b) # bottom
c = math.sqrt((bound_r - bound_l) ** 2 + (bound_b - bound_t) ** 2)
dloss = iou - (d ** 2) / (c ** 2)
loss = 1 - dloss.clamp(min=-1.0, max=1.0)
第一步 计算两个box的中心点距离d
首先要知道pred和target的输出结果是什么
pred[:,:2]第一个:表示多个图片,第二个:2表示前两个数值,代表矩形框中心点(Y,X)
pred[:,2:]第一个:表示多个图片,第二个2:表示后两个数值,代表矩形框长宽(H,W)
target[:,:2]同理,
d =

根据上面的分析来计算左右上下坐标lrtb

然后计算内部2个矩形的最小外接矩形的对角线长度c

d是两个预测矩形中心点的距离
下面接受各种极端情况
A 两个框中心对齐时候,d/c=0,iou可能0-1
A 两个框相距很远时,d/c=1,iou=0
所以d/c属于0-1
dloss=iou-d/c属于-1到1
因此设置loss=1-dloss属于0-2
class IOUloss(nn.Module):
def __init__(self, reduction="none", loss_type="iou"):
super(IOUloss, self).__init__()
self.reduction = reduction
self.loss_type = loss_type
def forward(self, pred, target):
assert pred.shape[0] == target.shape[0]
pred = pred.view(-1, 4)
target = target.view(-1, 4)
tl = torch.max(
(pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
)
# pred target都是[H,W,Y,X]
# (Y,X)-(H,W) 左上角
br = torch.min(
(pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
)
# (X,Y)+(H,W) 右下角
area_p = torch.prod(pred[:, 2:], 1) # HxW
area_g = torch.prod(target[:, 2:], 1)
en = (tl < br).type(tl.type()).prod(dim=1)
area_i = torch.prod(br - tl, 1) * en
area_u = area_p + area_g - area_i
iou = (area_i) / (area_u + 1e-16)
if self.loss_type == "iou":
loss = 1 - iou ** 2
elif self.loss_type == "giou":
c_tl = torch.min(
(pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
)
c_br = torch.max(
(pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
)
area_c = torch.prod(c_br - c_tl, 1)
giou = iou - (area_c - area_u) / area_c.clamp(1e-16)
loss = 1 - giou.clamp(min=-1.0, max=1.0)
# pred[:, :2] pred[:, 2:]
# (Y,X) (H,W)
# target[:, :2] target[:, 2:]
# (Y,X) (H,W)
elif self.loss_type == "diou":
'''
计算两个box的中心点距离d
'''
# d = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
d = math.sqrt((pred[:, -1] - target[:, -1]) ** 2 + (pred[:, -2] - target[:, -2]) ** 2)
# 左边x
pred_l = pred[:, -1] - pred[:, -1] / 2
target_l = target[:, -1] - target[:, -1] / 2
# 上边y
pred_t = pred[:, -2] - pred[:, -2] / 2
target_t = target[:, -2] - target[:, -2] / 2
# 右边x
pred_r = pred[:, -1] + pred[:, -1] / 2
target_r = target[:, -1] + target[:, -1] / 2
# 下边y
pred_b = pred[:, -2] + pred[:, -2] / 2
target_b = target[:, -2] + target[:, -2] / 2
'''
计算两个box的bound的对角线距离
'''
bound_l = torch.min(pred_l, target_l) # left
bound_r = torch.max(pred_r, target_r) # right
bound_t = torch.min(pred_t, target_t) # top
bound_b = torch.max(pred_b, target_b) # bottom
c = math.sqrt((bound_r - bound_l) ** 2 + (bound_b - bound_t) ** 2)
dloss = iou - (d ** 2) / (c ** 2)
loss = 1 - dloss.clamp(min=-1.0, max=1.0)
# Step1
# def DIoU(a, b):
# d = a.center_distance(b)
# c = a.bound_diagonal_distance(b)
# return IoU(a, b) - (d ** 2) / (c ** 2)
# Step2-1
# def center_distance(self, other):
# '''
# 计算两个box的中心点距离
# '''
# return euclidean_distance(self.center, other.center)
# Step2-2
# def euclidean_distance(p1, p2):
# '''
# 计算两个点的欧式距离
# '''
# x1, y1 = p1
# x2, y2 = p2
# return math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
# Step3
# def bound_diagonal_distance(self, other):
# '''
# 计算两个box的bound的对角线距离
# '''
# bound = self.boundof(other)
# return euclidean_distance((bound.x, bound.y), (bound.r, bound.b))
# Step3-2
# def boundof(self, other):
# '''
# 计算box和other的边缘外包框,使得2个box都在框内的最小矩形
# '''
# xmin = min(self.x, other.x)
# ymin = min(self.y, other.y)
# xmax = max(self.r, other.r)
# ymax = max(self.b, other.b)
# return BBox(xmin, ymin, xmax, ymax)
# Step3-3
# def euclidean_distance(p1, p2):
# '''
# 计算两个点的欧式距离
# '''
# x1, y1 = p1
# x2, y2 = p2
# return math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
if self.reduction == "mean":
loss = loss.mean()
elif self.reduction == "sum":
loss = loss.sum()
return loss边栏推荐
- zabbix-server启动失败处理方式
- 一条查询SQL是如何执行的
- Introduction to two types of rxjs observable operators
- Data Lake (11): Iceberg table data organization and query
- Package tronapi wave field interface based on thinkphp5 PHP version -- interface document attached -20220627
- Introduction to the novelty of substrat source code: indexing of call calls and fully completing the materialization of storage layer
- Basic syntax of unity script (2) -record time in unity
- 正则系列之断言Assertions
- [yitianxue awk] regular matching
- Docker installation of mysql8 and sqlyong connection error 2058 solution [jottings]
猜你喜欢

一文讲清楚什么是类型化数组、ArrayBuffer、TypedArray、DataView等概念

我如何才能保护我的私钥?

SQL考勤统计月报表

嵌入式开发:5个可能不再被禁止的C特征

IDEA 2021.3 执行 golang 报错:RNING: undefined behavior version of Delve is too old for Go version 1.18

A keepalived high availability accident made me learn it again!

Postman automatically generates curl code snippets
![[one day learning awk] Fundamentals](/img/09/a3eb03066eb063bd8594065cdce0aa.png)
[one day learning awk] Fundamentals

【C】深入理解指针、回调函数(介绍模拟qsort)

【招聘(广州)】成功易(广州).Net Core中高级开发工程师
随机推荐
RK356x U-Boot研究所(命令篇)3.2 help命令的用法
Dark horse notes -- wrapper class, regular expression, arrays class
Solve numpy core._ exceptions. Ufunctypeerror: UFUNC 'Add' did not contain a loop with signature matching
golang基础 —— 切片和数组的区别
Unity脚本的基础语法(4)-访问其他游戏对象
The spiral matrix of the force buckle rotates together (you can understand it)
Definition of variables and assignment of variables in MySQL
2022-06-23 帆软部分公式及sql生成(月份、季度取数)
MySQL queries the data within the radius according to the longitude and latitude, and draws a circle to query the database
Charles break point modify request data & response data
JS method of changing two-dimensional array to one-dimensional array
Mqtt ROS simulates publishing a custom message type
60 divine vs Code plug-ins!!
Tronapi-波场接口-源码无加密-可二开--附接口文档-基于ThinkPHP5封装-作者详细指导-2022年6月29日21:59:34
资源变现小程序开通微信官方小商店教程
常用的ui组件
Unity脚本的基础语法(3)-访问游戏对象组件
独立站即web3.0,国家“十四五“规划要求企业建数字化网站!
RK356x U-Boot研究所(命令篇)3.3 env相关命令的用法
这个编辑器即将开源!
https://github.com/Megvii-BaseDetection/YOLOX