当前位置:网站首页>DeiT学习笔记
DeiT学习笔记
2022-07-07 05:23:00 【麻花地】
DeiT学习笔记
Training data-efficient image transformers & distillation through attention
Abstract
最近,纯粹基于注意力的神经网络被证明可以解决图像理解任务,例如图像分类。这些高性能的视觉transformers使用大型基础设施预先训练了数亿张图像,因此限制了其采用。
在这项工作中,我们仅通过在Imagenet上进行训练来生产具有竞争力的无卷积变压器。我们用不到3天的时间在一台电脑上训练他们。我们的参考视觉transformers(86M参数)在没有外部数据的情况下,在ImageNet上达到了83.1%(单作物)的最高精度。
更重要的是,我们引入了一种针对变形金刚的师生策略。它依赖于一个蒸馏令牌,确保学生通过注意力从老师那里学习。我们展示了这种基于令牌的蒸馏的兴趣,尤其是当使用convnet作为教师时。这使得我们报告的结果在Imagenet(我们获得高达85.2%的准确性)和转移到其他任务时都与ConvNet竞争。我们分享我们的代码和模型。
我们提供了我们方法的开源实现。可在以下网址获得:https://github.com/facebookresearch/deit.
1 Introduction
卷积神经网络已经成为图像理解任务的主要设计范式,这一点最初在图像分类任务中得到了证明。他们成功的一个因素是提供了一个大型训练集,即Imagenet[13,42]。受自然语言处理中基于注意力的模型的成功[14,52]的推动,人们对利用ConvNet内注意力机制的架构越来越感兴趣[2,34,61]。最近,一些研究人员提出了混合架构,将transformer成分移植到ConvNet中,以解决视觉任务[6,43]。
Dosovitskiy等人[15]介绍的视觉变换器(ViT)是一种直接继承自自然语言处理[52]的架构,但应用在以原始图像块为输入进行图像分类。他们的论文展示了使用大型私有标记图像数据集(JFT-300M[46],3亿张图像)训练的变压器的出色结果。论文得出结论,transformers“在数据量不足的情况下,无法很好地进行泛化”,这些模型的训练涉及大量计算资源。
在本文中,我们在一个8-GPU节点上用两到三天的时间(53小时的预训练,以及可选的20小时的微调)训练视觉变换器,这与具有相似数量的参数和效率的ConvNet相竞争。它使用Imagenet作为唯一的训练集。我们基于Dosovitskiy等人[15]的视觉变换架构和timm库[55]中的改进。使用我们的数据高效图像变换器(DeiT),我们报告了比以前结果更大的改进,见图1。我们的消融研究详细介绍了成功训练的超参数和关键因素,例如重复增强。
我们要解决另一个问题:如何提取这些模型?我们介绍了一种基于令牌的策略,具体针对transformers,由DeiT表示, 并表明它有利于取代通常的蒸馏。
总之,我们的工作做出了以下贡献:
1)我们表明,我们的神经网络不包含卷积层,在没有外部数据的情况下,可以在ImageNet上实现与最先进技术相比的竞争结果。它们在三天内在具有4个GPU的单个节点上学习1。我们的两个新模型DeiT-S和DeiT-Ti的参数更少,可以看作是ResNet-50和ResNet-18的对应物
2)我们介绍了一种新的基于蒸馏令牌的蒸馏过程,该过程与类令牌的作用相同**,只是它旨在再现教师估计的标签。这两个令牌通过注意力在transformer中进行交互。**这种特定于变压器的策略显著优于 vanilla distillation。
3)有趣的是,通过我们的蒸馏,图像transformers从convnet中学到的东西比从另一个具有类似性能的transformers中学到的要多。
4)我们在Imagenet上预先学习的模型在以下几个流行的公共基准上转移到不同的下游任务(如细粒度分类)时具有竞争力:CIFAR-10、CIFAR-100、Oxford-102 flowers、Stanford Cars和iNaturalist-18/19。
本文的组织结构如下:我们回顾了第2节中的相关工作,并在第3节中重点介绍了用于图像分类的变换器。我们在第4节介绍了变压器的蒸馏策略。实验第5节提供了对ConvNet和最近的transformer的分析和比较,以及对我们的transformer特定蒸馏的比较评估。第6节详细介绍了我们的培训计划。它包括对我们的数据高效训练选择的广泛消融,这使我们对DeiT中涉及的关键成分有了一些了解。我们在第7节中得出结论。
2 Related work
Image Classification
它是计算机视觉的核心,经常被用作衡量图像理解进展的基准。任何进展通常转化为其他相关任务(如检测或分割)的改进。自2012年的AlexNet[32]以来,ConvNet一直主导着这一基准,并已成为事实上的标准。ImageNet数据集[42]的最新发展反映了卷积神经网络架构和学习[32、44、48、50、51、57]的进展。
尽管多次尝试使用Transformer进行图像分类[7],但到目前为止,它们的性能一直不如ConvNet。然而,结合ConvNet和transformers的混合架构,包括自注意力机制,最近在图像分类[56]、检测[6,28]、视频处理[45,53]、无监督物体发现[35]和统一文本视觉任务[8,33,37]方面显示出了竞争性的结果。
最近,视觉变换器(ViT)[15]在不使用任何卷积的情况下,缩小了与ImageNet上最先进技术的差距。这种性能是显著的,因为用于图像分类的convnet方法受益于多年的调整和优化[22,55]。然而,根据这项研究[15],为了使学习到的变压器有效,需要对大量经过整理的数据进行预训练。在本文中,我们在不需要大型训练数据集的情况下实现了强大的性能,即仅使用Imagenet1k。
The Transformer architecture,
Vaswani等人[52]介绍的机器翻译目前是所有自然语言处理(NLP)任务的参考模型。ConvNet在图像分类方面的许多改进都受到transformers的启发。例如,Squeeze and Excitation[2]、Selective Kernel[34]和 Split-Attention Networks[61]利用了类似于transformers自注意力(SA)机制的机制。
Knowledge Distillation (KD)
Hinton等人[24]介绍了一种训练范式,**其中学生模型利用来自强大教师网络的“软”标签。这是教师softmax函数的输出向量,而不仅仅是分数的最大值,它给出了一个“硬”标签。**这样的训练可以提高学生模型的性能(或者,可以将其视为将教师模型压缩为较小模型(即学生模型)的一种形式)。一方面,教师的软标签将具有与标签平滑类似的效果[58]。另一方面,如Wei等人[54]所示,**教师的监督考虑了数据增加的影响,这有时会导致真实标签和图像之间的错位。**例如,让我们考虑一个带有“猫”标签的图像,该标签表示一个大景观和一个角落中的一只小猫。如果cat不再处于数据增强的裁剪中,它会隐式地更改图像的标签。KD可以在使用教师模型的学生模型中以软方式传递归纳偏差[1],其中归纳偏差将以硬方式合并。例如,通过使用卷积模型作为教师,可以在变压器模型中诱导由于卷积引起的偏差。在本文中,我们研究了convnet或transformer教师对transformer学生的蒸馏。我们介绍了一种针对变压器的新蒸馏程序,并展示了其优越性。
3 Vision transformer: overview
略
4 Distillation through attention
在本节中,我们假设我们可以使用强大的图像分类器作为教师模型。它可以是convnet,也可以是分类器的混合。我们解决了如何利用这个老师学习变压器的问题。正如我们将在第5节中通过比较精度和图像吞吐量之间的权衡所看到的,用变压器代替卷积神经网络是有益的。本节涵盖两个蒸馏轴:hard distillation versus soft distillation, and classical distillation versus the distillation token.(硬蒸馏与软蒸馏,以及经典蒸馏与蒸馏令牌)。
Soft distillation
[24,54]最小化了教师模型的softmax和学生模型的softmax之间的Kullback-Leibler发散。
假设 Z t Z_t Zt是教师模型的logits, Z s Z_s Zs是学生模型的logits。我们用τ表示蒸馏温度,λ表示平衡地面真值标签y上的Kullback-Leibler发散损失(KL)和交叉熵( L C E L_{CE} LCE)的系数,ψ表示softmax函数。蒸馏的目标是
Hard-label distillation.
我们介绍了一种蒸馏的变体,其中我们将老师的hard decision视为真正的标签。假设 y t = a r g m a x c Z t ( c ) y_t=argmax_cZ_t(c) yt=argmaxcZt(c)是教师的艰难决定,与此硬标签蒸馏相关的目标是:
对于给定的图像,与教师相关联的硬标签可能会根据特定的数据增加而变化。我们将看到,这种选择优于传统选择,同时无参数且概念更简单:教师预测 y t y_t yt与真正的标签y起着相同的作用。
还请注意,硬标签也可以通过标签平滑[47]转换为软标签,其中真实标签的概率为1− ε、 剩下的ε在剩下的类中共享。在我们所有使用真实标签的实验中,我们将该参数固定为ε=0.1。
Distillation token.
我们现在关注我们的提议,如图2所示。我们在初始嵌入(补丁和类令牌)中添加了一个新令牌,即蒸馏令牌。我们的蒸馏令牌与类令牌类似:它通过自注意力与其他嵌入交互,并在最后一层之后由网络输出。其目标由损失的蒸馏成分给出。蒸馏嵌入允许我们的模型从教师的输出中学习,就像在常规蒸馏中一样,同时与类嵌入保持互补。
有趣的是,我们观察到 class and distillation tokens (学习类和蒸馏令牌)收敛到不同的向量:这些令牌之间的平均余弦相似性等于0.06。随着在每一层计算类和蒸馏嵌入,它们通过网络逐渐变得更相似,一直到最后一层,它们的相似性很高(cos=0.93),但仍然低于1。这是意料之中的,因为它们旨在产生相似但不完全相同的目标。
我们验证了我们的蒸馏令牌向模型中添加了一些东西,而不是简单地添加与相同目标标签相关联的额外 class token:我们使用具有两个 class token的transformer来代替教师伪标签。即使我们随机且独立地初始化它们,在训练过程中,它们会收敛到相同的向量(cos=0.999),并且输出嵌入也是准相同的。这个额外的类令牌不会给分类性能带来任何影响。**相比之下,我们的蒸馏策略比原始蒸馏基线有显著改进,**这已通过我们在第5.2节中的实验进行了验证。
Fine-tuning with distillation.
我们在更高分辨率的微调阶段使用真实标签和教师预测。我们使用具有相同目标分辨率的教师,通常通过Touvron等人[50]的方法从低分辨率教师处获得。我们也只测试了真正的标签,但这降低了教师的利益,并导致较低的表现。
Classification with our approach: joint classifiers.
在测试时,transformer生成的类或蒸馏嵌入都与线性分类器关联,并且能够推断图像标签。然而,我们的参考方法是这两个独立头部的后期融合,为此我们添加了两个分类器的softmax输出以进行预测。我们在第5节中评估了这三个选项。
5 Experiments
本节介绍了一些分析实验和结果。我们首先讨论我们的蒸馏策略。然后比较分析了ConvNet和视觉变换器的效率和准确性。
5.1 Transformer models
如前所述,我们的架构设计与Dosovitskiy等人[15]提出的架构设计相同,没有卷积。我们唯一的区别是训练策略和升华令牌。此外,我们不使用MLP头进行预训练,只使用线性分类器。为了避免任何混淆,我们参考了ViT在先前工作中获得的结果,并引用了DeiT的前缀。如果没有指定,DeiT指的是我们的参考模型DeiT-B,它与ViT-B具有相同的架构。当我们以更大的分辨率微调DeiT时,我们在末尾附加产生的操作分辨率,例如DeiT-B↑最后,当使用我们的蒸馏过程时,我们用一个alembic符号将其标识为DeiT.
ViT-B(因此也包括DeiT-B)的参数固定为D=768、h=12和D=D/h=64。我们引入了两个较小的模型,即DeiT-S和DeiT-Ti,对于它们,我们改变头的数量,保持d不变。表1总结了我们在本文中考虑的模型。
5.2 Distillation
我们的蒸馏方法产生的视觉转换器在精度和吞吐量之间的权衡方面与最佳ConvNet不相上下,见表5。有趣的是,在准确性和吞吐量之间的权衡方面,蒸馏模型优于其老师。我们在ImageNet-1k上的最佳模型是85.2%,top-1精度优于在JFT-300M上以384分辨率预训练的最佳Vit-B模型(84.15%)。作为参考,通过在JFT-300M上以512分辨率训练的ViTH模型(600M参数)获得了**88.55%**的额外训练数据。此后,我们将提供一些分析和观察结果。
Convnets teachers.
我们观察到,使用convnet教师比使用transformer具有更好的性能。表2比较了不同教师架构的蒸馏结果。正如Abnar等人[1]所解释的那样,convnet是一个更好的老师,这可能是由于变压器通过蒸馏继承的电感偏置。在我们随后的所有蒸馏实验中,默认教师是RegNetY-16GF[40](84M参数),我们使用与DeiT相同的数据和相同的数据增强进行训练。该教师在ImageNet上达到82.9%的top-1准确率。
Comparison of distillation methods.
我们在表3中比较了不同蒸馏策略的性能。即使仅使用类别标记,硬蒸馏也明显优于软蒸馏:硬蒸馏在分辨率224×224时达到83.0%,而软蒸馏精度为81.8%。第4节中的蒸馏策略进一步提高了性能,表明两个令牌提供了对分类有用的补充信息:两个令牌上的分类器明显优于独立类和蒸馏分类器,独立类和蒸馏分类器本身已经优于蒸馏基线。
蒸馏令牌的结果略好于类令牌。它还与convnets预测更相关。这种性能上的差异可能是因为它更受益于ConvNet的归纳偏差。我们将在下一段中提供更多细节和分析。蒸馏令牌对于初始训练具有不可否认的优势。
Agreement with the teacher & inductive bias?
如上所述,教师的架构具有重要影响。它是否继承了现有的有利于训练的归纳偏见?虽然我们认为很难正式回答这个问题,但我们在表4中分析了convnet教师、我们的图像转换器DeiT(仅从标签学习)和我们的transformer DeiT(仅从标签学习)之间的决策一致性.
我们提取的模型与convnet的相关性比与从头学习的变压器的相关性更高。正如预期的那样,与蒸馏嵌入相关联的分类器更接近与类嵌入相关联的convnet,相反,与类嵌入相关联的分类器更类似于未经蒸馏学习的DeiT。不出所料,联合类+distil分类器提供了一个中间地带。
Number of epochs.
增加历元数可以显著提高蒸馏训练的性能,见图3。有300个时代,我们的蒸馏网络DeiT-B 已经优于DeiT-B。但对于后者,性能随着时间的延长而饱和,我们提取的网络显然受益于更长的训练时间。
6 Training details & ablation
7 Conclusion
在本文中,我们介绍了DeiT,这是一种图像变换器,由于改进了培训,尤其是新的蒸馏程序。在近十年的时间里,卷积神经网络在架构和优化方面都进行了优化,包括通过广泛的架构搜索,这种搜索容易过度拟合,例如EfficientNets[51]。对于DeiT,我们已经开始使用已有的ConvNet数据增强和正则化策略,除了我们的新蒸馏令牌之外,没有引入任何重要的架构。因此,更适合或学习变压器的数据增强研究可能会带来进一步的收益。
因此,考虑到我们的结果,在图像变换器已经与ConvNet相当的情况下,考虑到在给定精度下较低的内存占用,我们相信它们将迅速成为一种选择方法。
边栏推荐
- Quick analysis of Intranet penetration helps the foreign trade management industry cope with a variety of challenges
- Fast parsing intranet penetration escorts the document encryption industry
- 藏书馆App基于Rainbond实现云原生DevOps的实践
- 解析机器人科技发展观对社会研究论
- [step on the pit series] H5 cross domain problem of uniapp
- CDC (change data capture technology), a powerful tool for real-time database synchronization
- Vulnerability recurrence fastjson deserialization
- Openjudge noi 2.1 1752: chicken and rabbit in the same cage
- [quick start of Digital IC Verification] 12. Introduction to SystemVerilog testbench (svtb)
- Standard function let and generic extension function in kotlin
猜你喜欢
BiSeNet的特点
Excel import function of jeesite form page
海信电视开启开发者模式
王爽 《汇编语言》之寄存器
在Rainbond中实现数据库结构自动化升级
复杂网络建模(一)
[quick start of Digital IC Verification] 12. Introduction to SystemVerilog testbench (svtb)
[quick start of Digital IC Verification] 10. Verilog RTL design must know FIFO
调用 pytorch API完成线性回归
藏书馆App基于Rainbond实现云原生DevOps的实践
随机推荐
Lua programming learning notes
eBPF Cilium实战(1) - 基于团队的网络隔离
Hisense TV starts the developer mode
Merging binary trees by recursion
利用 Helm 在各类 Kubernetes 中安装 Rainbond
Leetcode 187 Repeated DNA sequence (2022.07.06)
Notes on PHP penetration test topics
JS copy picture to clipboard read clipboard
机器人教育在动手实践中的真理
LeetCode简单题之字符串中最大的 3 位相同数字
Game attack and defense world reverse
Make LIVELINK's initial pose consistent with that of the mobile capture actor
Vulnerability recurrence fastjson deserialization
海信电视开启开发者模式
调用 pytorch API完成线性回归
offer收割机:两个长字符串数字相加求和(经典面试算法题)
Fast parsing intranet penetration escorts the document encryption industry
Offer harvester: add and sum two long string numbers (classic interview algorithm question)
复杂网络建模(三)
Find the mode in the binary search tree (use medium order traversal as an ordered array)