当前位置:网站首页>【AI4Code】《GraphCodeBERT: Pre-Training Code Representations With DataFlow》 ICLR 2021
【AI4Code】《GraphCodeBERT: Pre-Training Code Representations With DataFlow》 ICLR 2021
2022-07-25 11:11:00 【chad_lee】
《GraphCodeBERT: Pre-Training Code Representations With DataFlow》 ICLR 2021

近年来,应用于编程语言的预训练模型得到飞速发展,相关任务比如code search, code completion, code summarization 也得到提升。但是,现有的预训练模型是将code snippet(代码片段)视为一个token序列,忽视了代码的结构。
本文的GraphCodeBERT,没有用句法级别的AST,而是用的代码的数据流(data flow )来表示源代码信息。代码的数据流是一个graph,节点表示一个变量变量(variable),边表示变量之间的依赖关系(where-the-value-comes-from)。不用AST是考虑到数据流图不像AST这么复杂,不会带来不必要的深层信息。
本文的下游任务是natural language code search代码搜索、clone detection克隆检测、code translation代码翻译、code refinement修bug。
数据流图data flow
data flow是一个graph,节点是变量,边表示where the value of each variable comes from。
**为什么要建图?**对于同一个源代码,用不同的抽象语法得到的AST是不同的,但是代码的数据流是不变的。因此数据流图可以提供重要的语义信息。

举个例子比如 v = max value − min value ,程序员并不一定总是按照规定命名变量,因此想要了解变量 v 的语义,可以考虑变量v的来源,来源于数据流中的max和min。此外数据流图还可以支持解析同一变量在不同的执行阶段所具有的不同的语义信息,比如图中的x3, x7, x9, x11虽然都是 x这个token,但是语义信息是不同的,当作token序列训练时不太合适的。
构造数据流图的方法如上所示,对于一段代码片段 C = { c 1 , c 2 , … , c n } C = \left\{c_{1}, c_{2}, \ldots, c_{n}\right\} C={ c1,c2,…,cn},先用编译工具(Tree-sitter)将其解析成 AST,AST包含了代码段的句法信息,将AST的叶子节点识别为变量序列 V = { v 1 , v 2 , … , v k } V=\left\{v_{1}, v_{2}, \ldots, v_{k}\right\} V={ v1,v2,…,vk}。然后将每个变量作为一个节点,有向边 ε = * v i , v j * \varepsilon=\left\langle v_{i}, v_{j}\right\rangle ε=*vi,vj* 表示变量 j 的值依赖于 变量 i 的值。举例如代码 x = expr,x依赖于等号右侧表达式中的所有变量,所以数据流图是有向图,a指向x意味着x依赖于a。有向边的集合是 E = { ε 1 , ε 2 , … , ε l } E=\left\{\varepsilon_{1}, \varepsilon_{2}, \ldots, \varepsilon_{l}\right\} E={ ε1,ε2,…,εl},代码C的数据流图表示为 G ( C ) = ( V , E ) \mathcal{G}(C)=(V, E) G(C)=(V,E)。
模型
模型架构用的就是标准BERT,一些模型结构参数就不细讲了。唯一的区别是在Attention模块里有一个基于图 G ( C ) = ( V , E ) \mathcal{G}(C)=(V, E) G(C)=(V,E)的mask(毕竟图结构信息得用)

