当前位置:网站首页>知识蒸馏2:目标检测中的知识蒸馏
知识蒸馏2:目标检测中的知识蒸馏
2022-07-30 17:21:00 【@BangBang】
1. 目标检测知识蒸馏介绍
1. Faster-RCNN 知识蒸馏
1.1 原理介绍
本博客介绍目标检测任务中知识蒸馏如何操作,首先介绍下相关的论文。
第一篇:Learning Efficient Object Detection Models with Knowledge Distillation
这篇文章是针对Faster -RCNN进行知识蒸馏的方法,
- Faster RCNN是一个两阶段的目标检测算法,包括
region proposal network(RPN) 和region classification network(RCN) ,这两阶段都用到了classifier和bounding-box regressor,论文使用教师网络RPN和RCN的输出作为蒸馏的目标,并应用了中间层的输出作为提示Hint - 图中,上半部分是教师网络,网络中间层的输出叫做Hint知道学生网络中间层输出的学习,中间层一般是feature map。我们希望student也学习到中简层特征的输出。让二者也有某种近视,这种层通过L2 Loss。
- 对于检测部分Detection,包括classification和regression输出,对于这两部分的预测输出,通过教师网络的输出指导学生网络的学习,同时学生网络也可以从ground truth中的硬标签中进行学习。
1.2 损失函数
对于Faster RCNN中的RCN 和RPN部分都有分类损失和回归损失,然后通过RPN ,RCN,Hint的Loss三者加权求和
- 对于 L c l s L_{cls} Lcls 结合了 与
ground truth之间的hard softmax loss以及与soft label之间的蒸馏损失。 L r e g L_{reg} Lreg 也结合了ground truth之间的smooth L1 loss和教师网络的bounded L2 regression loss的蒸馏损失。 L h i n t L_{hint} Lhint鼓励学生网络到教师网络的特征响应。 公式中的 r r r是平衡不同损失的超参。
1.2.1 分类损失:类别不平衡

- 其中 P t P_t Pt为教师网络通过升温
T后的预测输出, P s P_s Ps为学生网络通过升温T后的预测输出, x x x输入数据, y y y为标签数据,分类损失的构建如下:
在 L h a r d L_{hard} Lhard和 L s o f t L_{soft} Lsoft两者之间通过超参 u u u进行平衡. - 在 L s o f t L_{soft} Lsoft中作者考虑了类别不平衡,在
Faster RCNN这种两阶段的模型有大量的bcakground类别,而foreground类别相对很少。因此论文作者对bcakground施加了较大权重 w 0 = 1.5 w_{0}=1.5 w0=1.5,其他类别 w i = 1 w_{i}=1 wi=1
1.2.2 回归损失
对于回归损失作者利用了 L 1 L_1 L1和 L b L_b Lb损失,使用的
smooth L1 loss, L 1 L_1 L1是学生网络和真实标签之间的损失; L b L_b Lb是学生网络,教师网络,预测标签之间的损失。
对于 L b L_b Lb损失,主要是个L2 Loss, 但这里考虑了一个条件 ∣ ∣ R s − y ∣ ∣ 2 2 + m > ∣ ∣ R t − y ∣ ∣ 2 2 ||R_s-y||_2^2+m >||R_t-y||_2^2 ∣∣Rs−y∣∣22+m>∣∣Rt−y∣∣22,就是说学生网络与ground truth的误差要比教师网络的误差打一个
margin。这样设计损失的目的是鼓励学生网络在学习回归的时候接近或者比教师网络要好,但是一旦达到教师网络的性能之后,就不在要求学生网络再进一步学习。
1.2.3 Hint Loss
Hint 是中间层特征的学习,
Hint Loss是一个V和Z的L2 损失 ,V是学生网络中间层的特征输出,Z是教师网络中间层的特征输出。作者对L1 Loss和L2 Loss都进行了实验。
注意:教师网络和学生网络输出的维度并不一定相同,因此通过一个自适应层去调节网络的输出,比如1x1卷积
2. YOLO 知识蒸馏
论文:Object detection at 200 Frames Per Second,这篇论文研究对于YOLO目标检测知识蒸馏的方法
2. 1 介绍

- yolo是单阶段目标检测算法,上面是
Tiny-YOLO作为教师网络,下面是Yolo作为教师网络,原理和上面的Faster RCNN知识蒸馏比较相似,但它没有中间层的学习。
2. 2 损失函数
- 损失函数包括3部分: f o b j f_{obj} fobj 目标置信度得分, f o b j f_{obj} fobj分类以及 f b b f_{bb} fbb回归的损失函数。

- 对每部分的损失函数考虑了
distillation loss,同时考虑了object scaled,即蒸馏损失和hard 损失之间的加权。详见:知识蒸馏1:基础原理讲解及项目实战介绍
边栏推荐
- How Google earth engine realizes the arrangement and selection of our time list
- Dive deep on Netflix‘s recommender system(Netflix推荐系统是如何实现的?)
- 多年以后「PageHelper」又深深的给我上了一课
- 你是这样的volatile,出乎意料
- 【AAAI2020】阿里DMR:融合Matching思想的深度排序模型
- LeetCode318: Maximum product of word lengths
- un7.30:Linux——如何在docker容器中显示MySQL的中文字符?
- Win11如何把d盘空间分给c盘?Win11d盘分盘出来给c盘的方法
- 论文阅读之《Quasi-Unsupervised Color Constancy 》
- Insert data into MySQL in C language
猜你喜欢
随机推荐
bert-base调试心得
Analysis and Simulation of Short Circuit Fault in Power System Based on MATLAB
中文字符集编码Unicode ,gb2312 , cp936 ,GBK,GB18030
Microsoft Office 2019 软件下载安装详细教程!
将 APACHE 日志解析到 SQL 数据库中
Dive deep on Netflix‘s recommender system(Netflix推荐系统是如何实现的?)
leetcode:1488. 避免洪水泛滥【二分 + 贪心】
fast shell porting
bean的生命周期
torch.optim.Adam() 函数用法
简易的命令行入门教程
论文阅读之《Quasi-Unsupervised Color Constancy 》
Promise入门到精通(1.5w字详解)
MySQL 8.0.29 解压版安装教程(亲测有效)
17.机器学习系统的设计
Mongoose module
有没有并发系统设计的经验,我该怎么说?
Oracle动态监听与静态监听详解
C陷阱与缺陷 第6章 预处理器
SYSCALL SWAPGS









