当前位置:网站首页>Loss function: Diou loss handwriting implementation
Loss function: Diou loss handwriting implementation
2022-06-30 13:22:00 【Computer vision Archer】
Here's the pure diou Code
'''
Calculate two box The distance from the center point of d
'''
# d = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
d = math.sqrt((pred[:, -1] - target[:, -1]) ** 2 + (pred[:, -2] - target[:, -2]) ** 2)
# On the left x
pred_l = pred[:, -1] - pred[:, -1] / 2
target_l = target[:, -1] - target[:, -1] / 2
# above y
pred_t = pred[:, -2] - pred[:, -2] / 2
target_t = target[:, -2] - target[:, -2] / 2
# On the right x
pred_r = pred[:, -1] + pred[:, -1] / 2
target_r = target[:, -1] + target[:, -1] / 2
# Underside y
pred_b = pred[:, -2] + pred[:, -2] / 2
target_b = target[:, -2] + target[:, -2] / 2
'''
Calculate two box Of bound Diagonal distance of
'''
bound_l = torch.min(pred_l, target_l) # left
bound_r = torch.max(pred_r, target_r) # right
bound_t = torch.min(pred_t, target_t) # top
bound_b = torch.max(pred_b, target_b) # bottom
c = math.sqrt((bound_r - bound_l) ** 2 + (bound_b - bound_t) ** 2)
dloss = iou - (d ** 2) / (c ** 2)
loss = 1 - dloss.clamp(min=-1.0, max=1.0)
First step Calculate two box The distance from the center point of d
The first thing to know is pred and target What is the output of
pred[:,:2] first : Represents multiple pictures , the second :2 Represents the first two values , Represents the center point of the rectangle (Y,X)
pred[:,2:] first : Represents multiple pictures , the second 2: Express after Two numerical , Represents the length and width of the rectangle (H,W)
target[:,:2] Empathy ,
d =

Calculate the left and right upper and lower coordinates according to the above analysis lrtb

Then calculate the internal 2 The diagonal length of the smallest circumscribed rectangle of a rectangle c

