当前位置:网站首页>从正负样本解耦看对比学习为何需要large batch size训练Ddcoupled Contrastive learning (DCT)
从正负样本解耦看对比学习为何需要large batch size训练Ddcoupled Contrastive learning (DCT)
2022-08-04 16:15:00 【想搞钱的小陈】
转载自从正负样本解耦看对比学习为何需要large batch size训练 - 知乎
欢迎大家从原文观看,此篇只做转载加上自己的一些理解。
一、目的
解决对比学习中常常用到的大batch size问题,这种固有缺陷往往是由于InfoNCE损失函数带来的,本文试图修改该损失函数来降低batch size对对比学习过程中模型权重变更的影响。
-------------------------------
这篇文章主要是在对梯度方面入手,从梯度中有一个qb乘子,这个qb乘子很难训练,因为正负样本对很容易分开,所以改进loss,直接舍弃了这个qb乘子。
二、方法
传统对比学习如SimCLR采用InfoNCE作为损失函数,其表达形式为:
这里 k∈{1,2} ,表示正例样本的两种不同视角, i,j 表示batch内的样本。对z求导之后,可以得到梯度如下表达式:
其中, qB,i(1) 为负例-正例耦合项NPC(negtive-positive coulpling),其表达式为:
由于每个梯度的系数都包含 qB,i ,且后面乘上的是变量 z ,所以梯度很大程度上由系数 qB,i 决定,而在一般CL任务中,正正,正负样本较容易区分开,这种情况下 qB,i 是趋近于0的,所以导致梯度很小,参数更新不动。
如何解决?我们回到InfoNCE损失中,去掉分母中正正样本的相关项,即:
这样求导后就没了 qB,i 项,梯度不至于变得很小。
进一步,化简上式,得到损失:
对正例部分(即第一项)做个权重加成得到最终版modified InfoNCE loss:
三、效果
这里列一个最关键的结果:
ImageNet-1K上能达到67%,虽然比MAE差了一大截,但这个方法还是管用的。
四、思考
文章证明比较详细,我也没有细看,只是从一个更容易理解的思路去论证了一下。我们知道InfoNCE的表达形式为 lossInfoNCE=−logesiesi+∑i≠jesj ,其中 si 为正例对的得分, sj 为负例对得分。进一步该损失可以写成 log(1+∑i≠jesjesi) ,该式可近似改写成 log(1+(n−1)∗ϵ) ,这里 ϵ 为负例得分比正例得分的均值。由于正负例对区别很明显,神经网络往往很快就能将 si 与 sj 的值拉大,导致 ϵ 的值很小,如果n也很小的话,那么loss就会很小,导致反向传播的梯度很小,参数难以更新。这就是对比学习往往需要大batch size的原因。
那么问题来了,既然 ϵ 很小导致 log(1+(n−1)ϵ) 趋近于0,那么我们干脆把1去掉,变成 log((n−1)ϵ) ,这样就算 ϵ 小,n小,梯度也会很大,不会导致模型学不出来的问题。具体来说就是把InfoNCE中分母的正例相关项去掉,这就是这篇文章做的核心工作。
另外还有一个问题,原始的InfoNCE是为了保证所有的样本对的概率小于1,所以分母是正正+正负样本对,现在去掉正正样本对,如何保证整个分式的值小于1,使loss不至于变成负数?我们以余弦相似度为例,则 s∈[−1,1] ,现在考虑极端情况:正正样本相似度全为1,正负样本相似度全为-1,那么满足 e1nneg∗e−1<1 ,可以得到 nneg>e2≈8 ,只要负例大于8个就能满足,这也就不是问题了。
扩展
在做dense retrieval过程中,往往也会引入batch in的负例,结合这个改进,应该能使模型学习到更好的文本表征。
边栏推荐
猜你喜欢
【TA-霜狼_may-《百人计划》】美术2.7 Metallic 与 Speculer流程
"Research Report on the Development of Global Unicorn Enterprises in the First Half of 2022" released - DEMO WORLD World Innovation Summit ended successfully
Many merchants mall system function and dismantling 24 - ping the strength distribution of members
不需要服务器,教你仅用30行代码搞定实时健康码识别
花了半个月,终于把一线大厂高频面试题做成合集了
To ensure that the communication mechanism
软考 --- 软件工程(2)软件开发方法
What is an artifact library in a DevOps platform?What's the use?
招募 | 香港理工大学Georg Kranz 博士诚招博士
Visual Studio 2022创建项目没有CUDA模板的解决方法
随机推荐
MySQL 性能调优和优化技巧
Roslyn 在 msbuild 的 target 判断文件存在
软考 --- 软件工程(2)软件开发方法
吴恩达机器学习[11]-机器学习性能评估、机器学习诊断
花了半个月,终于把一线大厂高频面试题做成合集了
张乐:研发效能的黄金三角及需求与敏捷协作领域的实践
跟我学 UML 系统建模
HyperBDR云容灾深度解析一:云原生跨平台容灾,让数据流转更灵活
成功 解决 @keyup.enter=“search()“ 在el-input 组件中不生效的问题
Check which user permissions are assigned to each database, is there an interface for this?
招募 | 香港理工大学Georg Kranz 博士诚招博士
JVM Tuning-GC Fundamentals and Tuning Key Analysis
皕杰报表配置文件report_config.xml里都配置了什么?
MySQL学习之运算符
越来越火的图数据库到底能做什么?
“敏捷欺骗了开发人员”
Redis的主从复制和集群
邮差"头":{“retCode”:“999999”
inter-process communication
C#命令行解析工具