当前位置:网站首页>Mmdet line by line deltaxywhbboxcoder
Mmdet line by line deltaxywhbboxcoder
2022-06-30 08:42:00 【Wu lele~】
List of articles
Preface
This is MMdet Read Chapter III line by line , Code address :mmdet/core/bbox/coder/delta_xywh_bbox_coder.py. The historical article is as follows :
AnchorGenerator Reading
MaxIOUAssigner Reading
1、BaseBBoxCoder Parent class
This class is all bbox The parent class of the codec class , The code is easier to understand , That is, all subclasses that inherit this class need to implement encode and decode Two methods .
class BaseBBoxCoder(metaclass=ABCMeta):
"""Base bounding box coder"""
def __init__(self, **kwargs):
pass
@abstractmethod
def encode(self, bboxes, gt_bboxes):
"""Encode deltas between bboxes and ground truth boxes"""
pass
@abstractmethod
def decode(self, bboxes, bboxes_pred):
""" Decode the predicted bboxes according to prediction and base boxes """
pass
2、DeltaXYWHBBoxCoder class
2.1. Theoretical basis
Most are based on anchor The target detection algorithms of all use this class . In the target detection algorithm , In order to facilitate the convergence of the network , With the help of anchor And return to anchor and gtbbox The deviation between . therefore , What the network predicts is bias , therefore , In the process of training , Need to compute gtbbox and anchor The true value of the deviation between t*. real t* The calculation of is as follows :
among [x,y,w,h] Express gtbbox Center width and height of ;[xa,ya,wa,ha] Express anchor Center width and height of . Simply speaking ,tx* ,ty* It means that the difference between the two is normalized by the width and height ;tw*,th* Just take a logarithm .
2.2、 Initialization part
We first construct an object :
import torch
from mmdet.core.bbox import build_bbox_coder
if __name__ == '__main__':
bbox_coder = dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[1.0, 1.0, 1.0, 1.0])
coder = build_bbox_coder(bbox_coder)
# Construct two prediction vectors and true values
proposals = torch.tensor([[1,1,3,3],[4,4,6,6]])
gt = torch.tensor([[2,2,3,3],[2,2,5,5]])
target_t = coder.encode(proposals,gt) # Call decoding method
among target_means and target_stds Is to the above t* Minus the mean divided by the standard deviation .
2.3、 Coding process
I have commented on this part of the code , The general idea is to first proposals and gtbbox from [xmin, ymin, xmax, ymax] become [cx, cy, w,h], Calculated after t*, And then t* Minus the mean divided by the standard deviation .
# Number of candidate boxes and gt The quantity must be consistent
assert proposals.size() == gt.size() # [N,4]
proposals = proposals.float()
gt = gt.float()
# proposals: [xmin, ymin, xmax, ymax] --> [cx, cy, w, h]
px = (proposals[..., 0] + proposals[..., 2]) * 0.5 # [N]
py = (proposals[..., 1] + proposals[..., 3]) * 0.5
pw = proposals[..., 2] - proposals[..., 0]
ph = proposals[..., 3] - proposals[..., 1]
# gt: [xmin, ymin, xmax, ymax] --> [cx, cy, w, h]
gx = (gt[..., 0] + gt[..., 2]) * 0.5
gy = (gt[..., 1] + gt[..., 3]) * 0.5
gw = gt[..., 2] - gt[..., 0]
gh = gt[..., 3] - gt[..., 1]
# Calculation t*
dx = (gx - px) / pw
dy = (gy - py) / ph
dw = torch.log(gw / pw)
dh = torch.log(gh / ph)
deltas = torch.stack([dx, dy, dw, dh], dim=-1) # [N] --> [N,4]
# Minus the mean divided by the standard deviation
means = deltas.new_tensor(means).unsqueeze(0) # [1,4]
stds = deltas.new_tensor(stds).unsqueeze(0) # [1,4]
deltas = deltas.sub_(means).div_(stds) # [N,4]
2.4、 The decoding process
The decoding process often occurs in the test phase . The deviation predicted by the network t Add to anchor Get on proposal( First order algorithm ) perhaps roi( Second order ) With . This process is the reverse of the encoding process , First multiply the standard deviation and add the mean to get t, After the t Add to anchor You can go up. .
# Mean and standard deviation : [4] --> [1,4]
means = deltas.new_tensor(means).repeat(1, deltas.size(1) // 4)
stds = deltas.new_tensor(stds).repeat(1, deltas.size(1) // 4)
denorm_deltas = deltas * stds + means # [N,4]
# obtain dx,dy,dw,dh
dx = denorm_deltas[:, 0::4] # [N,1]
dy = denorm_deltas[:, 1::4]
dw = denorm_deltas[:, 2::4]
dh = denorm_deltas[:, 3::4]
max_ratio = np.abs(np.log(wh_ratio_clip))
dw = dw.clamp(min=-max_ratio, max=max_ratio) # Cut down
dh = dh.clamp(min=-max_ratio, max=max_ratio)
# take rois/proposal Turn into [cx,cy,w,h] Format :[N,] --> [N,1] --> [N,1]
px = ((rois[:, 0] + rois[:, 2]) * 0.5).unsqueeze(1).expand_as(dx)
py = ((rois[:, 1] + rois[:, 3]) * 0.5).unsqueeze(1).expand_as(dy)
# Compute width/height of each roi
pw = (rois[:, 2] - rois[:, 0]).unsqueeze(1).expand_as(dw)
ph = (rois[:, 3] - rois[:, 1]).unsqueeze(1).expand_as(dh)
# The decoding process
gw = pw * dw.exp()
gh = ph * dh.exp()
gx = px + pw * dx
gy = py + ph * dy
# take [cx,cy,w,h] --> [xmin, ymin, xmax, ymax] Format
x1 = gx - gw * 0.5
y1 = gy - gh * 0.5
x2 = gx + gw * 0.5
y2 = gy + gh * 0.5
# Cut it down bbox, If it is too large, it will be cut off
if max_shape is not None:
x1 = x1.clamp(min=0, max=max_shape[1])
y1 = y1.clamp(min=0, max=max_shape[0])
x2 = x2.clamp(min=0, max=max_shape[1])
y2 = y2.clamp(min=0, max=max_shape[0])
# Return to the correct oversize forecast box
bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view_as(deltas)
summary
In general, this part of the source code is relatively simple , Next introduced anchor Of sampler part .
边栏推荐
- 【NVMe2.0b 14-5】Firmware Download/Commit command
- 1. Problems related to OpenGL window and environment configuration
- CUDA realizes matrix multiplication
- C# ListBox如何获取选中的内容(搜了很多无效的文章)
- Unity simple shader
- Pytorch BERT
- Map,String,Json之間轉換
- 小心transmittable-thread-local的这个坑
- TiDB 6.0:让 TSO 更高效丨TiDB Book Rush
- Deploy the cow like customer network project on the ECS
猜你喜欢

使用华为性能管理服务,按需配置采样率

Redis design and Implementation (VIII) | transaction

Flink SQL 自定义 Connector

增强for循环的增删操作 & 迭代器删除集合元素

Pytorch BERT

【NVMe2.0b 14-2】Create/Delete Queue

Redis design and Implementation (V) | sentinel sentry

TiDB 6.0:让 TSO 更高效丨TiDB Book Rush

Gilbert Strang's course notes on linear algebra - Lesson 2

Pytorch BERT
随机推荐
C # listbox how to get the selected content (search many invalid articles)
Vite project require syntax compatibility problem solving require is not defined
从0开始构建一个瀚高数据库Docker镜像
电流探头电路分析
mysql基础入门 day3 动力节点[老杜]课堂笔记
Swagger use
启动jar包报错UnsupportedClassVersionError,如何修复
一次cpu 跌底排查
Mmcv expanding CUDA operator beginner level chapter
Gilbert Strang's course notes on linear algebra - Lesson 2
Flink SQL 自定义 Connector
VIM from dislike to dependence (21) -- cross file search
Flink Sql -- toAppendStream doesn‘t support consuming update and delete changes which
Redis design and Implementation (II) | database (deletion strategy & expiration elimination strategy)
[untitled]
Codeworks 5 questions per day (1700 for each) - the third day
文件上传 upload 组件 on-success 事件,添加自定义参数
Graffiti Wi Fi & ble SoC development slide strip
Flink Exception -- No ExecutorFactory found to execute the application
Redis design and Implementation (III) | interaction between server and client (event IO model)