当前位置:网站首页>数据增强Mixup原理与代码解读
数据增强Mixup原理与代码解读
2022-08-05 02:32:00 【00000cj】
paper:mixup: Beyond Empirical Risk Minimization
存在的问题
- 经验风险最小化(Empirical Risk Minimization, ERM)允许大型神经网络强行记住训练数据(而不是去学习、泛化),即使加了很强的正则化,或是在随机分配标签的分类问题中,这个问题也依然存在。
- 使用ERM原则训练的神经网络,当在训练样本分布之外的数据上进行评估时,预测结果会发生显著的变化,这被称为对抗性样本。
解决这个问题的一个方法是邻域风险最小化(Vicinal Risk Minimization, VRM),即通过数据增强在原始样本的基础上构造更多的样本,但数据增强中需要人类知识来描述训练数据中每个样本的邻域,比如翻转、缩放等。因此VRM也有两点不足
- 数据增强过程依赖数据集,因此需要专家知识
- 数据增强只建模同一类别之间的邻域关系
Mix-up
针对上述问题,本文提出一种data-agnostic的数据增强方法mixup,
![]()
其中\(x_{i},x_{j}\)是从训练集中随机挑选的两张图像,\(y_{i},y_{j}\)是对应的one-hot标签,通过先验知识:特征向量的线性插值和对应目标的线性插值还是对应的关系,构造了新的样本\((\widetilde{x},\widetilde{y})\)。其中\(\lambda\)通过\(\beta(\alpha, \alpha)\)分布获得,\(\alpha\)是超参。
此外,作者提到了一些通过实验得到的结论
- 通过实验发现三个或三个以上样本的组合不能带来进一步的精度提升,反而会增加计算成本。
- 作者的实现方法是通过一个单独的data loader获得一个batch的数据,然后在random shuffle后对这一个batch内的数据使用mixup,作者发现这种策略的效果很好,同时减少了I/O。
- 只对相同类别的样本进行mixup并不会带来精度的提升。
实现
torchvision版本
这里通过roll方法将batch内的图片向后平移一个,然后与原batch进行mixup,相当于batch内的每张图片都和相邻的一张进行mixup,roll方法详见
class RandomMixup(torch.nn.Module):
"""Randomly apply Mixup to the provided batch and targets.
The class implements the data augmentations as described in the paper
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
Args:
num_classes (int): number of classes used for one-hot encoding.
p (float): probability of the batch being transformed. Default value is 0.5.
alpha (float): hyperparameter of the Beta distribution used for mixup.
Default value is 1.0.
inplace (bool): boolean to make this transform inplace. Default set to False.
"""
def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
super().__init__()
if num_classes < 1:
raise ValueError(
f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}"
)
if alpha <= 0:
raise ValueError("Alpha param can't be zero.")
self.num_classes = num_classes
self.p = p
self.alpha = alpha
self.inplace = inplace
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
batch (Tensor): Float tensor of size (B, C, H, W)
target (Tensor): Integer tensor of size (B, )
Returns:
Tensor: Randomly transformed batch.
"""
if batch.ndim != 4:
raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
if target.ndim != 1:
raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
if not batch.is_floating_point():
raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
if target.dtype != torch.int64:
raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
if not self.inplace:
batch = batch.clone()
target = target.clone()
if target.ndim == 1:
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
if torch.rand(1).item() >= self.p:
return batch, target
# It's faster to roll the batch by one instead of shuffling it to create image pairs
batch_rolled = batch.roll(1, 0)
target_rolled = target.roll(1, 0)
# Implemented as on mixup paper, page 3.
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
batch_rolled.mul_(1.0 - lambda_param)
batch.mul_(lambda_param).add_(batch_rolled)
target_rolled.mul_(1.0 - lambda_param)
target.mul_(lambda_param).add_(target_rolled)
return batch, target
def __repr__(self) -> str:
s = (
f"{self.__class__.__name__}("
f"num_classes={self.num_classes}"
f", p={self.p}"
f", alpha={self.alpha}"
f", inplace={self.inplace}"
f")"
)
return smmclassification版本
这里是通过randperm将batch内的图片打乱,然后与原batch进行mixup,并且得到\(\lambda\)的方法与torchvision也不一样。
class BatchMixupLayer(BaseMixupLayer):
r"""Mixup layer for a batch of data.
Mixup is a method to reduces the memorization of corrupt labels and
increases the robustness to adversarial examples. It's
proposed in `mixup: Beyond Empirical Risk Minimization
<https://arxiv.org/abs/1710.09412>`
This method simply linearly mix pairs of data and their labels.
Args:
alpha (float): Parameters for Beta distribution to generate the
mixing ratio. It should be a positive number. More details
are in the note.
num_classes (int): The number of classes.
prob (float): The probability to execute mixup. It should be in
range [0, 1]. Default sto 1.0.
Note:
The :math:`\alpha` (``alpha``) determines a random distribution
:math:`Beta(\alpha, \alpha)`. For each batch of data, we sample
a mixing ratio (marked as :math:`\lambda`, ``lam``) from the random
distribution.
"""
def __init__(self, *args, **kwargs):
super(BatchMixupLayer, self).__init__(*args, **kwargs)
def mixup(self, img, gt_label):
one_hot_gt_label = one_hot_encoding(gt_label, self.num_classes)
lam = np.random.beta(self.alpha, self.alpha)
batch_size = img.size(0)
index = torch.randperm(batch_size)
mixed_img = lam * img + (1 - lam) * img[index, :]
mixed_gt_label = lam * one_hot_gt_label + (
1 - lam) * one_hot_gt_label[index, :]
return mixed_img, mixed_gt_label
def __call__(self, img, gt_label):
return self.mixup(img, gt_label)目标检测中的mixup
在文章Bag of Freebies for Training Object Detection Neural Networks 中,对两张图片mixup后只是合并了两张图中的所有gt box,并没有对类别标签进行mixup。但文章提到"weighted loss indicates the overall loss is the summation of multiple objects with ratio 0 to 1 according to image blending ratio they belong to in the original training images",即在计算loss时对每个物体的loss按mixup时的系数进行加权求和。

参考
边栏推荐
- RAID disk array
- lua learning
- [Fortune-telling-60]: "The Soldier, the Tricky Way"-2-Interpretation of Sun Tzu's Art of War
- hypervisor相关的知识点
- 力扣-相同的树
- 02 [Development Server Resource Module]
- Images using redis cache Linux master-slave synchronization server hard drive full of moved to the new directory which points to be modified
- Access Characteristics of Constructor under Inheritance Relationship
- Common hardware delays
- 【解密】OpenSea免费创造的NFT都没上链竟能出现在我的钱包里?
猜你喜欢

Jincang database KingbaseES V8 GIS data migration solution (3. Data migration based on ArcGIS platform to KES)

The 2022 EdgeX China Challenge will be grandly opened on August 3

LPQ (local phase quantization) study notes

iNFTnews | What can NFTs bring to the sports industry and fans?

浅谈数据安全治理与隐私计算

RAID disk array

DAY23: Command Execution & Code Execution Vulnerability

使用SuperMap iDesktopX数据迁移工具迁移地图文档和符号
![[LeetCode Brush Questions] - Sum of Numbers topic (more topics to be added)](/img/ee/6b52072c841af99488dc0c1141c74c.png)
[LeetCode Brush Questions] - Sum of Numbers topic (more topics to be added)

.Net C# 控制台 使用 Win32 API 创建一个窗口
随机推荐
常见的硬件延迟
ARM Mailbox
Access Characteristics of Constructor under Inheritance Relationship
LeetCode使用最小花费爬楼梯----dp问题
View handler stepping record
使用SuperMap iDesktopX数据迁移工具迁移ArcGIS数据
2022 EdgeX中国挑战赛8月3日即将盛大开幕
基于左序遍历的数据存储实践
Advanced Numbers_Review_Chapter 1: Functions, Limits, Continuity
【C语言】详解栈和队列(定义、销毁、数据的操作)
hypervisor相关的知识点
继承关系下构造方法的访问特点
Quickly learn chess from zero to one
Fragment visibility judgment
Programmer's list of sheep counting when insomnia | Daily anecdote
树表的查找
.Net C# Console Create a window using Win32 API
Dotnet 6 Why does the network request not follow the change of the system network proxy and dynamically switch the proxy?
Optimizing the feed flow encountered obstacles, who helped Baidu break the "memory wall"?
力扣-相同的树