当前位置:网站首页>11%的参数就能优于Swin,微软提出快速预训练蒸馏方法TinyViT
11%的参数就能优于Swin,微软提出快速预训练蒸馏方法TinyViT
2022-08-05 05:15:00 【FightingCV】
【写在前面】
视觉Transformer(VIT)由于其卓越的建模能力,近年来在计算机视觉领域引起了极大的关注。然而,大多数流行的VIT模型都受到大量参数的限制,限制了它们在资源有限的设备上的适用性。为了缓解这一问题,作者提出了TinyViT,这是一种新的微小而高效的小视觉Transformer家族,使用本文提出的快速蒸馏框架在大规模数据集上进行预训练。其核心思想是将知识从大型预训练的模型转移到小型模型,同时使小型模型能够从大量的预训练数据中获得红利。更具体地说,作者在预训练期间应用蒸馏来进行知识转移。大型教师模型的logits被稀疏并预存储在磁盘中,以节省显存成本和计算开销。小的学生Transformer是从具有计算和参数约束的预训练的大型模型中自动缩小的。综合实验证明了TinyViT的有效性。它在ImageNet-1k上仅用21M个参数就达到了84.8%的TOP-1准确率,与SwinB在ImageNet-21k上的预训练相当,而使用的参数少了4.2倍。此外,提高图像分辨率,TinyViT可以达到86.5%的准确率,略好于Swin-L,而只使用11%的参数。最后,作者展示了TinyViT在各种下游任务上的良好迁移能力。
1. 论文和代码地址
TinyViT: Fast Pretraining Distillation for Small Vision Transformers
论文地址:https://arxiv.org/abs/2207.10666[1]
代码地址:https://github.com/microsoft/Cream/tree/main/TinyViT[2]
2. Motivation
Transformer已经在计算机视觉领域掀起了一场风暴,并在研究和实践中越来越受欢迎。视觉Transformer(VIT)的最新趋势之一是继续增加模型大小,同时在标准基准上产生更好的性能。例如,V-MoE使用3.05亿张图像训练一个具有147亿个参数的超大模型,实现了最先进的图像分类性能。同时,Swin使用30亿个参数和70M个预训练图片,为了在下游检测和分段任务上获得不错的结果。如此大的模型尺寸和随之而来的高昂预训练成本使得这些模型不适合用于有限计算预算的应用,例如移动和物联网边缘设备。
与大规模放大模型不同,这项工作将注意力转向缩小视觉Transformer的尺寸,旨在生成一系列新的微型模型,并提高它们在下游任务中的迁移能力。具体来说,作者探讨了以下关键问题:如何有效地将现有大型Transformer的知识转移到小型Transformer上,以及如何释放大规模数据的力量来提高小型模型的代表性?在计算机视觉中,很早就认识到,在大数据集上预训练的大模型往往能获得更好的结果,而小模型随着数据的增长很容易变得饱和(或不足)。小模型有没有可能从海量数据中吸取知识,进一步展示自己的能力?
为了回答这个问题,作者引入了一种快速的知识提取方法来对小模型进行预训练,并证明了在大模型的指导下,小模型也可以获得海量预训练数据的红利。更具体地说,作者观察到小模型的直接预训练受到性能饱和的影响,特别是当数据规模增加时。但是,如果在预训练中进行蒸馏,使用一个强大的模型作为老师,大规模预训练数据的潜力可以被释放到小模型,如上图所示。同时,提取的小模型可以很好地转移到下游任务,因为它们已经学习了如何从大模型以及大规模预训练数据中进行泛化的大量知识。
使用蒸馏的预训练模型效率低且成本高,因为在每次迭代中,相当大比例的计算资源消耗在通过大型教师模型传递训练数据上,而不是训练目标小学生。此外,巨型教师可能会占用最多的GPU内存,显著减慢学生的训练速度(由于batch大小有限)。为了解决这个问题,作者提出了一种快速且可扩展的精馏策略。更具体地说,作者提出预先生成一个稀疏概率向量作为每幅输入图像的软标签,并将其和相应的数据增强信息一起存储到标签文件中。在训练过程中,作者重用存储的稀疏软标签和扩充来精确复制蒸馏过程,成功地省略了大型教师模型的向前计算和存储。这种策略有两个优点:1)速度快。在很大程度上节省了训练过程中生成教师软标签的内存开销和计算开销。因此,小模型的蒸馏速度可以大大加快,因为它能够使用大得多的batch。此外,由于每个时代的教师log是独立的,所以可以并行保存,而不是传统方法中逐epoch地保存。2)可拓展。它可以模拟任何类型的数据增强,并生成相应的软标签。只需要将大型教师模型前向传播一次,并对任意学生模型重用软标签。
作者验证了本文的快速预训练蒸馏框架不仅在现有的小型视觉Transformer(如DeiT-T和Swin-T)上,而且在新设计的微型架构上的有效性。具体地说,作者采用渐进式模型收缩方法来缩小大型模型并生成一族微小视觉Transformer(TinyViT)。通过在ImageNet-21k上的快速预训练蒸馏,具有21M参数的TinyViT在ImageNet-1k上达到了84.8%的TOP-1准确率,比预训练的Swin-B(88M参数下的85.2%)小4.2倍。在更高的分辨率下,本文的模型可以达到86.5%的TOP-1精度,在对齐设置下在ImageNet-1k上建立了新的最先进的性能。此外,TinyViT模型在下游任务上表现出了良好的迁移能力。例如,TinyViT-21M在目标检测基准上的AP为50.2,使用28M参数时比Swin-T高2.1个百分点。
总而言之,这项工作的主要贡献有两个:
1)为了充分利用大规模的预训练数据,释放小模型的能力,作者提出了一种快速预训练蒸馏框架。这是探索小模型预训练的第一个工作。
2)作者发布了一系列新的微小视觉Transformer模型,它们在计算和精度之间找到了很好的折衷。在预训练精馏的情况下,这种模型在下游任务上表现出了良好的转移能力。
3. 方法
3.1 Fast Pretraining Distillation
作者观察到,在海量数据上直接对小模型进行预训练并没有带来太大的收益,特别是当它们转移到下游任务时,如上图所示。为了解决这个问题,作者诉诸知识蒸馏来进一步揭示小模型预训练的力量。不同于以往侧重于微调阶段蒸馏的工作,作者将重点放在预训练蒸馏上,它不仅允许小模型向大尺度模型学习,而且提高了它们对下游任务的迁移能力。
使用蒸馏进行预训练效率低且成本高,因为在每次迭代中,相当一部分计算资源被消耗在通过大型教师模型传递训练数据上,而不是训练目标小学生。此外,教师可能会占用最多的GPU内存,从而减慢目标学生的训练速度(由于batch大小有限)。为了解决这一问题,作者提出了一种快速预训练蒸馏框架。如上图所示,作者预先存储了数据扩充和教师预测的信息。在训练过程中,作者重用存储的信息来精确复制蒸馏过程,成功地省略了大型教师模型的前向计算和内存占用。
在数学上,对于具有强数据扩充的输入图像x,例如RandAugment和CutMix,作者存储和教师预测,其中和是教师模型和扩充图像。值得注意的是,将同一图像多次通过相同的数据增强pipeline将产生不同的增强图像。因此,需要在每次迭代中为每个图像保存对,如上图所示。
在训练过程中,只需要从存储的文件中恢复对,并优化以下目标函数用于学生模型蒸馏:
其中S(·)和CE(·)分别是学生模型和交叉熵损失。注意,本文的框架是无标签的,即不需要ground truth标签,因为只使用教师模型生成的软标签来进行训练。因此,它可以利用大量没有标签的现成网络数据进行大规模的预训练。这种无标签策略在实践中是可行的,因为软标签足够准确,同时携带了大量用于分类的区别性信息,如类别关系。作者还观察到,与ground truth蒸馏将导致轻微的性能下降。原因可能是并非ImageNet-21k中的所有标签都是互斥的,包括相关的对,如“椅子”和“家具”,“马”和“动物”。因此,一成不变的ground truth标签不能准确地描述对象,在某些情况下,它在训练期间抑制子类或父类。此外,由于在训练过程中去掉了繁琐的教师,因此本文的蒸馏框架与没有蒸馏的训练模型一样快。
此外,由于两个关键组件:稀疏软标签(Sparse soft labels)和数据增强编码(Data augmentation encoding),本文的蒸馏框架是快速的。它们可以在极大地减少存储消耗的同时,提高训练期间的内存效率。
Sparse soft labels
设教师模型为预测输出C个logits。如果C很大,例如,对于ImageNet-21k,C=21,841,则通常要消耗大量存储空间来保存所有增强图像的全部密集logits。因此,作者只保存了logits中最重要的部分,即稀疏软标签。形式上,作者选择中的前K个值,即,,并将它们与它们的索引一起存储到标签文件中。在训练过程中,作者只使用存储的稀疏标签进行标签平滑的蒸馏,这被定义为:
其中是为学生模型蒸馏恢复的教师logits,即。当稀疏因子K较小时,即K≪C时,可以将Logit的存储量减少数量级。
Data augmentation encoding
数据增强涉及一组参数d,例如旋转度和裁剪坐标,以变换输入图像。由于每次迭代中每个图像的d都不同,因此直接保存它会降低内存效率。为了解决这个问题,作者用单个参数来编码d,其中是上图中的编码器。然后在训练过程中,在存储文件中加载后恢复,其中,被视为解码器。因此,可以准确地重建数据增强。实际上,解码器的常见选择是伪随机数生成器(即PCG)。它将单个参数作为输入,并生成一系列参数。对于编码器,只需通过的生成器并重用解码器来实现。输出表示教师模型。保存后,以便解码器在训练学生时再现d。因此,实现变得更高效。
3.2 Model Architectures
作者通过使用渐进模型收缩方法缩小大型模型种子,提出了一个新的微型视觉Transformer家族。具体地说,作者从一个大型模型开始,并定义了一组基本的收缩因子。然后在每一步中,通过调整收缩因子,围绕当前模型生成较小的候选模型。作者选择同时满足参数数量和吞吐量限制的模型。验证精度最高的模型将被用于下一步的进一步简化,直到达到目标。这是在由收缩因子跨越的模型空间中的一种受限局部搜索形式。
作者采用分层视觉Transformer作为基本架构,以方便密集预测的下游任务,如检测,需要多尺度特征。更具体地说,本文的基本模型由四个阶段组成,分辨率逐渐降低,类似于Swin和LeVit。块嵌入由两个卷积组成,核大小为3,步长为2,填充为1。作者在阶段1和下采样块中应用了轻量级和高效的MBConv,因为较早层的卷积由于其强烈的归纳偏差而能够有效地学习低级表示。后三级由Transformer块构成,并对注意加窗以减少计算量。
Contraction factors
收缩因子。作者考虑了以下因子来形成一个模型:
:分别为四个阶段的嵌入维度。减少它们会导致一个更薄的网络和更少的头在多头自注意力。
:分别为四个阶段的块数。通过减小这些值可以降低模型的深度。
:最后三个阶段的窗口大小。随着这些值变小,该模型具有更少的参数和更高的吞吐量。
:MBConv块的通道扩充率。通过减小这一因子,可以得到较小的模型尺寸。
:所有Transformer组的MLP膨胀率。如果缩小该值,MLP的隐藏尺寸将更小。
:多头注意中每个头的大小。当缩小时,头的数量将增加,从而带来更低的计算成本。
4.实验
上表展示了难例样本的影响。
上表展示了不同训练策略的消融实验结果。
上图展示了ImageNet-21k上输出预测的皮尔逊相关性。
上图展示了TinyViT-21M的IN-1K精度和存储成本随不同的存储对数K的变化情况。
上表展示了不同教师模型的影响。
上图展示了本文方法在IN-1k上和SOTA结果的对比图。
上表展示了本文方法在IN-1k上和SOTA结果的对比。
上表展示了TinyViT-21M有无与训练进行linear probe和few-shot 图像分类的结果。
上表展示了本文方法在COCO数据集上进行目标检测的结果。
5. 总结
作者利用本文提出的快速蒸馏框架TinyViT,提出了一种新的微小而高效的视觉Transformer家族,它在大规模数据集上进行了预训练。广泛的实验证明了TinyViT在ImageNet-1k上的有效性,以及它在各种下游基准上的卓越可迁移性。
已建立深度学习公众号——FightingCV,欢迎大家关注!!!
ICCV、CVPR、NeurIPS、ICML论文解析汇总:https://github.com/xmu-xiaoma666/FightingCV-Paper-Reading
面向小白的Attention、重参数、MLP、卷积核心代码学习:https://github.com/xmu-xiaoma666/External-Attention-pytorch
加入交流群,请添加小助手wx:FightngCV666
参考资料
[1]https://arxiv.org/abs/2207.10666: https://arxiv.org/abs/2207.10666
[2]https://github.com/microsoft/Cream/tree/main/TinyViT: https://github.com/microsoft/Cream/tree/main/TinyViT
边栏推荐
猜你喜欢
解决:Unknown column ‘id‘ in ‘where clause‘ 问题
[Go through 8] Fully Connected Neural Network Video Notes
【论文精读】R-CNN 之预测框回归(Bounding box regression)问题详述
将照片形式的纸质公章转化为电子公章(不需要下载ps)
flink部署操作-flink on yarn集群安装部署
vscode+pytorch use experience record (personal record + irregular update)
Flink EventTime和Watermarks案例分析
el-pagination左右箭头替换成文字上一页和下一页
大型Web网站高并发架构方案
第四讲 back propagation 反向传播
随机推荐
Matplotlib(二)—— 子图
Flink Oracle CDC写入到HDFS
【数据库和SQL学习笔记】7.SQL中的插入(INSERT)、删除(DELETE)、更新(UPDATE)
Mesos learning
Redux
flink中文文档-目录v1.4
基于Flink CDC实现实时数据采集(三)-Function接口实现
What are the characteristics of the interface of the physical layer?What does each contain?
Calling Matlab configuration in pycharm: No module named 'matlab.engine'; 'matlab' is not a package
Service
【零基础开发NFT智能合约】如何使用工具自动生成NFT智能合约带白名单可Mint无需写代码
有用番茄来监督自己的同道中人吗?加一下我的自习室,一起加油
学习总结week2_3
DOM及其应用
关于基于若依框架的路由跳转
flink部署操作-flink on yarn集群安装部署
BFC详解(Block Formmating Context)
Day1:用原生JS把你的设备变成一台架子鼓!
Flink EventTime和Watermarks案例分析
Matplotlib(一)—— 基础