当前位置:网站首页>目标检测中的知识蒸馏方法
目标检测中的知识蒸馏方法
2022-07-30 05:44:00 【Amber亮】
目标检测中的知识蒸馏方法
知识蒸馏(Knowledge Distillation KD)是一种教师-学生(Teacher-Student)训练结构,通常是已训练好的教师模型提供知识,学生模型通过蒸馏训练来获取教师的知识,能够以轻微的性能损失为代价将复杂教师模型的知识迁移到简单的学生模型中。在蒸馏的过程中,小模型学习到了大模型的泛化能力,保留了接近于大模型的性能。
一、什么是知识蒸馏
1、知识蒸馏的思想
知识蒸馏这一概念最早提出是为了解决模型压缩(轻量化)问题的。它是指从大模型(教师模型)中学习到有用的知识来训练小模型(学生模型),在保证性能差不多的情况下进行模型压缩。在蒸馏的过程中,小模型学习到了大模型的泛化能力,保留了接近于大模型的性能。
用压缩之前的模型作为老师,压缩之后的模型作为学生,通过一步一步地使用一个较大的已经训练好的网络去教导一个较小的网络确切地去做什么。“软标签”指的是大网络在每一层卷积后输出的feature map。然后,通过尝试复制大网络在每一层的输出(不仅仅是最终的损失),小网络被训练以学习大网络的准确行为。
最早知识蒸馏算法主要针对分类问题,分类问题的共同点是模型最后会有一个softmax层,其输出值对应了相应类别的概率值。一个很直白且高效的迁移泛化能力的方法就是:使用“软目标”,软目标是指输入数据通过教师模型所得到的softmax层的输出。软目标有着更高的熵,更小的梯度变化,因此学生模型相比教师模型可以使用更少的数据和更大的学习率(意味着收敛很快,这部分多出来的训练时间不是问题)。
2、KD的训练过程为什么更有效?
softmax层的输出,除了正例之外,负标签也带有大量的信息,比如某些负标签对应的概率远远大于其他负标签。而在传统的训练过程(hard target)中,所有负标签都被统一对待。也就是说,KD的训练方式使得每个样本给Net-S带来的信息量大于传统的训练方式。
hard targets:输入数据所对应的label 例:[0,0,1,0]
soft targets:输入数据通过大模型所得到的softmax层的输出 例:[0.01,0.02,0.98,0.17]
soft target有着更高的熵,更小的梯度变化,因此student相比teacher可以使用更少的数据和更大的学习率(意味着收敛很快,这部分多出来的训练时间不是问题)
3、知识蒸馏的过程
1)训练大模型:先用hard target,也就是正常的label训练大模型;
2)计算“软目标”:利用训练好的大模型来计算“软目标”。也就是大模型“软化后”再经过softmax的output;
3)训练小模型,在小模型的基础上再加一个额外的软目标的损失函数,通过参数调节两个损失函数的比重,蒸馏损失如下图所示:
4)将训练好的小模型进行预测。
二、知识蒸馏的相关文献
1、Distilling the Knowledge in a Neural Network
2、Improved Knowledge Distillation via Teacher Assistant: Bridging the Gap Between Student and Teacher,2019
通过助教提升知识提取能力:减小学生网络与老师网络之间的鸿沟
当学生和教师之间的差距较大时,学生网络性能会下降。引入了多步骤知识提炼,它使用一个中等规模的网络(教师助理)来弥合学生和教师之间的差距。
三、目标检测中的知识蒸馏的相关文献
1、Learning Efficient Object Detection Models with Knowledge Distillation (2017 NIPS)
第一篇是2017年的NIPS提出的,它是用知识蒸馏做目标检测的第一篇论文。核心思想比较简单的,主要由三个loss函数组成,分别对主干网络、分类head和回归head进行了蒸馏。
1)对主干网络,加了一个adaptation layers,让feature map的维度匹配
2)对于分类任务的输出,使用加权cross entropy loss来解决类别失衡严重问题
3)对于回归任务,将教师的回归预测作为上界。学生网络需要尽可能和真实标签接近,一旦学生网络的质量超越教师网络时,教师网络不再指导学生网络。
2、Distilling Object Detectors with Fine-grained Feature Imitation CVPR 2019
复现基于原文开源代码:https://github.com/twangnh/Distilling-Object-Detectors
第二篇是2019年的CVPR提出的,通过细粒度特征模仿蒸馏目标检测。这篇文章FGFI 通过实验发现,全特征模仿(就是均衡地学习特征图上的所有特征)会导致学生模型效果下降,推测是由于检测模型中特征数量多而且复杂,背景信息包含过多的噪声。因此,FGFI 考虑只对GT附近的anchor进行蒸馏,提出了方法 fine-grained feature limitation。首先定位这些知识密集的位置,并让学生模型在这些位置上模仿老师。
那么如何定位这些知识密集的位置呢?FGFI 通过设置掩码定位知识密集的位置,让模型只对GT附近的anchor进行蒸馏
3、Localization Distillation for Object Detection (目标检测中的定位蒸馏)2021
创新点:
1、提出了用于目标检测的定位蒸馏方法,适用于任何结构的检测器。基于bbox的通用分布,我们的定位蒸馏方法可以被公式化为标准的知识蒸馏,通过标准的知识蒸馏,教师模型捕获的定位模糊可以被很好地提取到学生模型中。
2、教师模型不一定是最优的,为了保持学习的有效性,我们提出了一种教师助理策略来填补教师模型和学生模型之间可能存在的差距。
3、对教师网络使用self-LD能够进一步增强预测框的精确度
4、Deep Structured Instance Graph for Distilling Object Detectors (ICCV 2021) 基于 detectron2
代码:https://github.com/dvlab-research/Dsig
5、Distilling Object Detectors via Decoupled Features(CVPR 2021)
代码:基于mmdetection
FGFI假设背景信息包含太多的噪声,所以就剔除了背景信息。但基于解耦特征的目标检测知识蒸馏 DeFeat通过实验证明:背景区域同样包含有用的信息,能够提升student模型指标。
知识蒸馏中只使用正样本虽然可以提升模型指标,但是不使用正样本,只使用背景区域也可以达到相同指标。
与图像分类不同,目标检测器具有复杂的多损失函数,其中语义信息所依赖的特征非常复杂。 DeFeat指出一种在现有方法中经常被忽略的路径:从不包括目标的区域中提取的特征信息对于提取学生检测器也是必不可少的。同时阐明了在蒸馏过程中,不同区域的特征应具有不同的重要性。并为此提出了一种新的基于解耦特征(DeFeat)的提取算法来学习更好的学生检测器。具体来说,将处理两个层次的解耦特征来将有用信息嵌入到学生中,即来自颈部的解耦特征和来自分类头部的解耦proposal。在不同主干的探测器上进行的大量实验表明,该方法能够超越现有的目标检测蒸馏方法。
Defeat在计算loss时,把pos和neg分开独立计算loss,利用超参缓解pos/neg之间的不均衡问题。
两个阶段计算KD Loss:
1、Decoupled feature 在Teacher/Student相同尺度特征图之间计算KD Loss
2、Decoupled RoI取Teachear输出的ROI,分别送入Student/Tearcher的ROIAligned模块,获得对应的ROI Feature,Studenet/Teacher各自预测类别和位置,计算K
7、Deep Structured Instance Graph for Distilling Object Detectors (ICCV 2021) 蒸馏目标检测的深度结构实例图
以往的工作中针对目标检测的蒸馏存在两个问题:
(1)特征的不平衡问题(前景特征和背景特征的不平衡)
(2)缺失目标之间的一种关系
之前的一些方法,像分类上,其实都是点对点的蒸馏,因为在分类上没有instance这个概念。但在目标检测这个任务上,更关注的是instance这个概念,就是一些物体,这些物体之间的关系其实做检测的时候是很关键的。
四、产业应用前景
知识蒸馏是模型压缩、计算加速的重要方法,尤其是将深度神经网络部署到移动端、IOT等边缘设备上时。此外,知识蒸馏也有模型增强的作用,可利用其它资源(如无标签或跨模态的数据)或知识蒸馏的优化策略(如相互学习和自学习)来提高一个复杂学生模型的性能。
知识蒸馏可应用于计算机视觉、自然语言处理、语音识别、推荐系统、安全隐私问题、多模态数据、金融证券等领域。目前,知识蒸馏的算法已经广泛应用到图像语义识别,目标检测等场景中,并且针对不同的研究场景,蒸馏方法都做了部分的定制化修改,同时,在行人检测,人脸识别,姿态检测,图像域迁移,视频检测等方面,知识蒸馏也是作为一种提升模型性能和精度的重要方法。
边栏推荐
猜你喜欢

