当前位置:网站首页>mmdetection之dataloader构建
mmdetection之dataloader构建
2022-06-10 16:54:00 【武乐乐~】
前言
本篇将介绍mmdetection如何构建dataloader类的。dataloader主要控制数据集的迭代读取。与之配套的是首先实现dataset类。关于dataset类的实现请转mmdetection之dataset类构建。
1、总体流程
在pytorch中,Dataloader实例构建需要以下重要参数(截取dataloader源码)。
Arguments:
dataset (Dataset): dataset from which to load the data.
batch_size (int, optional): how many samples per batch to load
(default: ``1``).
shuffle (bool, optional): set to ``True`` to have the data reshuffled
at every epoch (default: ``False``).
sampler (Sampler or Iterable, optional): defines the strategy to draw
samples from the dataset. Can be any ``Iterable`` with ``__len__``
implemented. If specified, :attr:`shuffle` must not be specified.
batch_sampler (Sampler or Iterable, optional): like :attr:`sampler`, but
returns a batch of indices at a time. Mutually exclusive with
:attr:`batch_size`, :attr:`shuffle`, :attr:`sampler`,
and :attr:`drop_last`.
num_workers (int, optional): how many subprocesses to use for data
loading. ``0`` means that the data will be loaded in the main process.
(default: ``0``)
collate_fn (callable, optional): merges a list of samples to form a
mini-batch of Tensor(s). Used when using batched loading from a
map-style dataset.
简单介绍下各个参数含义:
dataset:就是继承Dataset类的实例;
batch_size: 批次大小
shuffle: True.在开始新的一轮epoch时,是否会重新打乱数据
sampler:迭代器:里面存储着数据集的下标(可能被打乱/顺序)。是迭代器。
batch_samper: 迭代sampler中下标,然后根据下标去dataset中取出batch_size个数据。
collate_fn:将batch个数据整合进一个list,调整宽和高。
可能上面几个参数定义有点而蒙。没关系,只需记住dataset,sampler,batch_sampler,dataloader均是迭代器即可。至于迭代器:理解为可以被 for … in dataset:使用即可。
既然Dataloader主要参数有了,那么现在看下mmdetection中如何build_dataloader的。接下来我打算分两部分进行讲解:
(1)如何实例化一个dataloader对象。如下图所示:mmdetection中主要实现下边四个参数。GroupSamper继承自torch的sampler类。shuffle大多数都是True。而batch_sampler参数mmdetection使用是pytorch中已实现的BatchSampler类。
(2)读取一个batch数据流程。
2、实例化dataloader
2.1. GroupSampler类实现
dataset的实现请转dataset类构建。这里我贴下GroupSampler源码:
class GroupSampler(Sampler):
def __init__(self, dataset, samples_per_gpu=1):
assert hasattr(dataset, 'flag')
self.dataset = dataset
self.samples_per_gpu = samples_per_gpu
self.flag = dataset.flag.astype(np.int64) #
self.group_sizes = np.bincount(self.flag) # np.bincount()函数统计 下标01出现的次数。
self.num_samples = 0
for i, size in enumerate(self.group_sizes):
self.num_samples += int(np.ceil(
size / self.samples_per_gpu)) * self.samples_per_gpu
def __iter__(self):
indices = []
for i, size in enumerate(self.group_sizes): # self.group_sizes = [942,4096] ;其中942代表长度比例<1的图像数量;
if size == 0:
continue
indice = np.where(self.flag == i)[0] # 提取出self.flag中等于当前i的下标。 self.flag顺序存储着训练集中所有图像的aspect-ratio
assert len(indice) == size
np.random.shuffle(indice) # 这里将下标打乱了
num_extra = int(np.ceil(size / self.samples_per_gpu)
) * self.samples_per_gpu - len(indice)
indice = np.concatenate(
[indice, np.random.choice(indice, num_extra)])
indices.append(indice)
indices = np.concatenate(indices) # 合并陈一个list,长度为5011的
indices = [ # 按照batch将list划分:若batch=1,则将列表划分成长度为5011的数组。
indices[i * self.samples_per_gpu:(i + 1) * self.samples_per_gpu]
for i in np.random.permutation(
range(len(indices) // self.samples_per_gpu))
]
indices = np.concatenate(indices)
indices = indices.astype(np.int64).tolist()
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
其实主要实现了__iter__方法使其成为一个迭代器。而大致思路就是:假如我有一个5000张图像的数据集。那么数据集下标是0~4999.通过np.random.shuffle打乱5000个下标。假如batch是2,则共得到2500对。将这2500对以数组形式存于indices这个list中。最终通过iter(indices)迭代。
2.2. BatchSampler类
这部分mmdetection使用的是pytorch源码。我贴下源码:
class BatchSampler(Sampler[List[int]]):
def __init__(self, sampler: Sampler[int], batch_size: int, drop_last: bool) -> None:
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self):
# Can only be called if self.sampler has __len__ implemented
# We cannot enforce this condition, so we turn off typechecking for the
# implementation below.
# Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
if self.drop_last:
return len(self.sampler) // self.batch_size # type: ignore
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore
从源码可以看出:BatchSampler以sampler初始化的。同时也实现了__iter__方法,每迭代够一个batch,则借助生成器yield batch,即返回一个batch数据。
3、读取一个batch数据流程
这里我想用张图说明下:文字不易描述:
总结
本文主要介绍mmdetection如何通过实现dataset,sampler来构造一个Dataloader,另外,展示了dataloader内部是如何迭代每个批次数据的。若有问题欢迎+vx:wulele2541612007,拉你进群探讨交流。
边栏推荐
- Online communication skill network: a sparse model for solving multi task and multi-modal problems (Qingyuan talk, issue 19, tangduyu)
- Fabric. Keep the original level when JS element is selected
- Draw confusion matrix
- How will you integrate into the $20trillion "project economy" in five years
- When V-IF and V-for need to be used at the same time
- 品牌难立,IPO难行,中国茶企困于“传统”?
- matplotlib plt. Specific usage of text() - labeling points in a drawing
- Fabric. JS activation input box
- 高数_第6章无穷级数__正项级数的性质
- 【玩转华为云】鲲鹏DevKit迁移实战
猜你喜欢

仅需三步学会使用低代码ThingJS与森数据DIX数据对接

For more than 20 years, there are only Durex, Okamoto and jasbon in the condom market

亟需丰富智能家居产品线,扫地机器人赛道上挤得下萤石吗?

Internet enterprises and chips

See how advanced technology changes human life

嘿!ONES 新星请看过来|师兄师姐说

企鹅电竞停步,虎牙也难行

Redis general instruction

2022年G2电站锅炉司炉考试模拟100题及模拟考试

带你初步了解 类和对象 的基本机制
随机推荐
几个对程序员的误解,害人不浅!
com.netflix.client.ClientException: Load balancer does not have available server for client: userser
单片机底层通信协议① —— 同步和异步、并行和串行、全双工和半双工以及单工、电平信号和差分信号
KDD 2021 | MoCl: comparative learning of molecular graphs using multi-level domain knowledge
2022年G2电站锅炉司炉考试模拟100题及模拟考试
Leetcode String to integer(Atoi)
OpenJudge NOI 1.13 15:求序列中的众数
[play with Huawei cloud] Kunpeng devkit migration practice
C# 根据EXCEL自动生成oracle建表语句
基于DeepFace模型设计的人脸识别软件
Cap version 6.1 Release Notice
[BSP video tutorial] BSP video tutorial issue 17: single chip microcomputer bootloader topic, startup, jump configuration and various usage of debugging and downloading (2022-06-10)
pands pd. Detailed parsing of dataframe() function
2022年T电梯修理考试题模拟考试题库及在线模拟考试
B站不想成为“良心版爱优腾”
Leetcode 929. 独特的电子邮件地址
2022年茶艺师(中级)操作证考试题库及模拟考试
Swift 3pThread tool Promise Pipeline Master/Slave Serial Thread confinement Serial queue
线上交流丨技能网络:解决多任务多模态问题的稀疏模型(青源Talk第19期 唐都钰)
4. ssh