当前位置:网站首页>FCOS3D label assignment
FCOS3D label assignment
2022-07-07 11:56:00 【烤粽子】
跟2d的FCOS差不太多,
主要是依靠图片坐标系来分配target:
def _get_target_single(self, gt_bboxes, gt_labels, gt_bboxes_3d,
gt_labels_3d, centers2d, depths, attr_labels,
points, regress_ranges, num_points_per_lvl):
"""Compute regression and classification targets for a single image."""
num_points = points.size(0)
num_gts = gt_labels.size(0)
if not isinstance(gt_bboxes_3d, torch.Tensor):
gt_bboxes_3d = gt_bboxes_3d.tensor.to(gt_bboxes.device)
if num_gts == 0:
return gt_labels.new_full((num_points,), self.background_label), \
gt_bboxes.new_zeros((num_points, 4)), \
gt_labels_3d.new_full(
(num_points,), self.background_label), \
gt_bboxes_3d.new_zeros((num_points, self.bbox_code_size)), \
gt_bboxes_3d.new_zeros((num_points,)), \
attr_labels.new_full(
(num_points,), self.attr_background_label)
# change orientation to local yaw
gt_bboxes_3d[..., 6] = -torch.atan2(
gt_bboxes_3d[..., 0], gt_bboxes_3d[..., 2]) + gt_bboxes_3d[..., 6]
areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * (
gt_bboxes[:, 3] - gt_bboxes[:, 1]) # [tl_x, tl_y, br_x, br_y]--> S_areas
areas = areas[None].repeat(num_points, 1) # [2] --> [30929, 2]
regress_ranges = regress_ranges[:, None, :].expand(
num_points, num_gts, 2) # [30929, 2] --> [30929, 2, 2]
gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 4)
centers2d = centers2d[None].expand(num_points, num_gts, 2)
gt_bboxes_3d = gt_bboxes_3d[None].expand(num_points, num_gts,
self.bbox_code_size)
depths = depths[None, :, None].expand(num_points, num_gts, 1)
# 每个points的坐标(xs,ys)
xs, ys = points[:, 0], points[:, 1]
xs = xs[:, None].expand(num_points, num_gts)
ys = ys[:, None].expand(num_points, num_gts)
# gt center --> offsets
## centers2d: 每个gt在2d image上的坐标
delta_xs = (xs - centers2d[..., 0])[..., None]
delta_ys = (ys - centers2d[..., 1])[..., None]
# 0. 前面的操作是主要是为了这里,获得跟网络输出相同的target_box
bbox_targets_3d = torch.cat(
(delta_xs, delta_ys, depths, gt_bboxes_3d[..., 3:]), dim=-1)
left = xs - gt_bboxes[..., 0]
right = gt_bboxes[..., 2] - xs
top = ys - gt_bboxes[..., 1]
bottom = gt_bboxes[..., 3] - ys
bbox_targets = torch.stack((left, top, right, bottom), -1)
assert self.center_sampling is True, 'Setting center_sampling to '\
'False has not been implemented for FCOS3D.'
# condition1: inside a `center bbox`
radius = self.center_sample_radius # 1.5
center_xs = centers2d[..., 0]
center_ys = centers2d[..., 1]
center_gts = torch.zeros_like(gt_bboxes)
stride = center_xs.new_zeros(center_xs.shape)
# project the points on current lvl back to the `original` sizes
# 1. 将各层特征点位置映射回输入图像中
lvl_begin = 0
for lvl_idx, num_points_lvl in enumerate(num_points_per_lvl): # [23200, 5800, 1450, 375, 104]
lvl_end = lvl_begin + num_points_lvl
stride[lvl_begin:lvl_end] = self.strides[lvl_idx] * radius # [8, 16, 32, 64, 128] * 1.5
# 每个point的缩放系数 * 半径
lvl_begin = lvl_end
# 2. 位于物体框内的位置点作为正样本候选
## 边长1.5的框 -->
center_gts[..., 0] = center_xs - stride
center_gts[..., 1] = center_ys - stride
center_gts[..., 2] = center_xs + stride
center_gts[..., 3] = center_ys + stride
cb_dist_left = xs - center_gts[..., 0] # points中心点到
cb_dist_right = center_gts[..., 2] - xs
cb_dist_top = ys - center_gts[..., 1]
cb_dist_bottom = center_gts[..., 3] - ys
center_bbox = torch.stack(
(cb_dist_left, cb_dist_top, cb_dist_right, cb_dist_bottom), -1)
inside_gt_bbox_mask = center_bbox.min(-1)[0] > 0 # anchor_box中心点落在gt_box中心点1.5单位的正方形内才有效
# condition2: limit the regression range for each location
# 3. 某位置点到物体边框的距离只有位于一定范围内才可作为正样本(每层有各自的范围)
max_regress_distance = bbox_targets.max(-1)[0]
# 确保在每层level的回归范围内
inside_regress_range = (
(max_regress_distance >= regress_ranges[..., 0])
& (max_regress_distance <= regress_ranges[..., 1]))
# center-based criterion to deal with ambiguity
# 4. 基于中心准则的模糊处理
## 4.1选出偏移量最小的gt+gt_inds
dists = torch.sqrt(torch.sum(bbox_targets_3d[..., :2]**2, dim=-1)) # offsets的欧式距离 [30929, 2]
dists[inside_gt_bbox_mask == 0] = INF # 筛选anchor
dists[inside_regress_range == 0] = INF
min_dist, min_dist_inds = dists.min(dim=1)
labels = gt_labels[min_dist_inds] # 筛选gt
labels_3d = gt_labels_3d[min_dist_inds]
attr_labels = attr_labels[min_dist_inds]
labels[min_dist == INF] = self.background_label # set as BG 10
labels_3d[min_dist == INF] = self.background_label # set as BG
attr_labels[min_dist == INF] = self.attr_background_label
## 4.2 每个point上选出对应的box_target
bbox_targets = bbox_targets[range(num_points), min_dist_inds] # [30929, 2, 4] --> [30929, 4]
bbox_targets_3d = bbox_targets_3d[range(num_points), min_dist_inds]
## 4.3 筛选centerness_targets
## 偏移量--> 斜边 / 边长1.5scale到实际三角形的边长 == 相对距离
relative_dists = torch.sqrt(
torch.sum(bbox_targets_3d[..., :2]**2,
dim=-1)) / (1.414 * stride[:, 0])
# [N, 1] / [N, 1]
centerness_targets = torch.exp(-self.centerness_alpha * relative_dists) # exp(-2.5 * relative_dists) todo?
return labels, bbox_targets, labels_3d, bbox_targets_3d, \
centerness_targets, attr_labels
边栏推荐
- Excerpt from "misogyny: female disgust in Japan"
- Use of polarscatter function in MATLAB
- 2022-7-6 Leetcode27.移除元素——太久没有做题了,为双指针如此狼狈的一天
- 供应链供需预估-[时间序列]
- Solve the cache breakdown problem
- Esp32 ① compilation environment
- Introduction and basic use of stored procedures
- The delivery efficiency is increased by 52 times, and the operation efficiency is increased by 10 times. See the compilation of practical cases of financial cloud native technology (with download)
- 2022-7-6 sigurg is used to receive external data. I don't know why it can't be printed out
- 2022-7-7 Leetcode 844. Compare strings with backspace
猜你喜欢
DID登陆-MetaMask
Ogre introduction
Write it down once Net a new energy system thread surge analysis
QQ medicine, Tencent ticket
室內ROS機器人導航調試記錄(膨脹半徑的選取經驗)
Esp32 ① compilation environment
Show the mathematical formula in El table
Custom thread pool rejection policy
2022-7-6 sigurg is used to receive external data. I don't know why it can't be printed out
flask session伪造之hctf admin
随机推荐
[1] ROS2基础知识-操作命令总结版
High end for 8 years, how is Yadi now?
Summary of import, export, backup and recovery of mongodb
648. Word replacement: the classic application of dictionary tree
clion mingw64中文乱码
Mysql怎样控制replace替换的次数?
Redis can only cache? Too out!
OSI seven layer model
Some principles of mongodb optimization
Simple and easy-to-use code specification
"New red flag Cup" desktop application creativity competition 2022
Sliding rail stepping motor commissioning (national ocean vehicle competition) (STM32 master control)
Digital IC Design SPI
Drawerlayout suppress sideslip display
Storage principle inside mongodb
566. Reshaping the matrix
Mongodb meets spark (for integration)
call undefined function openssl_cipher_iv_length
TPG x AIDU | AI leading talent recruitment plan in progress!
[fortress machine] what is the difference between cloud fortress machine and ordinary fortress machine?