当前位置:网站首页>YOLOv5的Tricks | 【Trick8】图片采样策略——按数据集各类别权重采样
YOLOv5的Tricks | 【Trick8】图片采样策略——按数据集各类别权重采样
2022-06-11 05:35:00 【Clichong】
如有错误,恳请指出。
这篇文章用来记录一下yolov5在训练过程中提出的一个图片采样策略,简单来说,就是根据图片的权重来决定其采样顺序。
1. 图片采样策略想法
- 图片采样策略想法
在我们训练数据集的时候,一般是对数据集随机采样几张图像然后构建成一个mini-batch来批量输入网络处理。个人猜想,一个可能的想法就是,这种随机的图像采集会不会过于随意,因为有些图像的目标是过少的,那么这种图像可能对网络来说比较简单;而有些图像的目标是比较多的,这种是比较困难的。而对于开始训练的初期就使用这种简答图像对网络的训练可能带来不了多大的学习提升。
所以,如果可以对数据集中的每张图像做一个权重的划分,在训练模型的时候依照图像的权重大小依次按难到易的大概顺序来进行训练,让模型从一开始的困难的样本较快的学习到潜在特征,到之后通过简单的图像样本来对参数进行微调,说不定是一个好的方法。
(以上内容是个人的思考猜测,可能是有误的,欢迎探讨。)
- 图片采样策略思路
那么具体的实现思路就是,对整个数据集的图像目标做类别统计,然后类别的数目越大权重越小(成反比的关系)。然后再使用整个数据集的类别权重对每一张图像做类别权重的叠加。也就是根据每一张的图片的类别权重和来作为采样的权重,决定其采用的顺序。在代码的实现中是从大到小排序的。
2. 图片采样策略代码
- yolov5参考代码
大概的注释都写在代码里了:
def train():
...
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
...
for epoch in range(start_epoch, epochs):
model.train()
# Update image weights (optional, single-GPU only)
if opt.image_weights:
# 根据数据集的类别数目构建每个类别的权重(类别权重与类别数目成反比)
cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
# 对每张图片的目标计算其类别权重和作为图片的采集权重
iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
# 再更具每张图片的采集权重来构建图片的采样顺序
dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
...
def labels_to_class_weights(labels, nc=80):
# Get class weights (inverse frequency) from training labels
# labels是当前数据集的训练集的所有图像: {list: 682}
# 列表的每个对象格式是: (ndarray: (k, 5)) k表示当前图像的目表个数, 5是(class+xywh)
if labels[0] is None: # no labels loaded
return torch.Tensor()
# 把图像的标签列表直接转化为标签列表:{ndarray: (labels, 5)} labels表示全部图像的所有标签个数
labels = np.concatenate(labels, 0) # labels.shape = (866643, 5) for COCO
# 提取类别 labels[:, 0] 数据来为每一类做统计 .astype(np.int): 取整
classes = labels[:, 0].astype(np.int) # labels = [class xywh]
# weight: 统计每个类别出现的次数
weights = np.bincount(classes, minlength=nc) # occurrences per class
# Prepend gridpoint count (for uCE training)
# gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image
# weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start
# 将出现次数为0的类别权重全部取1
weights[weights == 0] = 1 # replace empty bins with 1
# 类别权重取类别出现次数的倒数, 也就是表示类别次数与权重成反比, 标签频率越高的类别权重越低, 因为越不罕见
weights = 1 / weights # number of targets per class
# 归一化操作: 求出每一类别的占比
weights /= weights.sum() # normalize
return torch.from_numpy(weights) # numpy -> tensor
def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
# Produces image weights based on class_weights and image contents
# out:{ndarray: (682,3)} 统计每一张图片中类类别的数目 这里我用的是mask数据集有3个类别 每个位置存储图像中对应类别目标出现的个数
class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
# class_weights:[n_class] -> [1, n_class]
# 每张图片的每个类别个数[label_nums, n_class] * 整个数据集每个类别的权重[1, n_class] = 每张图片的对应每个类别的权重[label_nums, n_class_weight]
# 然后每个类别的权重加在一起等于当前这张图片的权重
image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
# index = random.choices(range(n), weights=image_weights, k=1) # weight image sample
return image_weights
- 构造Dataset使用的地方
class LoadImagesAndLabels(Dataset):
def __init__(self, img_size=640, batch_size=16, image_weights=False, ...):
...
self.indices = range(n)
def __len__(self):
return len(self.img_files)
def __getitem__(self, index):
# 重点使用部分, 就是用权重采样策略替代了随机采样
# 随机采样: index返回的是随机值(shuffle = True),所以注意到其实在
# 权重采样: index是按顺序从0开始, 然后依次提取indices所指向的图像索引
index = self.indices[index] # linear, shuffled, or image_weights
img, labels = load_mosaic(self, index)
...
return torch.from_numpy(img), labels_out, self.img_files[index], shapes
# 因为可以注意到, 构建dataloader的时候yolov5代码中是没有使用shuffle=True这个随机采样的参数的
def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix=''):
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache
with torch_distributed_zero_first(rank):
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
augment=augment, # augment images
hyp=hyp, # augmentation hyperparameters
rect=rect, # rectangular training
cache_images=cache,
single_cls=single_cls,
stride=int(stride),
pad=pad,
image_weights=image_weights,
prefix=prefix)
batch_size = min(batch_size, len(dataset))
# 这里对num_worker进行更改
# nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, workers]) # number of workers
nw = 0 # 可以适当提高这个参数0, 2, 4, 8, 16…
sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
# Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
# 没有使用 shuffle=True 这个参数
dataloader = loader(dataset,
batch_size=batch_size,
num_workers=nw,
sampler=sampler,
pin_memory=True,
collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn)
return dataloader, dataset
所以从代码中可以看见,如果不使用图像采样策略,这里也不会使用随机的选择策略,而且index从0开始提取,验证如下:
第一次断点调试:index从0开始,想法验证成功

