当前位置:网站首页>【论文介绍】R-Drop: Regularized Dropout for Neural Networks
【论文介绍】R-Drop: Regularized Dropout for Neural Networks
2022-07-02 06:25:00 【lwgkzl】
总述
本文的出发点在于:之前的dropout存在训练与测试时模型不一致的问题。
基于这个出发点,本文提出了R-Dropout的方式来解决这个问题。
实验证明,R-Dropout在多个数据集上均有效(都略有提升)
对Dropout的思考
首先我们要理解,为什么之前的dropout存在训练与测试时不一致的问题。在训练的时候,dropout会随机mask模型的一些节点,然后利用剩余的网络去拟合数据(防止过拟合)。在不同batch的数据训练的过程中,由于mask是随机变动的,因此不同的数据可能会经过不同的网络处理。因此整个训练过程可以看成是多个不同的网络的集成学习。而在测试的时候,由于不会随机mask掉节点,因此可以看做是完整的模型在测试集上做预测,因此这里就存在不一致性。
因为训练的时候,学习的是一个子模型,而测试的时候,使用的是完整的模型来做预测。
作者的思路非常清奇,可能直观的方法是直接想办法减小子模型与完整模型的差距,而本文的想法没有那么直观,而是说:如果所有子模型的输入都差不多的话,那么完整模型的输出与子模型的输出应该也相差不大。, 因此本文的优化目标是对于同一组输入,经过相同架构,但是以不同mask dropout后的模型,得到的输出应该要保持一致。
R-Dropout 介绍
其主要思路如上一节所示,这张图也可以直观的展示他的思路,如右图所示,对于同一个输入X,经过两个相同的Transformer encoder结构,不过这两个结构会以不同的mask进行dropout然后得到两个输出P1(y|x)以及P2(y|x),R-dropout要求这两个输出要尽可能一致。因此以这两个输出的KL散度作为模型的损失函数之一来进行优化。
实验与结论
标准的实验证明R-Dropout在18个数据集上都有一点提升(1%–2%)。
其中消融实验比较有意思,会验证几个想法。
想法1:
每一次进行R-dropout的时候,可以不只是两个模块产生的输出求KL散度,可以同时使用多个模块来对照。
结论: 作者实验了同时做三个不同dropout的模块,效果比两个模块要稍微好一丢丢,但是没有太多意义。
想法2:这两个模块,dropout的概率可以不一样,因此可以尝试以不同的概率进行mask。得到如下的矩阵。结论: 两个模块的dropout概率在0.3-0.5之间时,结果都相差不多。
代码
import torch.nn.functional as F
# define your task model, which outputs the classifier logits
model = TaskModel()
def compute_kl_loss(self, p, q, pad_mask=None):
p_loss = F.kl_div(F.log_softmax(p, dim=-1), F.softmax(q, dim=-1), reduction='none')
q_loss = F.kl_div(F.log_softmax(q, dim=-1), F.softmax(p, dim=-1), reduction='none')
# pad_mask is for seq-level tasks
if pad_mask is not None:
p_loss.masked_fill_(pad_mask, 0.)
q_loss.masked_fill_(pad_mask, 0.)
# You can choose whether to use function "sum" and "mean" depending on your task
p_loss = p_loss.sum()
q_loss = q_loss.sum()
loss = (p_loss + q_loss) / 2
return loss
# keep dropout and forward twice
logits = model(x)
logits2 = model(x)
# cross entropy loss for classifier
ce_loss = 0.5 * (cross_entropy_loss(logits, label) + cross_entropy_loss(logits2, label))
kl_loss = compute_kl_loss(logits, logits2)
# carefully choose hyper-parameters
loss = ce_loss + α * kl_loss
实际上在使用的时候,可以把X给copy一份,然后统一输入到model中,故无需把X两次经过model。即:
double_x = torch.stack([x,x],0).view(-1,x.size(-1)
tot_logits = model(double_x).view(2, x.size(0), -1)
logits = tot_logits[0]
logits2 = tot_logits[1]
# .....
边栏推荐
- 外币记账及重估总账余额表变化(下)
- Wechat applet Foundation
- 如何高效开发一款微信小程序
- 使用 Compose 实现可见 ScrollBar
- Pratique et réflexion sur l'entrepôt de données hors ligne et le développement Bi
- Oracle RMAN semi automatic recovery script restore phase
- SQLI-LABS通关(less18-less20)
- Sqli-labs customs clearance (less15-less17)
- 架构设计三原则
- User login function: simple but difficult
猜你喜欢
Illustration of etcd access in kubernetes
腾讯机试题
ssm人事管理系统
SSM学生成绩信息管理系统
Only the background of famous universities and factories can programmers have a way out? Netizen: two, big factory background is OK
IDEA2020中测试PySpark的运行出错
类加载器及双亲委派机制
User login function: simple but difficult
软件开发模式之敏捷开发(scrum)
Write a thread pool by hand, and take you to learn the implementation principle of ThreadPoolExecutor thread pool
随机推荐
使用Matlab实现:Jacobi、Gauss-Seidel迭代
Error in running test pyspark in idea2020
Changes in foreign currency bookkeeping and revaluation general ledger balance table (Part 2)
Oracle RMAN automatic recovery script (migration of production data to test)
Message queue fnd in Oracle EBS_ msg_ pub、fnd_ Application of message in pl/sql
DNS攻击详解
ORACLE APEX 21.2安装及一键部署
JSP intelligent community property management system
ORACLE EBS DATAGUARD 搭建
@Transational踩坑
Data warehouse model fact table model design
ORACLE 11G SYSAUX表空间满处理及move和shrink区别
php中树形结构转数组(拉平树结构,保留上下级排序)
ssm人事管理系统
Oracle 11g uses ords+pljson to implement JSON_ Table effect
图解Kubernetes中的etcd的访问
php中根据数字月份返回月份的英文缩写
Ingress Controller 0.47.0的Yaml文件
Yaml file of ingress controller 0.47.0
SQLI-LABS通关(less6-less14)