当前位置:网站首页>Mmdet line by line code interpretation of positive and negative sample sampler
Mmdet line by line code interpretation of positive and negative sample sampler
2022-06-30 08:41:00 【Wu lele~】
List of articles
Preface
This is MMdet Read the fourth article line by line , Code address :mmdet/core/bbox/samplers/random_sampler.py. Random sampling of positive and negative samples is mainly aimed at the training process , after MAXIOUAssigner after , Identify each anchor With which gt After the match , Take samples from these positive and negative samples loss Calculation . This article takes RPN Of config Explain , Because this part uses random sampling to overcome the imbalance of positive and negative samples ; And in the RetinaNet Use... In focal loss To overcome the imbalance between positive and negative samples , That is, there is no random sampling process .
The historical article is as follows :
AnchorGenerator Reading
MaxIOUAssigner Reading
DeltaXYWHBBoxCoder Reading
1、 Construct a simple sampler
from mmdet.core.bbox import build_sampler
# Construct a sampler
sampler = dict(
type='RandomSampler',# Construct a random sampler
num=256, # Total number of positive and negative samples
pos_fraction=0.5, # Positive sample ratio
neg_pos_ub=-1, # Negative sample upper limit
add_gt_as_proposals=False) # Whether to add gt As a positive sample , Do not add by default .
sp = build_sampler(sampler)
# This need not be understood in detail , Just know the meaning .
# Is to randomly generate one assigner、bboxes and gt_bboxes,
from mmdet.core.bbox import AssignResult
from mmdet.core.bbox.demodata import ensure_rng, random_boxes
rng = ensure_rng(None)
assign_result = AssignResult.random(rng=rng)
bboxes = random_boxes(assign_result.num_preds, rng=rng)
gt_bboxes = random_boxes(assign_result.num_gts, rng=rng)
gt_labels = None
# call sample Method to sample positive and negative samples
self = sp.sample(assign_result, bboxes, gt_bboxes, gt_labels)
2、BaseSampler class
class BaseSampler(metaclass=ABCMeta):
"""Base class of samplers"""
def __init__(self,
num,
pos_fraction,
neg_pos_ub=-1,
add_gt_as_proposals=True,
**kwargs):
self.num = num
self.pos_fraction = pos_fraction
self.neg_pos_ub = neg_pos_ub
self.add_gt_as_proposals = add_gt_as_proposals
self.pos_sampler = self
self.neg_sampler = self
@abstractmethod
def _sample_pos(self, assign_result, num_expected, **kwargs):
"""Sample positive samples"""
pass
@abstractmethod
def _sample_neg(self, assign_result, num_expected, **kwargs):
"""Sample negative samples"""
pass
def sample(self,
assign_result,
bboxes,
gt_bboxes,
gt_labels=None,
**kwargs):
pass
Base classes are easier to understand , The core is sample Method , Internal calls _sample_pos Methods and _sample_neg Method . Subsequent subclasses that inherit this class only need to implement _sample_pos Methods and _sample_neg The method can .
3、RandomSampler class
3.1 sample Method
With RandomSampler Class to explain the code . Take a look first sample Method :
# Determine the number of positive samples : 256*0.5 = 128
num_expected_pos = int(self.num * self.pos_fraction)
# call _sample_pos Method returns the positive sample after sampling id.
pos_inds = self.pos_sampler._sample_pos(
assign_result, num_expected_pos, bboxes=bboxes, **kwargs)
pos_inds = pos_inds.unique() # Pick out tensor Independent non repeating elements
num_sampled_pos = pos_inds.numel() # Determine the number of positive samples
num_expected_neg = self.num - num_sampled_pos # Determine the number of negative samples
# Because the parameter is -1, So... Is not executed if sentence , That is, real-world sampling 254 Negative samples
if self.neg_pos_ub >= 0:
_pos = max(1, num_sampled_pos)
# Determine that the upper limit of negative samples is the number of positive samples neg_pos_ub times
neg_upper_bound = int(self.neg_pos_ub * _pos)
# The number of negative samples cannot exceed the upper limit
if num_expected_neg > neg_upper_bound:
num_expected_neg = neg_upper_bound
# call _sample_neg Method returns the value of the negative sample after sampling id
neg_inds = self.neg_sampler._sample_neg(
assign_result, num_expected_neg, bboxes=bboxes, **kwargs)
neg_inds = neg_inds.unique() # Empathy , take id Fetch set operation .
# use SamplingResult encapsulate
sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
assign_result, gt_flags)
The code is still relatively easy to understand : First, determine the number of positive samples , Then finish sampling ; Determine the number of negative samples after sampling , If a negative sample upper limit is specified :neg_upper_bound, The maximum number of negative samples cannot exceed the number of positive samples neg_upper_bound times ; If not specified , Then the number of negative samples is the total number - Number of positive samples .
3.2 _sample_pos Method
Let's look at the method of sampling positive samples :
def _sample_pos(self, assign_result, num_expected, **kwargs):
"""Randomly sample some positive samples."""
# Find out what's wrong with 0 A positive sample of id
pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False)
if pos_inds.numel() != 0:
pos_inds = pos_inds.squeeze(1)
# If the number of positive samples < Expectant 128 individual , Then return directly .
if pos_inds.numel() <= num_expected:
return pos_inds
# Otherwise, it will be from pos_inds Pick enough at random 128 individual .
else:
return self.random_choice(pos_inds, num_expected)
3.2 _sample_neg Method
It is similar to the method of sampling positive samples , Take a look here .
def _sample_neg(self, assign_result, num_expected, **kwargs):
"""Randomly sample some negative samples."""
neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False)
if neg_inds.numel() != 0:
neg_inds = neg_inds.squeeze(1)
if len(neg_inds) <= num_expected: # If the number of negative samples is smaller than expected, it will be returned directly
return neg_inds
else:
return self.random_choice(neg_inds, num_expected)
summary
The next chapter will open model Module introduction , Coming soon .
边栏推荐
- 将线程绑定在某个具体的CPU逻辑内核上运行
- [nvme2.0b 14 - 5] commande de téléchargement / commande du logiciel
- 【NVMe2.0b 14-1】Abort、Asynchronous Event Request、Capacity Management command
- Gilbert Strang's course notes on linear algebra - Lesson 1
- Redis设计与实现(四)| 主从复制
- What are the Amazon evaluation terms?
- Interference source current spectrum test of current probe
- C# ListBox如何获取选中的内容(搜了很多无效的文章)
- Flink SQL 自定义 Connector
- PHP API to obtain QR code and combine to generate pictures
猜你喜欢

