当前位置:网站首页>【调参Tricks】WhiteningBERT: An Easy Unsupervised Sentence Embedding Approach
【调参Tricks】WhiteningBERT: An Easy Unsupervised Sentence Embedding Approach
2022-07-02 06:25:00 【lwgkzl】
总述
该文主要介绍了三种使用BERT做Sentence Embedding的小Trick,分别为:
- 应该使用所有token embedding的average作为句子表示,而非只使用[CLS]对应位置的表示。
- 在BERT中应该使用多层的句向量叠加,而非只使用最后一层。
- 在通过余弦相似度做句子相似度判定的时候,可以使用Whitening操作来统一sentence embedding的向量分布,从而可以获得更好的句子表示。
模型
文中介绍的前两点均不涉及到模型,只有第三点Whitening操作可以做简要介绍。
出发点: 以余弦相似度作为向量相似度衡量的指标的是建立在“标准正交基”的基础上的,基向量不同,向量中各个数值所代表的的意义也变不一样。然后经过BERT抽取之后的句向量所处的坐标系可能并非基于同一个“标准正交基”的坐标系。
解决方案: 将各个向量归一化到同一个标准正交基的坐标系中。一个猜测是,预训练语言模型生成的各个句向量应该在坐标系中的各个位置是相对均匀的,即表现出各项同性。基于这个猜测,我们可以将所有句向量做归一化,使之满足各向同性。一个可行的方案是将句向量的分布规约成正态分布,因为正态分布满足各项同性(数学定理)。
做法:
内容截图自苏神的博客: 链接
实验及结论
- 应该使用所有token embedding的average作为句子表示,而非只使用[CLS]对应位置的表示。
- 叠加BERT的1,2,12层这三层的向量效果表现最好。
- Whiten操作对于多数预训练语言模型而言均有效果。
代码
def whitening_torch_final(embeddings):
# For torch < 1.10
mu = torch.mean(embeddings, dim=0, keepdim=True)
cov = torch.mm((embeddings - mu).t(), (embeddings - mu))
# For torch >= 1.10
cov = torch.cov(embedding)
u, s, vt = torch.svd(cov)
W = torch.mm(u, torch.diag(1/torch.sqrt(s)))
embeddings = torch.mm(embeddings - mu, W)
return embeddings
在经过bert encoder之后的向量,送入whitening_torch_final函数中即可完成whitening的操作。
优化
根据苏神的博客,只保留SVD提取出来的前N个特征值可以提升进一步的效果。并且,由于只保留了前N个特征,故与PCA的原理类似,相当于对句向量做了一步降维的操作。
代码修改为:
def whitening_torch_final(embeddings, keep_dim=256):
# For torch >= 1.10
cov = torch.cov(embedding) # emb_dim * emb_dim
u, s, vt = torch.svd(cov)
# u : emb_dim * emb_dim, s: emb_dim
W = torch.mm(u, torch.diag(1/torch.sqrt(s))) # W: emb_dim * emb_dim
embeddings = torch.mm(embeddings - mu, W[:,:keep_dim]) # 截断
return embeddings # bs * keep_dim
边栏推荐
- SQLI-LABS通关(less1)
- 離線數倉和bi開發的實踐和思考
- IDEA2020中测试PySpark的运行出错
- SQLI-LABS通关(less6-less14)
- php中判断版本号是否连续
- The boss said: whoever wants to use double to define the amount of goods, just pack up and go
- Two table Association of pyspark in idea2020 (field names are the same)
- Take you to master the formatter of visual studio code
- Oracle段顾问、怎么处理行链接行迁移、降低高水位
- Tool grass welfare post
猜你喜欢
随机推荐
MySQL无order by的排序规则因素
php中的二维数组去重
Oracle 11.2.0.3 handles the problem of continuous growth of sysaux table space without downtime
架构设计三原则
ORACLE 11G SYSAUX表空间满处理及move和shrink区别
php中生成随机的6位邀请码
Oracle EBS数据库监控-Zabbix+zabbix-agent2+orabbix
CRP实施方法论
Error in running test pyspark in idea2020
Use of interrupt()
Tool grass welfare post
JSP智能小区物业管理系统
RMAN增量恢复示例(1)-不带未备份的归档日志
JS delete the last character of the string
Common prototype methods of JS array
Ceaspectuss shipping company shipping artificial intelligence products, anytime, anywhere container inspection and reporting to achieve cloud yard, shipping company intelligent digital container contr
Laravel8中的find_in_set、upsert的使用方法
使用 Compose 实现可见 ScrollBar
php中判断版本号是否连续
如何高效开发一款微信小程序