d Is the distance between the center points of the two prediction rectangles 
Accept the extremes below
A When the centers of the two boxes are aligned ,d/c=0,iou Probably 0-1
A When the two boxes are far apart ,d/c=1,iou=0
therefore d/c Belong to 0-1
dloss=iou-d/c Belong to -1 To 1
So set loss=1-dloss Belong to 0-2
Exhibition iou\giou\diou Code , This is a YOLOX Its own loss function , among dloss I wrote it myself
YOLOX Is downloaded from
GitHub - Megvii-BaseDetection/YOLOX: YOLOX is a high-performance anchor-free YOLO, exceeding yolov3~v5 with MegEngine, ONNX, TensorRT, ncnn, and OpenVINO supported. Documentation: https://yolox.readthedocs.io/YOLOX is a high-performance anchor-free YOLO, exceeding yolov3~v5 with MegEngine, ONNX, TensorRT, ncnn, and OpenVINO supported. Documentation: https://yolox.readthedocs.io/ - GitHub - Megvii-BaseDetection/YOLOX: YOLOX is a high-performance anchor-free YOLO, exceeding yolov3~v5 with MegEngine, ONNX, TensorRT, ncnn, and OpenVINO supported. Documentation: https://yolox.readthedocs.io/
https://github.com/Megvii-BaseDetection/YOLOX
class IOUloss(nn.Module):
def __init__(self, reduction="none", loss_type="iou"):
super(IOUloss, self).__init__()
self.reduction = reduction
self.loss_type = loss_type
def forward(self, pred, target):
assert pred.shape[0] == target.shape[0]
pred = pred.view(-1, 4)
target = target.view(-1, 4)
tl = torch.max(
(pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
)
# pred target All are [H,W,Y,X]
# (Y,X)-(H,W) top left corner
br = torch.min(
(pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
)
# (X,Y)+(H,W) The lower right corner
area_p = torch.prod(pred[:, 2:], 1) # HxW
area_g = torch.prod(target[:, 2:], 1)
en = (tl < br).type(tl.type()).prod(dim=1)
area_i = torch.prod(br - tl, 1) * en
area_u = area_p + area_g - area_i
iou = (area_i) / (area_u + 1e-16)
if self.loss_type == "iou":
loss = 1 - iou ** 2
elif self.loss_type == "giou":
c_tl = torch.min(
(pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
)
c_br = torch.max(
(pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
)
area_c = torch.prod(c_br - c_tl, 1)
giou = iou - (area_c - area_u) / area_c.clamp(1e-16)
loss = 1 - giou.clamp(min=-1.0, max=1.0)
# pred[:, :2] pred[:, 2:]
# (Y,X) (H,W)
# target[:, :2] target[:, 2:]
# (Y,X) (H,W)
elif self.loss_type == "diou":
'''
Calculate two box The distance from the center point of d
'''
# d = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
d = math.sqrt((pred[:, -1] - target[:, -1]) ** 2 + (pred[:, -2] - target[:, -2]) ** 2)
# On the left x
pred_l = pred[:, -1] - pred[:, -1] / 2
target_l = target[:, -1] - target[:, -1] / 2
# above y
pred_t = pred[:, -2] - pred[:, -2] / 2
target_t = target[:, -2] - target[:, -2] / 2
# On the right x
pred_r = pred[:, -1] + pred[:, -1] / 2
target_r = target[:, -1] + target[:, -1] / 2
# Underside y
pred_b = pred[:, -2] + pred[:, -2] / 2
target_b = target[:, -2] + target[:, -2] / 2
'''
Calculate two box Of bound Diagonal distance of
'''
bound_l = torch.min(pred_l, target_l) # left
bound_r = torch.max(pred_r, target_r) # right
bound_t = torch.min(pred_t, target_t) # top
bound_b = torch.max(pred_b, target_b) # bottom
c = math.sqrt((bound_r - bound_l) ** 2 + (bound_b - bound_t) ** 2)
dloss = iou - (d ** 2) / (c ** 2)
loss = 1 - dloss.clamp(min=-1.0, max=1.0)
# Step1
# def DIoU(a, b):
# d = a.center_distance(b)
# c = a.bound_diagonal_distance(b)
# return IoU(a, b) - (d ** 2) / (c ** 2)
# Step2-1
# def center_distance(self, other):
# '''
# Calculate two box The distance from the center point of
# '''
# return euclidean_distance(self.center, other.center)
# Step2-2
# def euclidean_distance(p1, p2):
# '''
# Calculate the Euclidean distance between two points
# '''
# x1, y1 = p1
# x2, y2 = p2
# return math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
# Step3
# def bound_diagonal_distance(self, other):
# '''
# Calculate two box Of bound Diagonal distance of
# '''
# bound = self.boundof(other)
# return euclidean_distance((bound.x, bound.y), (bound.r, bound.b))
# Step3-2
# def boundof(self, other):
# '''
# Calculation box and other The edge of the outsourcing box , bring 2 individual box The smallest rectangle in the box
# '''
# xmin = min(self.x, other.x)
# ymin = min(self.y, other.y)
# xmax = max(self.r, other.r)
# ymax = max(self.b, other.b)
# return BBox(xmin, ymin, xmax, ymax)
# Step3-3
# def euclidean_distance(p1, p2):
# '''
# Calculate the Euclidean distance between two points
# '''
# x1, y1 = p1
# x2, y2 = p2
# return math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
if self.reduction == "mean":
loss = loss.mean()
elif self.reduction == "sum":
loss = loss.sum()
return loss边栏推荐
- Understanding and mastery of ffmpeg avbufferpool
- kubeedge的核心理念
- 黑马笔记---常用日期API
- 【C】 In depth understanding of pointers and callback functions (Introduction to simulating qsort)
- Product manager professional knowledge 50 (7) - how to establish a complete set of user growth system?
- Tronapi-波场接口-PHP版本--附接口文档-基于ThinkPHP5封装-源码无加密-可二开-作者详细指导-2022年6月28日11:49:56
- 商品服务-平台属性
- Unity脚本的基础语法(5)-向量
- App wechat payment unicloud version of uniapp payment (with source code)
- 我如何才能保护我的私钥?
猜你喜欢

WTM重大更新,多租户和单点登录

How can I protect my private key?

微信小程序报错:TypeError: Cannot read property ‘setData‘ of undefined

Basic syntax of unity script (1) - common operations of game objects

ABAP工具箱 V1.0(附实现思路)

【C语言深度解剖】float变量在内存中存储原理&&指针变量与“零值”比较

【C】 In depth understanding of pointers and callback functions (Introduction to simulating qsort)

App wechat payment unicloud version of uniapp payment (with source code)

Today's sleep quality record 80 points

黑马笔记---常用日期API
随机推荐
Charles break point modify request data & response data
一次 Keepalived 高可用的事故,让我重学了一遍它!
Open source of xinzhibao applet
mqtt-ros模拟发布一个自定义消息类型
Motor control Clarke( α/β) Derivation of equal amplitude transformation
微信小程序报错:TypeError: Cannot read property ‘setData‘ of undefined
zabbix-server启动失败处理方式
[yitianxue awk] regular matching
App wechat payment unicloud version of uniapp payment (with source code)
Dark horse notes -- wrapper class, regular expression, arrays class
navicat数据库建表是没有utf8选项。
When MySQL judges that the execution condition is null, it returns 0. Correct parameter count in the call to native function 'isnull',
MySQL implements the division of two query results
顺应媒体融合趋势,中科闻歌携手美摄打造数智媒宣
常用的ui组件
A keepalived high availability accident made me learn it again!
WTM major updates, multi tenancy and single sign on
Ffmpeg miscellaneous
一篇文章读懂关于企业IM的所有知识点
PG Basics - logical structure management (table inheritance, partition table)