当前位置:网站首页>从正负样本解耦看对比学习为何需要large batch size训练Ddcoupled Contrastive learning (DCT)
从正负样本解耦看对比学习为何需要large batch size训练Ddcoupled Contrastive learning (DCT)
2022-08-04 16:15:00 【想搞钱的小陈】
转载自从正负样本解耦看对比学习为何需要large batch size训练 - 知乎
欢迎大家从原文观看,此篇只做转载加上自己的一些理解。
一、目的
解决对比学习中常常用到的大batch size问题,这种固有缺陷往往是由于InfoNCE损失函数带来的,本文试图修改该损失函数来降低batch size对对比学习过程中模型权重变更的影响。
-------------------------------
这篇文章主要是在对梯度方面入手,从梯度中有一个qb乘子,这个qb乘子很难训练,因为正负样本对很容易分开,所以改进loss,直接舍弃了这个qb乘子。
二、方法
传统对比学习如SimCLR采用InfoNCE作为损失函数,其表达形式为:
这里 k∈{1,2} ,表示正例样本的两种不同视角, i,j 表示batch内的样本。对z求导之后,可以得到梯度如下表达式:
其中, qB,i(1) 为负例-正例耦合项NPC(negtive-positive coulpling),其表达式为:
由于每个梯度的系数都包含 qB,i ,且后面乘上的是变量 z ,所以梯度很大程度上由系数 qB,i 决定,而在一般CL任务中,正正,正负样本较容易区分开,这种情况下 qB,i 是趋近于0的,所以导致梯度很小,参数更新不动。
如何解决?我们回到InfoNCE损失中,去掉分母中正正样本的相关项,即:
这样求导后就没了 qB,i 项,梯度不至于变得很小。
进一步,化简上式,得到损失:
对正例部分(即第一项)做个权重加成得到最终版modified InfoNCE loss:
三、效果
这里列一个最关键的结果:
ImageNet-1K上能达到67%,虽然比MAE差了一大截,但这个方法还是管用的。
四、思考
文章证明比较详细,我也没有细看,只是从一个更容易理解的思路去论证了一下。我们知道InfoNCE的表达形式为 lossInfoNCE=−logesiesi+∑i≠jesj ,其中 si 为正例对的得分, sj 为负例对得分。进一步该损失可以写成 log(1+∑i≠jesjesi) ,该式可近似改写成 log(1+(n−1)∗ϵ) ,这里 ϵ 为负例得分比正例得分的均值。由于正负例对区别很明显,神经网络往往很快就能将 si 与 sj 的值拉大,导致 ϵ 的值很小,如果n也很小的话,那么loss就会很小,导致反向传播的梯度很小,参数难以更新。这就是对比学习往往需要大batch size的原因。
那么问题来了,既然 ϵ 很小导致 log(1+(n−1)ϵ) 趋近于0,那么我们干脆把1去掉,变成 log((n−1)ϵ) ,这样就算 ϵ 小,n小,梯度也会很大,不会导致模型学不出来的问题。具体来说就是把InfoNCE中分母的正例相关项去掉,这就是这篇文章做的核心工作。
另外还有一个问题,原始的InfoNCE是为了保证所有的样本对的概率小于1,所以分母是正正+正负样本对,现在去掉正正样本对,如何保证整个分式的值小于1,使loss不至于变成负数?我们以余弦相似度为例,则 s∈[−1,1] ,现在考虑极端情况:正正样本相似度全为1,正负样本相似度全为-1,那么满足 e1nneg∗e−1<1 ,可以得到 nneg>e2≈8 ,只要负例大于8个就能满足,这也就不是问题了。
扩展
在做dense retrieval过程中,往往也会引入batch in的负例,结合这个改进,应该能使模型学习到更好的文本表征。
边栏推荐
- 攻防视角下,初创企业安全实战经验分享
- [TA-Frost Wolf_may-"Hundred Talents Project"] Art 2.7 Metallic and Speculer Process
- 07-输入输出系统
- HCIP笔记(7)
- GPS卫星同步时钟,NTP网络同步时钟,北斗时钟服务器(京准)
- JVM Tuning-GC Fundamentals and Tuning Key Analysis
- 面渣逆袭:MySQL六十六问,两万字+五十图详解
- Win10 无线网卡驱动感叹号,显示错误代码56
- B站回应HR称核心用户是Loser;微博回应宕机原因;Go 1.19 正式发布|极客头条
- shell中当basename和dirname无法满足你的需求时你一定要想到的命令
猜你喜欢
【Idea设置运行参数无效】可能是...
饿了么智能头盔专利获授权 进一步提升骑手安全保障
现代 ABAP 编程语言中的正则表达式
Typora收费?搭建VS Code MarkDown写作环境
In action: 10 ways to implement delayed tasks, with code!
How to monitor code cyclomatic complexity by refactoring indicators
Task Computing【动态规划_牛客】
Jenkins 如何玩转接口自动化测试?
贝叶斯优化核极限学习机KELM用于回归预测
Matlab计算随模拟时间变化的热导率
随机推荐
In-depth analysis of HyperBDR cloud disaster recovery 1: Cloud-native cross-platform disaster recovery, making data flow more flexible
【打卡】广告-信息流跨域ctr预估(待更新)
开源一夏 | 请你谈谈网站是如何进行访问的?【web领域面试题】
shell中当basename和dirname无法满足你的需求时你一定要想到的命令
NFT blind box mining system dapp development NFT chain game construction
jasmine.any(Function) 的一个使用场景
第三章 Scala运算符
皕杰报表配置文件report_config.xml里都配置了什么?
Pulsar消费者处理不当导致的消息积压问题
Request method ‘POST‘ not supported。 Failed to load resource: net::ERR_FAILED
测试零基础如何进入大厂?一场面试教会你(附面试题解析)
8年软件测试感悟,送给刚入测试行业的小伙伴
在Markdown文件中快速插入本地图片
软考 --- 软件工程(2)软件开发方法
Check which user permissions are assigned to each database, is there an interface for this?
大家有没有遇到过 cdc mysql to doris只能单task,看不到具体数据流。怎么回事?
农产品期货开户哪家好??
06-总线
【Idea设置运行参数无效】可能是...
Steady Development | Data and Insights on Mobile Game Players in Western Europe