当前位置:网站首页>【调参Tricks】WhiteningBERT: An Easy Unsupervised Sentence Embedding Approach
【调参Tricks】WhiteningBERT: An Easy Unsupervised Sentence Embedding Approach
2022-07-02 06:25:00 【lwgkzl】
总述
该文主要介绍了三种使用BERT做Sentence Embedding的小Trick,分别为:
- 应该使用所有token embedding的average作为句子表示,而非只使用[CLS]对应位置的表示。
- 在BERT中应该使用多层的句向量叠加,而非只使用最后一层。
- 在通过余弦相似度做句子相似度判定的时候,可以使用Whitening操作来统一sentence embedding的向量分布,从而可以获得更好的句子表示。
模型
文中介绍的前两点均不涉及到模型,只有第三点Whitening操作可以做简要介绍。
出发点: 以余弦相似度作为向量相似度衡量的指标的是建立在“标准正交基”的基础上的,基向量不同,向量中各个数值所代表的的意义也变不一样。然后经过BERT抽取之后的句向量所处的坐标系可能并非基于同一个“标准正交基”的坐标系。
解决方案: 将各个向量归一化到同一个标准正交基的坐标系中。一个猜测是,预训练语言模型生成的各个句向量应该在坐标系中的各个位置是相对均匀的,即表现出各项同性。基于这个猜测,我们可以将所有句向量做归一化,使之满足各向同性。一个可行的方案是将句向量的分布规约成正态分布,因为正态分布满足各项同性(数学定理)。
做法:
内容截图自苏神的博客: 链接
实验及结论
- 应该使用所有token embedding的average作为句子表示,而非只使用[CLS]对应位置的表示。
- 叠加BERT的1,2,12层这三层的向量效果表现最好。


- Whiten操作对于多数预训练语言模型而言均有效果。
代码
def whitening_torch_final(embeddings):
# For torch < 1.10
mu = torch.mean(embeddings, dim=0, keepdim=True)
cov = torch.mm((embeddings - mu).t(), (embeddings - mu))
# For torch >= 1.10
cov = torch.cov(embedding)
u, s, vt = torch.svd(cov)
W = torch.mm(u, torch.diag(1/torch.sqrt(s)))
embeddings = torch.mm(embeddings - mu, W)
return embeddings
在经过bert encoder之后的向量,送入whitening_torch_final函数中即可完成whitening的操作。
优化
根据苏神的博客,只保留SVD提取出来的前N个特征值可以提升进一步的效果。并且,由于只保留了前N个特征,故与PCA的原理类似,相当于对句向量做了一步降维的操作。
代码修改为:
def whitening_torch_final(embeddings, keep_dim=256):
# For torch >= 1.10
cov = torch.cov(embedding) # emb_dim * emb_dim
u, s, vt = torch.svd(cov)
# u : emb_dim * emb_dim, s: emb_dim
W = torch.mm(u, torch.diag(1/torch.sqrt(s))) # W: emb_dim * emb_dim
embeddings = torch.mm(embeddings - mu, W[:,:keep_dim]) # 截断
return embeddings # bs * keep_dim
边栏推荐
- ORACLE EBS接口开发-json格式数据快捷生成
- Illustration of etcd access in kubernetes
- Oracle general ledger balance table GL for foreign currency bookkeeping_ Balance change (Part 1)
- Oracle segment advisor, how to deal with row link row migration, reduce high water level
- CRP implementation methodology
- 如何高效开发一款微信小程序
- MySQL中的正则表达式
- view的绘制机制(三)
- php中通过集合collect的方法来实现把某个值插入到数组中指定的位置
- MapReduce与YARN原理解析
猜你喜欢

Illustration of etcd access in kubernetes

Sqli-labs customs clearance (less2-less5)

ORACLE EBS中消息队列fnd_msg_pub、fnd_message在PL/SQL中的应用

MapReduce concepts and cases (Shang Silicon Valley Learning Notes)

ORACLE EBS 和 APEX 集成登录及原理分析

Sqli labs customs clearance summary-page2

view的绘制机制(一)

SSM实验室设备管理

How to call WebService in PHP development environment?

中年人的认知科普
随机推荐
JSP intelligent community property management system
sparksql数据倾斜那些事儿
php中通过集合collect的方法来实现把某个值插入到数组中指定的位置
Yaml file of ingress controller 0.47.0
Tool grass welfare post
SSM实验室设备管理
外币记账及重估总账余额表变化(下)
php中生成随机的6位邀请码
Thinkphp5中一个字段对应多个模糊查询
sqli-labs通關匯總-page2
MySQL组合索引加不加ID
Principle analysis of spark
MapReduce与YARN原理解析
JS countdown case
SQL注入闭合判断
2021-07-05c /cad secondary development create arc (4)
PM2 simple use and daemon
Sqli labs customs clearance summary-page2
pySpark构建临时表报错
sqli-labs通关汇总-page1