当前位置:网站首页>【yolov3损失函数】
【yolov3损失函数】
2022-07-05 11:34:00 【网络星空(luoc)】
解析的代码地址:
github:tensorflow-yolov3
本文解析计算各部分损失compute_loss()部分:
yolov3的损失函数稍微有点复杂,花了点时间看了好多篇有关yolov3的损失函数,总算有了一点点眉目,下篇文章写一下有关yolov3损失函数的理论部分。这篇源码的解析主要提现在结构图和源码的注释中,相关部分没有写专门的文字说明。
开始在train.py的第55行调用compute_loss()计算模型训练的损失,然后再yolov3.py中进入到compute_loss()函数
compute_loss()函数
其实这部分代码并没有进行核心的计算,主要是传入相关参数,然后调用loss_layer()函数进行相关计算。
结构图:
源码:
# 计算损失
''' conv_sbbox, # 卷积特征图里的原始值 pred_sbbox, # 由conv_sbbox经过decode()函数解码获得 label_sbbox, # [52,52,3,85] 标记框GT的坐标值,置信度,概率等多个值 true_sbbox, # 只含标记框GT的坐标值 (3,150,4) true_box在实际训练时传入的是下面的值 self.true_sbboxes: train_data[4]==sbboxes == bboxes_xywh[..,0] # 获取小中大标记框的坐标值 self.true_mbboxes: train_data[5]==mbboxes == bboxes_xywh[..,1] # 获取小中大标记框的坐标值 self.true_lbboxes: train_data[6]==lbboxes == bboxes_xywh[..,2] # 获取小中大标记框的坐标值 '''
def compute_loss(self, label_sbbox, label_mbbox, label_lbbox, true_sbbox, true_mbbox, true_lbbox):
# 检测小框损失 52*52
with tf.name_scope('smaller_box_loss'):
loss_sbbox = self.loss_layer(self.conv_sbbox, self.pred_sbbox, label_sbbox, true_sbbox,
anchors = self.anchors[0], stride = self.strides[0])
# 检测中框损失 26*26
with tf.name_scope('medium_box_loss'):
loss_mbbox = self.loss_layer(self.conv_mbbox, self.pred_mbbox, label_mbbox, true_mbbox,
anchors = self.anchors[1], stride = self.strides[1])
# 检测大框损失 13*13
with tf.name_scope('bigger_box_loss'):
loss_lbbox = self.loss_layer(self.conv_lbbox, self.pred_lbbox, label_lbbox, true_lbbox,
anchors = self.anchors[2], stride = self.strides[2])
# 广义IOU损失 ,把检测小、中、大框的广义IOU损失相加
with tf.name_scope('giou_loss'):
giou_loss = loss_sbbox[0] + loss_mbbox[0] + loss_lbbox[0]
# 置信度损失,把检测小、中、大框的置信度损失相加
with tf.name_scope('conf_loss'):
conf_loss = loss_sbbox[1] + loss_mbbox[1] + loss_lbbox[1]
# 分类交叉熵损失,把检测小、中、大框的分类交叉熵损失相加
with tf.name_scope('prob_loss'):
prob_loss = loss_sbbox[2] + loss_mbbox[2] + loss_lbbox[2]
return giou_loss, conf_loss, prob_loss
loss_layer()函数
结构图:
源码:
# 损失层
''' conv, # 卷积特征图里的原始值 pred, # 由conv经过decode()函数解码获得 label, # 标记框GT的坐标值,置信度,概率等多个值 [52,52,3,85] bboxes, # 只含标记框GT的坐标值 (3,150,4) anchors # 基准anchor的宽高 stride # 特征图相对于原图的缩放率 '''
def loss_layer(self, conv, pred, label, bboxes, anchors, stride):
conv_shape = tf.shape(conv) # 原图卷积得到的值
batch_size = conv_shape[0] # 一次多少张图
output_size = conv_shape[1] # 特征图大小 13*13, 26*26, 52*52
input_size = stride * output_size # 13*32 26*16 52*8
# conv is reshaped here to 5 dimensions
conv = tf.reshape(conv, (batch_size, output_size, output_size,
self.anchor_per_scale, 5 + self.num_class))
# this is the logit before going into sigmoid functions
conv_raw_conf = conv[:, :, :, :, 4:5] # 原始的置信度,起始值就是特征图像的值
conv_raw_prob = conv[:, :, :, :, 5:] # 原始的预测概率,起始值就是特征图像的值
pred_xywh = pred[:, :, :, :, 0:4] # 预测框xywh
pred_conf = pred[:, :, :, :, 4:5] # 预测置信度
# true coordinates (x, y, w, h) 标记坐标GT
label_xywh = label[:, :, :, :, 0:4] # 标记框坐标
# what is this?
respond_bbox = label[:, :, :, :, 4:5] # 置信度,判断网格内有无物体
# true probabilities
label_prob = label[:, :, :, :, 5:] # 真值概率
# GIOU损失
# label_xywh and pred_xywh are used to compute giou
# 标记框的xywh和预测框的xywh用来计算giou
giou = tf.expand_dims(self.bbox_giou(pred_xywh, label_xywh), axis=-1) #在第axis位置增加一个维度,-1表示最后一维
input_size = tf.cast(input_size, tf.float32) # 数据类型转换
bbox_loss_scale = 2.0 - 1.0 * label_xywh[:, :, :, :, 2:3] * label_xywh[:, :, :, :, 3:4] / (input_size ** 2)
giou_loss = respond_bbox * bbox_loss_scale * (1- giou)
# 置信度损失
# bboxes (true_bboxes) and pred_xywh are used to compute iou
# 计算true_bboxes和pred_xywh的交并比
iou = self.bbox_iou(pred_xywh[:, :, :, :, np.newaxis, :], bboxes[:, np.newaxis, np.newaxis, np.newaxis, :, :])
# 找出与真实框 iou 值最大的预测框
max_iou = tf.expand_dims(tf.reduce_max(iou, axis=-1), axis=-1)
# 如果最大的 iou 小于阈值,那么认为该预测框不包含物体,则为背景框
respond_bgd = (1.0 - respond_bbox) * tf.cast( max_iou < self.iou_loss_thresh, tf.float32 )
conf_focal = self.focal(respond_bbox, pred_conf) # ???
# 计算置信度的损失(我们希望假如该网格中包含物体,那么网络输出的预测框置信度为 1,无物体时则为 0。
conf_loss = conf_focal * (
respond_bbox * tf.nn.sigmoid_cross_entropy_with_logits(labels=respond_bbox, logits=conv_raw_conf)
+
respond_bgd * tf.nn.sigmoid_cross_entropy_with_logits(labels=respond_bbox, logits=conv_raw_conf)
)
# cross-entropy for classifications
# 分类的交叉熵损失
prob_loss = respond_bbox * tf.nn.sigmoid_cross_entropy_with_logits(labels=label_prob, logits=conv_raw_prob)
giou_loss = tf.reduce_mean(tf.reduce_sum(giou_loss, axis=[1,2,3,4]))
conf_loss = tf.reduce_mean(tf.reduce_sum(conf_loss, axis=[1,2,3,4]))
prob_loss = tf.reduce_mean(tf.reduce_sum(prob_loss, axis=[1,2,3,4]))
return giou_loss, conf_loss, prob_loss
其中调用的3个小函数:
Focal loss让one stage detecor也变的牛逼起来,解决了class imbalance的问题。 是同时解决了正负样本不平衡以及区分简单与复杂样本的问题。
def focal(self, target, actual, alpha=1, gamma=2):
focal_loss = alpha * tf.pow(tf.abs(target - actual), gamma)
return focal_loss
计算giou
# 计算giou
def bbox_giou(self, boxes1, boxes2):
boxes1 = tf.concat([boxes1[..., :2] - boxes1[..., 2:] * 0.5,
boxes1[..., :2] + boxes1[..., 2:] * 0.5], axis=-1)
boxes2 = tf.concat([boxes2[..., :2] - boxes2[..., 2:] * 0.5,
boxes2[..., :2] + boxes2[..., 2:] * 0.5], axis=-1)
boxes1 = tf.concat([tf.minimum(boxes1[..., :2], boxes1[..., 2:]),
tf.maximum(boxes1[..., :2], boxes1[..., 2:])], axis=-1)
boxes2 = tf.concat([tf.minimum(boxes2[..., :2], boxes2[..., 2:]),
tf.maximum(boxes2[..., :2], boxes2[..., 2:])], axis=-1)
boxes1_area = (boxes1[..., 2] - boxes1[..., 0]) * (boxes1[..., 3] - boxes1[..., 1])
boxes2_area = (boxes2[..., 2] - boxes2[..., 0]) * (boxes2[..., 3] - boxes2[..., 1])
left_up = tf.maximum(boxes1[..., :2], boxes2[..., :2]) # 左上
right_down = tf.minimum(boxes1[..., 2:], boxes2[..., 2:]) # 右下
inter_section = tf.maximum(right_down - left_up, 0.0)
inter_area = inter_section[..., 0] * inter_section[..., 1]
union_area = boxes1_area + boxes2_area - inter_area
# 计算两个边界框之间的 iou 值
iou = inter_area / union_area
# 计算最小闭合凸面 C 左上角和右下角的坐标
enclose_left_up = tf.minimum(boxes1[..., :2], boxes2[..., :2])
enclose_right_down = tf.maximum(boxes1[..., 2:], boxes2[..., 2:])
enclose = tf.maximum(enclose_right_down - enclose_left_up, 0.0)
# 计算最小闭合凸面 C 的面积
enclose_area = enclose[..., 0] * enclose[..., 1]
# 根据 GIoU 公式计算 GIoU 值
giou = iou - 1.0 * (enclose_area - union_area) / enclose_area
return giou
两个框的交并比
# 两个框的交并比
def bbox_iou(self, boxes1, boxes2):
boxes1_area = boxes1[..., 2] * boxes1[..., 3]
boxes2_area = boxes2[..., 2] * boxes2[..., 3]
boxes1 = tf.concat([boxes1[..., :2] - boxes1[..., 2:] * 0.5,
boxes1[..., :2] + boxes1[..., 2:] * 0.5], axis=-1)
boxes2 = tf.concat([boxes2[..., :2] - boxes2[..., 2:] * 0.5,
boxes2[..., :2] + boxes2[..., 2:] * 0.5], axis=-1)
left_up = tf.maximum(boxes1[..., :2], boxes2[..., :2])
right_down = tf.minimum(boxes1[..., 2:], boxes2[..., 2:])
inter_section = tf.maximum(right_down - left_up, 0.0)
inter_area = inter_section[..., 0] * inter_section[..., 1]
union_area = boxes1_area + boxes2_area - inter_area
iou = 1.0 * inter_area / union_area
return iou
边栏推荐
- 以交互方式安装ESXi 6.0
- COMSOL -- establishment of geometric model -- establishment of two-dimensional graphics
- Differences between IPv6 and IPv4 three departments including the office of network information technology promote IPv6 scale deployment
- 【Office】Excel中IF函数的8种用法
- 全网最全的新型数据库、多维表格平台盘点 Notion、FlowUs、Airtable、SeaTable、维格表 Vika、飞书多维表格、黑帕云、织信 Informat、语雀
- IPv6与IPv4的区别 网信办等三部推进IPv6规模部署
- 【爬虫】charles unknown错误
- 管理多个Instagram帐户防关联小技巧大分享
- SET XACT_ ABORT ON
- The art of communication III: Listening between people
猜你喜欢
随机推荐
How to protect user privacy without password authentication?
How to understand super browser? What scenarios can it be used in? What brands are there?
7.2 daily study 4
Unity xlua monoproxy mono proxy class
Summary of thread and thread synchronization under window
解决访问国外公共静态资源速度慢的问题
【爬虫】wasm遇到的bug
C # implements WinForm DataGridView control to support overlay data binding
13. (map data) conversion between Baidu coordinate (bd09), national survey of China coordinate (Mars coordinate, gcj02), and WGS84 coordinate system
Error assembling WAR: webxml attribute is required (or pre-existing WEB-INF/web.xml if executing in
Prevent browser backward operation
Shell script file traversal STR to array string splicing
Question and answer 45: application of performance probe monitoring principle node JS probe
DDoS attack principle, the phenomenon of being attacked by DDoS
11.(地图数据篇)OSM数据如何下载使用
redis集群中hash tag 使用
OneForAll安装使用
spark调优(一):从hql转向代码
-26374 and -26377 errors during coneroller execution
iTOP-3568开发板NPU使用安装RKNN Toolkit Lite2







![[crawler] bugs encountered by wasm](/img/29/6782bda4c149b7b2b334238936e211.png)