【面经】米哈游数据开发面经

Servlet basic principles and application of common API methods

MySQL开窗函数

【MySQL功法】第5话 · SQL单表查询

Twenty-two, Kotlin advanced learning: simply learn RecyclerView to achieve list display;

mysql delete duplicate data in the table, (retain only one row)
phpok website vulnerability exploitation analysis

Jdbc & Mysql timeout分析

MySQL 数据类型及占用空间
Volatility memory forensics - command shows
随机推荐
Redis publish/subscribe
MySQL 索引优化及失效场景
Flink-stream/batch/OLAP integrated to get Flink engine
Flink PostgreSQL CDC configuration and FAQ
【SQL】first_value 应用场景 - 首单 or 复购
MySQL 5.7 installation tutorial (all steps, nanny tutorials)
史上超强最常用SQL语句大全
根据ip地址获取地理位置及坐标(离线方式)
Using PyQt5 to add an interface to YoloV5 (1)
Usage of exists in sql
利用自定义注解,统计方法执行时间
Bubble sort, selection sort, insertion sort, quick sort
Flink PostgreSQL CDC配置和常见问题
Flink CDC 实现Postgres到MySQL流式加工传输案例
十三、Kotlin进阶学习:内联函数let、also、with、run、apply的用法。
mysql delete duplicate data in the table, (retain only one row)
MySQL - 多表查询与案例详解
学生成绩管理系统(C语言版)
[MATLAB]图像处理——交通标志的识别
第一个WebAssembly程序