输入输出
有三种序列:代码片段 C = { c 1 , c 2 , … , c n } C=\left\{c_{1}, c_{2}, \ldots, c_{n}\right\} C={ c1,c2,…,cn},该段代码的注释文本片段 W = { w 1 , w 2 , … , w m } W=\left\{w_{1}, w_{2}, \ldots, w_{m}\right\} W={ w1,w2,…,wm} 以及 变量节点序列 V = { v 1 , v 2 , … , v k } V=\left\{v_{1}, v_{2}, \ldots, v_{k}\right\} V={ v1,v2,…,vk}。 输入 X 由三段序列拼接起来: X = { [ C L S ] , W , [ S E P ] , C , [ S E P ] , V } X=\{[C L S], W,[S E P], C,[S E P], V\} X={[CLS],W,[SEP],C,[SEP],V}
输出则是每个token的向量表示,用于完成各种预训练任务。
Graph-Guided Masked Attention
这里主要在BERT中对multi-head attention做了一个设计,multi-head的输出为:
h e a d i = softmax ( Q i ⋅ K i T d k + M ) ⋅ V i G ^ n = [ head 1 ; … ; head u ] ⋅ W n O \begin{gathered} h e a d_{i}=\operatorname{softmax}\left(\frac{Q_{i} \cdot K_{i}^{T}}{\sqrt{d_{k}}}+M\right) \cdot V_{i} \\ \hat{G}^{n}=\left[\text { head }_{1} ; \ldots ; \text { head }_{u}\right] \cdot W_{n}^{O} \end{gathered} headi=softmax(dkQi⋅KiT+M)⋅ViG^n=[ head 1;…; head u]⋅WnO
其中 M 是Graph-Guided Masked Attention 矩阵(GraphCodeBert相比于Bert的特色之处),是
∣ X ∣ × ∣ X ∣ |X| \times|X| ∣X∣×∣X∣ 维度的向量, M M M有两个作用:1、 第 i 个变量如果和第 j 个变量在数据流图中没有边连接的话( * v j , v i * ∈ E ) \left.\left\langle v_{j}, v_{i}\right\rangle \in E\right) *vj,vi*∈E)),softmax的权重 $ M_{i j}$ 为负无穷,即不允许变量 i 去关注 变量j,有边连接就是0,允许i 注意 j;2、如果变量节点 v i v_i vi 是从代码token c j c_j cj 识别来的,就允许 i和j 互相注意,否则也是负无穷。
M i j = { 0 i f ( q i ∈ [ C L S ] , [ S E P ] ) or ( q i , k j ∈ W ∪ C ) or ( * q i , k j * ∈ E ∪ E ′ ) − ∞ otherwise M_{i j}=\left\{\begin{array}{rl} 0 & i f\left(q_{i} \in[C L S],[S E P]\right) \operatorname{or}\left(q_{i}, k_{j} \in W \cup C\right) \operatorname{or}\left(\left\langle q_{i}, k_{j}\right\rangle \in E \cup E^{\prime}\right) \\ -\infty & \text { otherwise } \end{array}\right. Mij={ 0−∞if(qi∈[CLS],[SEP])or(qi,kj∈W∪C)or(*qi,kj*∈E∪E′) otherwise 
- 白色部分的值均为0
- 橙色部分,如果代码token c i c_i ci 与变量 v j v_j vj 有对应关系,比如return x中的token x和x11就有对应关系,那么 M c i v j = 0 M_{c_{i} v_{j}}=0 Mcivj=0; 其他token(包括其他的x)和 x11就没有对应关系,就设置为-∞。
- 蓝色部分,如果变量 v i v_i vi 和变量 v j v_j vj 有数据流关系, M c i v j = 0 M_{c_{i} v_{j}}=0 Mcivj=0就是0,否则负无穷。
这里更体现了Multi-head attention和图卷积的关系,只在有边连接的节点之间计算attention。
预训练任务
三个预训练任务,分别是MLM、Edge Prediction 和 Node Alignment
Masked Language Modeling
只在代码序列和注释文本序列上做MLM
Edge Prediction
数据流图的边预测,目的是让模型学习"where-the-value-comes-from"的信息,对应架构图中蓝色部分。随机从数据流图中采样20%的节点记为 V s V_{s} Vs ,然后mask掉这20%节点所设计的边,mask做法就是将 边Mask矩阵中的值设为负无穷。然后用BERT的输出带入BCE二分类loss中做Edge Prediction:
loss E d g e P r e d = − ∑ e i j ∈ E c [ δ ( e i j ∈ E m a s k ) log p e i j + ( 1 − δ ( e i j ∈ E m a s k ) ) log ( 1 − p e i j ) ] \operatorname{loss}_{E d g e P r e d}=-\sum_{e_{i j} \in E_{c}}\left[\delta\left(e_{i j} \in E_{m a s k}\right) \log p_{e_{i j}}+\left(1-\delta\left(e_{i j} \in E_{m a s k}\right)\right) \log \left(1-p_{e_{i j}}\right)\right] lossEdgePred=−eij∈Ec∑[δ(eij∈Emask)logpeij+(1−δ(eij∈Emask))log(1−peij)]
这里 δ ( e i j ∈ E ) \delta\left(e_{i j} \in E\right) δ(eij∈E) is 1 if * v i , v j * ∈ E \left\langle v_{i}, v_{j}\right\rangle \in E *vi,vj*∈E otherwise 0 就是BCE的label, p e i j p_{e_{i j}} peij 就是BERT输出的embedding的内积。这里还考虑负采样。
Node Alignment

为了对齐代码序列的表征和数据流图的表征,代码序列中出现了 4个相同的token “x”,但是数据流图中的 x11 应该对应的是代码序列中最后一个表达式 “return x”的x。
基于这种思想,具体做法是先将 mask 矩阵M中 “x”和 x11的边 mask掉(从0变为-∞),对BERT的输出做BCE二分类,这里的负样本就是代码序列中其他的token x:
loss N o d e A l i g n = − ∑ e i j ∈ E c ′ [ δ ( e i j ∈ E m a s k ′ ) log p e i j + ( 1 − δ ( e i j ∈ E m a s k ′ ) ) log ( 1 − p e i j ) ] \operatorname{loss}_{N o d e A l i g n}=-\sum_{e_{i j} \in E_{c}^{\prime}}\left[\delta\left(e_{i j} \in E_{m a s k}^{\prime}\right) \log p_{e_{i j}}+\left(1-\delta\left(e_{i j} \in E_{m a s k}^{\prime}\right)\right) \log \left(1-p_{e_{i j}}\right)\right] lossNodeAlign=−∑eij∈Ec′[δ(eij∈Emask′)logpeij+(1−δ(eij∈Emask′))log(1−peij)]
实验
4个下游任务:搜代码、代码克隆检测、代码翻译、修bug
NATURAL LANGUAGE CODE SEARCH
代码搜索任务的含义是给定一种自然语言输入,要求从一组候选代码中找到语义最相关的代码,使用的数据集是CodeSearchNet的数据集,使用代码文档的第一段作为query,用GraphCodeBERT分别encode query 和 代码序列+数据流图,然后用 [CLS]输出的表征来计算相似度。也可以fine-tuning,就是双塔。

Code Clone Detection
给定两个代码片段,要求度量其相似性,用的是BigCloneBench数据集。输入是代码片段及数据流图,还是用 CLS的表征。

Code Translation
代码翻译的含义是将一种编程语言翻译成另一种编程语言,其目的是将遗留软件从平台的一种编程语言迁移到另一种编程语言,以Lucene、POI等开源项目为数据集,这些项目都有Java和C#的实现,任务中模型输入Java(C#)代码,输出与之对应的C#(Java)代码。
做法是将预训练的GraphCodeBERT作为Encoder,然后随机初始化一个decoder,然后fine-tuning。

Code Redinement
修bug,输入带bug的JAVA代码,输出正确的JAVA代码,流程和代码翻译类似。

边栏推荐
- 【多模态】《HiT: Hierarchical Transformer with Momentum Contrast for Video-Text Retrieval》ICCV 2021
- Various controls ==pyqt5
- LeetCode 50. Pow(x,n)
- 【IMX6ULL笔记】--内核底层驱动初步探究
- 奉劝那些刚参加工作的学弟学妹们:要想进大厂,这些并发编程知识是你必须要掌握的!完整学习路线!!(建议收藏)
- [electronic device notes 5] diode parameters and selection
- Learning to Pre-train Graph Neural Networks(图预训练与微调差异)
- Power Bi -- these skills make the report more "compelling"“
- 'C:\xampp\php\ext\php_ zip. Dll'-%1 is not a valid Win32 Application Solution
- Brpc source code analysis (VIII) -- detailed explanation of the basic class eventdispatcher
猜你喜欢

Start with the development of wechat official account

LeetCode 50. Pow(x,n)

dirReader. Readentries compatibility issues. Exception error domexception

Varest blueprint settings JSON

异构图神经网络用于推荐系统问题(ACKRec,HFGN)

brpc源码解析(八)—— 基础类EventDispatcher详解
![[MySQL learning 08]](/img/9e/6e5f0c4c956ca8dc31d82560262013.png)
[MySQL learning 08]
![[imx6ull notes] - a preliminary exploration of the underlying driver of the kernel](/img/0f/a0139be99c61fde08e73a5be6d6b4c.png)
[imx6ull notes] - a preliminary exploration of the underlying driver of the kernel

JVM performance tuning methods

Teach you how to configure S2E as the working mode of TCP client through MCU
随机推荐
winddows 计划任务执行bat 执行PHP文件 失败的解决办法
Qin long, a technical expert of Alibaba cloud: a prerequisite for reliability assurance - how to carry out chaos engineering on the cloud
Hardware connection server TCP communication protocol gateway
Management of software defects
【高并发】我用10张图总结出了这份并发编程最佳学习路线!!(建议收藏)
【高并发】高并发场景下一种比读写锁更快的锁,看完我彻底折服了!!(建议收藏)
Chapter 4 linear equations
软件测试阶段的风险
任何时间,任何地点,超级侦探,认真办案!
Intelligent information retrieval(智能信息检索综述)
session和cookie有什么区别??小白来告诉你
brpc源码解析(一)—— rpc服务添加以及服务器启动主要过程
toString()与new String()用法区别
Layout management ==pyqt5
Start with the development of wechat official account
异构图神经网络用于推荐系统问题(ACKRec,HFGN)
Teach you how to configure S2E as the working mode of TCP client through MCU
PHP 上传ftp路径文件到外网服务器上 curl base64图片
已解决 Files‘ name is invalid or does not exist (1205)
Brpc source code analysis (VII) -- worker bthread scheduling based on parkinglot