Redis design and Implementation (VI) | cluster (sharding)

电流探头的干扰源电流谱测试

Summary of common pytoch APIs

A troubleshooting of CPU bottom falling

电流探头电路分析

Redis设计与实现(八)| 事务

从0开始构建一个瀚高数据库Docker镜像

将线程绑定在某个具体的CPU逻辑内核上运行

1. Problems related to OpenGL window and environment configuration

C accesses mongodb and performs CRUD operations
随机推荐
Axure make menu bar effect
vite項目require語法兼容問題解决require is not defined
【NVMe2.0b 14】NVMe Admin Command Set
Source code interpretation of detectron2 1--engine
Bind threads to run on a specific CPU logical kernel
Understanding society at the age of 14 - reading notes on "happiness at work"
Icon resources
CUDA implements matrix replication
End-to-end 3D Point Cloud Instance Segmentation without Detection
C # get the current timestamp
Redis design and Implementation (I) | data structure & object
Swagger use
示波器探头对测量电容负荷有影响吗?
QT downloading files through URL
Deploy the cow like customer network project on the ECS
Flink sql -- No factory implements ‘org.apache.flink.table.delegation.ExecutorFactory‘.
VIM from dislike to dependence (21) -- cross file search
Gilbert Strang's course notes on linear algebra - Lesson 2
【NVMe2.0b 14-5】Firmware Download/Commit command
codeforces每日5题(均1700)-第三天