参考资料:
边栏推荐
- Intercept file extension
- Click the icon is not sensitive how to adjust?
- Multi thread tutorial (XXIX) immutable design
- Section III: structural characteristics of cement concrete pavement
- getBackgroundAudioManager控制音乐播放(类名的动态绑定)
- Topological sorting
- Big meal count (time complexity) -- leetcode daily question
- wxParse解析iframe播放视频
- 27. Remove elements
- 自定义View之基础篇
猜你喜欢

Getbackgroundaudiomanager controls music playback (dynamic binding of class name)

In the future, how long will robots or AI have human creativity?

自定义View之基础篇

微信小程序,购买商品属性自动换行,固定div个数,超出部分自动换行

KD-Tree and LSH

Analyzing while experimenting - memory leakage caused by non static inner classes

Analyze while doing experiments -ndk article -jni uses registernatives for explicit method registration

White Gaussian noise (WGN)

Exploration of kangaroo cloud data stack on spark SQL optimization based on CBO

Restoration of binary tree -- number restoration
随机推荐
Click the icon is not sensitive how to adjust?
Manually splicing dynamic JSON strings
Recursively process data accumulation
JVM tuning 6: GC log analysis and constant pool explanation
20多种云协作功能,3分钟聊透企业的数据安全经
NDK learning notes (VII) system configuration, users and groups
Maximum number of points on the line ----- hash table solution
截取文件扩展名
SwiftUI: Navigation all know
Traversal of binary tree -- restoring binary tree by two different Traversals
WinForm (I) introduction to WinForm and use of basic controls
袋鼠云数栈基于CBO在Spark SQL优化上的探索
Vins fusion GPS fusion part
NDK learning notes (IX) POSIX sockect connection oriented communication
JS promise, async, await simple notes
Number of atoms (easy to understand)
code
Multi thread tutorial (30) meta sharing mode
Wechat applet uploads the data obtained from database 1 to database 2
27、移除元素