当前位置:网站首页>【Mixup】《Mixup:Beyond Empirical Risk Minimization》
【Mixup】《Mixup:Beyond Empirical Risk Minimization》
2022-07-02 06:26:00 【bryant_meng】
ICLR-2018
文章目录
1 Background and Motivation
现在的模型越来越强,但 memorization(死记硬背训练集,泛化能力不够) and sensitivity to adversarial examples(泛化能力不够)
作者从 Vicinal Risk Minimization (VRM) principle 出发,提出 mixup 数据增强方法(convex combinations of pairs of examples and their labels)以提升现有 SOTA 模型的泛化能力
Q:VRM 是啥?先从 Empirical Risk Minimization (ERM) principle 开始说起
简单来说,做机器学习任务时,我们无法获知数据的真实分布(eg 猫狗分类,万千世界的猫狗数据难以穷尽),所以我们无法最小化真实风险,只能最小化部分数据(来自真实世界的抽样)的风险(minimize their average error over the training data),也即最小化经验风险
the convergence of ERM is guaranteed as long as the size of the learning machine does not increase with the number of training data.
当数据一定,模型越做越大,基于 ERM 原则训练出来的模型会出现以下问题:
- memorize (instead of generalize from) the training data(开始死记硬背,过拟合了)
- trained with ERM change their predictions drastically when evaluated on examples just outside the training distribution,also known as adversarial examples(敏感了,泛化性能不够)
模型大了,与数据量不匹配,往往需要从真实分布中多抽样点数据,也就是我们面对过拟合时常采用的数据增广方法,formalized by the Vicinal Risk Minimization (VRM) principle
具体的,
作者提出了新的数据增广方式,mixup
2 Related Work
略
3 Advantages / Contributions
提出了 mixup 数据增强方法(故事引入的不错),improves the generalization of state-of-the-art neural network architectures
- reduces the memorization of corrupt labels
- increases the robustness to adversarial examples
- stabilizes the training of generative adversarial networks(GAN)
- improves generalization on speech and tabular data(这方面没了解过)
4 Method
其中 λ ∼ B e t a ( α , α ) \lambda \sim Beta(\alpha, \alpha) λ∼Beta(α,α)
code
B e t a ( α , β ) Beta(\alpha, \beta) Beta(α,β) 分布的概率密度函数如下
本文 α = β \alpha = \beta α=β
下面绘制一些本文出现的 α \alpha α 图像
部分代码
from scipy.stats import beta
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(0, 1, 100)
y = beta.pdf(x, 0.1, 0.1)
plt.plot(x, y, label="0.1")
plt.legend()
plt.show()
可以看到,概率密度函数是对称的, α = 1 \alpha=1 α=1 时, B e t a ( α , α ) Beta(\alpha, \alpha) Beta(α,α) 成了均匀分布
当 α → 0 \alpha \rightarrow 0 α→0,Beta 的概率密度函数趋向于0,抽样时 λ → 0 \lambda \rightarrow 0 λ→0,mixup 融合没有了,VRM 回退成 ERM
当 α → ∞ \alpha \rightarrow \infty α→∞,Beta 的概率密度函数趋向于 ∞ \infty ∞
作者发现
- mixup 数量多于 2 时效果没有进一步提升,但是计算代价增大了
- mixup 的两张图来自同一个 mini-batch,节约了 I/O
- mixup only 作用在 equal label 上效果没有明显提升(单类 mixup 效果不明显?)
What is mixup doing?
encourages the model f f f to behave linearly in-between training examples,特别是不同类,之前的数据增广基本都是基于同类的,mixup 引入了不同类别之间关系的先验,虽然只是最简单的线性关系。
mixup leads to decision boundaries that transition linearly from class to class,providing a smoother estimate of uncertainty
5 Experiments
5.1 Datasets and Metrics
数据集
- CIFAR-10 / CIFAR-100
- ImageNet
- UCI
- the Google commands dataset
评价指标
- top1-error
- top5-error
5.2 Experiments
1)ImageNet Classification
α ∈ \alpha \in α∈[0.1, 0.4] 的时候,mixup 比 ERM 好, for large α \alpha α, mixup leads to underfitting( α \alpha α 越大 λ \lambda λ 越趋向于取 0.5,两张图片重叠的更深入,越偏离原有数据,拟合难度提升)
模型越大,训练时间更长,mixup 发挥的作用更明显
2)CIFAR-10 and CIFAR-100
α \alpha α 被设置为了1,beta 分布此时等于均匀分布,也即 λ \lambda λ 取到 0~1 的任何值概率上是相同的
3)Speech data
LeNet 上没有 ERM 好,VGG 上比 ERM 好
4)Memorization of corrupted labels
CIFAR-10 数据集
α \alpha α 越大,mixup 融合的越深入,making memorization more difficult to achieve
明显,没有 mixup,过拟合很严重,把错误的信息都学到了(损坏的数据集上 training error 特别低)
mixup 和 dropout 还可以相互促进
5)Robustness to adversarival examples
ImageNet
penalizing the norm of the gradient of the loss(减少数据集中的不良振荡)
可以看到 mixup 的梯度范数更小
下面看看面对对抗样本的效果
明显 mixup 鲁棒很多
白盒攻击 (white box) 与黑盒攻击 (black box):
- 白盒攻击:被攻击模型的模型参数可以被获取;
- 黑盒攻击:被攻击模型的模型参数不可以被获取。
fast gradient sign method(FGSM)
以下关于 FGSM 的介绍和代码来自:对抗样本之FGSM实战
# FGSM attack code
def fgsm_attack(image, epsilon, data_grad):
# 使用sign(符号)函数,将对x求了偏导的梯度进行符号化
sign_data_grad = data_grad.sign()
# 通过epsilon生成对抗样本
perturbed_image = image + epsilon*sign_data_grad
# 做一个剪裁的工作,将torch.clamp内部大于1的数值变为1,小于0的数值等于0,防止image越界
perturbed_image = torch.clamp(perturbed_image, 0, 1)
# 返回对抗样本
return perturbed_image
def test( model, device, test_loader, epsilon ):
# 准确度计数器
correct = 0
# 对抗样本
adv_examples = []
# 循环所有测试集
for data, target in test_loader:
# Send the data and label to the device
data, target = data.to(device), target.to(device)
# Set requires_grad attribute of tensor. Important for Attack
data.requires_grad = True
# Forward pass the data through the model
output = model(data)
init_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
# If the initial prediction is wrong, dont bother attacking, just move on
if init_pred.item() != target.item():
continue
# Calculate the loss
loss = F.nll_loss(output, target)
# Zero all existing gradients
model.zero_grad()
# Calculate gradients of model in backward pass
loss.backward()
# Collect datagrad
data_grad = data.grad.data
# Call FGSM Attack
perturbed_data = fgsm_attack(data, epsilon, data_grad)
# Re-classify the perturbed image
output = model(perturbed_data)
...
6)Tabular data
用的是 UCI 机器学习数据集,表格形式
7)Stabilization of GAN
GAN
GAN + mixup
the stabilizing effect of mixup the training of GAN (orange samples) when modeling two toy datasets (blue samples).——黄色拟合蓝色
可以发现 mixup + GAN 更稳定
8)Ablation studies
探索下了 mixup 的不同形式
ERM a large weight decay works better, whereas for mixup a small weight decay is preferred
9)Discussion
increasingly large α \alpha α the training error on real data increases, while the generalization gap decreases.
increasing the model capacity would make training error less sensitive to large α \alpha α
6 Conclusion(own) / Future work
1)未来工作:
- 用在 regression 和 structure learning 上(eg 分割)
- 用在半监督 、无监督、深度强化学习上
2)源码
https://github.com/facebookresearch/mixup-cifar10/blob/main/train.py
def mixup_data(x, y, alpha=1.0, use_cuda=True):
'''Returns mixed inputs, pairs of targets, and lambda'''
if alpha > 0:
lam = np.random.beta(alpha, alpha)
else:
lam = 1
batch_size = x.size()[0]
if use_cuda:
index = torch.randperm(batch_size).cuda()
else:
index = torch.randperm(batch_size)
mixed_x = lam * x + (1 - lam) * x[index, :]
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
def mixup_criterion(criterion, pred, y_a, y_b, lam):
return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
for batch_idx, (inputs, targets) in enumerate(trainloader):
if use_cuda:
inputs, targets = inputs.cuda(), targets.cuda()
inputs, targets_a, targets_b, lam = mixup_data(inputs, targets,
args.alpha, use_cuda)
inputs, targets_a, targets_b = map(Variable, (inputs,
targets_a, targets_b))
outputs = net(inputs)
loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
3)全网最全:盘点那些图像数据增广方式Mosiac,MixUp,CutMix等.
所以当你的数据集类别比较多的时候, 用这种方式也许可以有效的区分一些难例, 但是并非所有的情况都能用 MixUp, 至少在只有一个类别的情况下, 我认为效果不会是很有效.
5)如何通俗理解 beta 分布? - 马同学的回答 - 知乎
Beta 分布有共轭先验的性质,也即
且
7)结构化预测
8)如何评价mixup: BEYOND EMPIRICAL RISK MINIMIZATION? - 张宏毅的回答
这些合成的 training data 的作用,流行的解释是“增强模型对某种变换的 invariance”。这句话反过来说,就是机器学习里经常提到的“减少模型估计的 variance”,也就是控制了模型的复杂度。
9)如何评价mixup: BEYOND EMPIRICAL RISK MINIMIZATION? - Zhanxing Zhu的回答
11)为什么mixup要使用beta分布?
观点一:为什么mixup要使用beta分布? - Sincere的回答
观点二:为什么mixup要使用beta分布? - 邹瑜亮的回答
这个回答醍醐灌顶!!!
12)Beta 分布抽样的困惑
如何评价mixup: BEYOND EMPIRICAL RISK MINIMIZATION? - 张宏毅的回答
13)《Manifold mixup: Better representations by interpolating hidden states》(ICML-2019)
简单涨点 | Flow-Mixup: 对含有损坏标签的多标签医学图像进行分类 (优于 Mixup 和 Maniflod Mixup)
边栏推荐
- PPT的技巧
- Oracle 11g uses ords+pljson to implement JSON_ Table effect
- 【信息检索导论】第七章搜索系统中的评分计算
- Determine whether the version number is continuous in PHP
- Agile development of software development pattern (scrum)
- iOD及Detectron2搭建过程问题记录
- 解决万恶的open failed: ENOENT (No such file or directory)/(Operation not permitted)
- Sparksql data skew
- Generate random 6-bit invitation code in PHP
- 离线数仓和bi开发的实践和思考
猜你喜欢
[Bert, gpt+kg research] collection of papers on the integration of Pretrain model with knowledge
How to efficiently develop a wechat applet
[introduction to information retrieval] Chapter 7 scoring calculation in search system
Yaml file of ingress controller 0.47.0
ModuleNotFoundError: No module named ‘pytest‘
[tricks] whiteningbert: an easy unsupervised sentence embedding approach
常见的机器学习相关评价指标
Regular expressions in MySQL
view的绘制机制(一)
Cognitive science popularization of middle-aged people
随机推荐
Faster-ILOD、maskrcnn_ Benchmark training coco data set and problem summary
MMDetection模型微调
Calculate the difference in days, months, and years between two dates in PHP
Faster-ILOD、maskrcnn_benchmark安装过程及遇到问题
[torch] some ideas to solve the problem that the tensor parameters have gradients and the weight is not updated
Alpha Beta Pruning in Adversarial Search
第一个快应用(quickapp)demo
机器学习理论学习:感知机
Oracle EBs and apex integrated login and principle analysis
A summary of a middle-aged programmer's study of modern Chinese history
ABM论文翻译
Use Baidu network disk to upload data to the server
Typeerror in allenlp: object of type tensor is not JSON serializable error
Spark SQL task performance optimization (basic)
使用Matlab实现:弦截法、二分法、CG法,求零点、解方程
MySQL composite index with or without ID
【Ranking】Pre-trained Language Model based Ranking in Baidu Search
Using MATLAB to realize: Jacobi, Gauss Seidel iteration
PointNet理解(PointNet实现第4步)
[introduction to information retrieval] Chapter 3 fault tolerant retrieval