当前位置:网站首页>知识蒸馏2:目标检测中的知识蒸馏
知识蒸馏2:目标检测中的知识蒸馏
2022-07-30 17:21:00 【@BangBang】
1. 目标检测知识蒸馏介绍
1. Faster-RCNN 知识蒸馏
1.1 原理介绍
本博客介绍目标检测任务中知识蒸馏如何操作,首先介绍下相关的论文。
第一篇:Learning Efficient Object Detection Models with Knowledge Distillation
这篇文章是针对Faster -RCNN
进行知识蒸馏的方法,
- Faster RCNN是一个两阶段的目标检测算法,包括
region proposal network
(RPN) 和region classification network
(RCN) ,这两阶段都用到了classifier
和bounding-box regressor
,论文使用教师网络RPN
和RCN
的输出作为蒸馏的目标,并应用了中间层的输出作为提示Hint - 图中,上半部分是教师网络,网络中间层的输出叫做Hint知道学生网络中间层输出的学习,中间层一般是feature map。我们希望student也学习到中简层特征的输出。让二者也有某种近视,这种层通过L2 Loss。
- 对于检测部分Detection,包括classification和regression输出,对于这两部分的预测输出,通过教师网络的输出指导学生网络的学习,同时学生网络也可以从ground truth中的硬标签中进行学习。
1.2 损失函数
对于Faster RCNN中的RCN 和RPN部分都有分类损失和回归损失,然后通过RPN ,RCN,Hint
的Loss三者加权求和
- 对于 L c l s L_{cls} Lcls 结合了 与
ground truth
之间的hard softmax loss
以及与soft label
之间的蒸馏损失。 L r e g L_{reg} Lreg 也结合了ground truth
之间的smooth L1 loss
和教师网络的bounded L2 regression loss
的蒸馏损失。 L h i n t L_{hint} Lhint鼓励学生网络到教师网络的特征响应。 公式中的 r r r是平衡不同损失的超参。
1.2.1 分类损失:类别不平衡
- 其中 P t P_t Pt为教师网络通过升温
T
后的预测输出, P s P_s Ps为学生网络通过升温T
后的预测输出, x x x输入数据, y y y为标签数据,分类损失的构建如下:
在 L h a r d L_{hard} Lhard和 L s o f t L_{soft} Lsoft两者之间通过超参 u u u进行平衡. - 在 L s o f t L_{soft} Lsoft中作者考虑了类别不平衡,在
Faster RCNN
这种两阶段的模型有大量的bcakground
类别,而foreground
类别相对很少。因此论文作者对bcakground
施加了较大权重 w 0 = 1.5 w_{0}=1.5 w0=1.5,其他类别 w i = 1 w_{i}=1 wi=1
1.2.2 回归损失
对于回归损失作者利用了 L 1 L_1 L1和 L b L_b Lb损失,使用的
smooth L1 loss
, L 1 L_1 L1是学生网络和真实标签之间的损失; L b L_b Lb是学生网络,教师网络,预测标签之间的损失。对于 L b L_b Lb损失,主要是个L2 Loss, 但这里考虑了一个条件 ∣ ∣ R s − y ∣ ∣ 2 2 + m > ∣ ∣ R t − y ∣ ∣ 2 2 ||R_s-y||_2^2+m >||R_t-y||_2^2 ∣∣Rs−y∣∣22+m>∣∣Rt−y∣∣22,就是说学生网络与ground truth的误差要比教师网络的误差打一个
margin
。这样设计损失的目的是鼓励学生网络在学习回归的时候接近或者比教师网络要好,但是一旦达到教师网络的性能之后,就不在要求学生网络再进一步学习。
1.2.3 Hint Loss
Hint 是中间层特征的学习,
Hint Loss
是一个V和Z的L2 损失 ,V是学生网络中间层的特征输出,Z是教师网络中间层的特征输出。作者对L1 Loss和L2 Loss都进行了实验。注意:教师网络和学生网络输出的维度并不一定相同,因此通过一个自适应层去调节网络的输出,比如1x1卷积
2. YOLO 知识蒸馏
论文:Object detection at 200 Frames Per Second,这篇论文研究对于YOLO目标检测知识蒸馏的方法
2. 1 介绍
- yolo是单阶段目标检测算法,上面是
Tiny-YOLO
作为教师网络,下面是Yolo作为教师网络,原理和上面的
Faster RCNN知识蒸馏比较相似,但它没有中间层的学习
。
2. 2 损失函数
- 损失函数包括3部分: f o b j f_{obj} fobj 目标置信度得分, f o b j f_{obj} fobj分类以及 f b b f_{bb} fbb回归的损失函数。
- 对每部分的损失函数考虑了
distillation loss
,同时考虑了object scaled,即蒸馏损失和hard 损失之间的加权。详见:知识蒸馏1:基础原理讲解及项目实战介绍
边栏推荐
- bean的生命周期
- olap——入门ClickHouse
- 阿里巴巴中国站获得1688商品分类 API
- How Google earth engine realizes the arrangement and selection of our time list
- 【综合类型第 34 篇】喜讯!喜讯!!喜讯!!!,我在 CSDN 的第一个实体铭牌
- 【Cloud Store Announcement】Notice of Help Center Update on July 30
- SYSCALL SWAPGS
- KDD 2020 | 深入浅出优势特征蒸馏在淘宝推荐中的应用
- [NCTF2019]Fake XML cookbook-1|XXE漏洞|XXE信息介绍
- JMeter笔记4 | JMeter界面介绍
猜你喜欢
[Geek Challenge 2020] Roamphp1-Welcome
Tensorflow模型量化(Quantization)原理及其实现方法
向量检索基础方法总结
Error occurred while trying to proxy request The project suddenly can't get up
论文阅读之《DeepIlluminance: Contextual IlluminanceEstimation via Deep Neural Networks》
Error EPERM operation not permitted, mkdir 'Dsoftwarenodejsnode_cache_cacach Two solutions
链表Oj练习题 纯C语言
Excel导入和导出
UE5第一人称射击游戏蓝图教程
Mongoose module
随机推荐
LeetCode318:单词长度的最大乘积
Promise入门到精通(1.5w字详解)
crontab报错,但本地执行正常
Daily practice------Generate 13-digit bar, Ean-13 code rule: The thirteenth digit is the check code obtained by the calculation of the first twelve digits.
《痞子衡嵌入式半月刊》 第 59 期
C陷阱与缺陷 第6章 预处理器 6.2 宏并不是函数
论文阅读之《Underwater scene prior inspired deep underwater image and video Enhancement (UWCNN)》
优酷视频元素内容召回系统:多级多模态引擎探索
快使用flyway管理sql脚本吧~
C# 连接SQL Sever 数据库与数据查询实例 数据仓库
数据库的三大范式
万华化学精细化工创新产品大会
Analysis and Simulation of Short Circuit Fault in Power System Based on MATLAB
C陷阱与缺陷 第7章 可移植性缺陷 7.4 字符是有符号数还是无符号数
腾讯专家献上技术干货,带你一览腾讯广告召回系统的演进
Tensorflow模型量化(Quantization)原理及其实现方法
Shell implementation based on stm32
Dive deep on Netflix‘s recommender system(Netflix推荐系统是如何实现的?)
JMeter Notes 3 | JMeter Installation and Environment Instructions
多年以后「PageHelper」又深深的给我上了一课