当前位置:网站首页>【多任务优化】DWA、DTP、Gradnorm(CVPR 2019、ECCV 2018、 ICML 2018)
【多任务优化】DWA、DTP、Gradnorm(CVPR 2019、ECCV 2018、 ICML 2018)
2022-08-01 19:53:00 【chad_lee】
多任务学习模型的优化
有多个task就有多个loss,常见的MTL模型loss可以直接简单的对多个任务的loss相加:
L = ∑ i L i L=\sum_{i} L_{i} L=i∑Li
显然这种做法有很大问题,因为不同task的label分布不同,同时不同task的loss量级也不同,整个模型很可能被一些loss特别大的任务主导。最简单的方法是加权loss,人工设计权重:
L = ∑ i w i ∗ L i L=\sum_{i} w_{i} * L_{i} L=i∑wi∗Li
但是这样这个权重在整个训练周期中都是固定的,不同训练阶段权重可能变化,动态权重则为:
L = ∑ i w i ( t , θ ) ∗ L i L=\sum_{i} w_{i}(t, \theta) * L_{i} L=i∑wi(t,θ)∗Li
t是训练的step,theta是模型其他参数。但是这种做法也不一定有人工设计权重好。
一些设计 w i ( t , θ ) w_{i}(t, \theta) wi(t,θ) 的方法:
《End-to-End Multi-Task Learning with Attention》 CVPR 2019
CVPR 2019的《End-to-End Multi-Task Learning with Attention》提出的Dynamic Weight Averaging(DWA),核心公式如下所示:
r n ( t − 1 ) = L n ( t − 1 ) L n ( t − 2 ) w i ( t ) = N exp ( r i ( t − 1 ) / T ) ∑ n exp ( r n ( t − 1 ) / T ) \begin{gathered} r_{n}(t-1)=\frac{L_{n}(t-1)}{L_{n}(t-2)} \\ w_{i}(t)=\frac{N \exp \left(r_{i}(t-1) / T\right)}{\sum_{n} \exp \left(r_{n}(t-1) / T\right)} \end{gathered} rn(t−1)=Ln(t−2)Ln(t−1)wi(t)=∑nexp(rn(t−1)/T)Nexp(ri(t−1)/T)
$L_{n}(t-1) 是任务 n 在 t − 1 时的训练 l o s s ,因此 是任务 n 在 t-1 时的训练loss,因此 是任务n在t−1时的训练loss,因此r_{n}(t-1) $ 是此时loss的下降速度,$r_{n}(t-1) $越小,训练速度越快。(已经开始收敛,loss=0时结束了)
w i ( t ) w_i(t) wi(t)代表不同任务loss的权重,直观理解就是loss收敛越快的任务,权重越小,权重的平均程度由温度系数T控制
《Dynamic task prioritization for multitask learning》 ECCV 2018
DTP(Dynamic Task Prioritization):
w i ( t ) = − ( 1 − k i ( t ) ) γ i log ( k i ( t ) ) w_{i}(t)=-\left(1-k_{i}(t)\right)^{\gamma_{i}} \log \left(k_{i}(t)\right) wi(t)=−(1−ki(t))γilog(ki(t))
$k_i(t) $ 表示第t步的某衡量kpi值,取值维0~1之间,比如在分类任务中KPI可以是训练集上的准确率等,可以反应模型在这个任务上的拟合程度,γ是人工调节的温度系数。直观理解就类似focal loss,优化的越好的任务,获得的权重越小。
《Gradnorm: Gradient normalization for adaptive loss balancing in deep multitask networks》 ICML 2018
影响力最大的是GradNorm。其核心思想相对前述的DWA和DTP更为复杂,核心观点为
不仅考虑loss收敛的速度,进一步希望loss本身的量级能尽量接近
不同的任务以相近的速度训练(与gradient相关)
从带参数的动态权重 $L=\sum_{i} w_{i}(t, \theta) * L_{i} $出发,作者还定义了一个训练权重 $w_i(t,\theta) $ 相关的 gradient loss。(定一个loss用来优化训练loss的权重)
w i ( t ) w_i(t) wi(t) 刚开始初始化为1 或者超参,然后用gradient loss来优化。
首先得到任务 i 在 t 时刻的 梯度的2范数,以及所有任务的平均值:
G W ( i ) ( t ) = ∥ ∇ W ( w i ( t ) L i ( t ) ) ∥ 2 G ˉ W ( t ) = A V G ( G W ( i ) ( t ) ) \begin{gathered} G_{W}^{(i)}(t)=\left\|\nabla_{W}\left(w_{i}(t) L_{i}(t)\right)\right\|_{2} \\ \bar{G}_{W}(t)=A V G\left(G_{W}^{(i)}(t)\right) \end{gathered} GW(i)(t)=∥∇W(wi(t)Li(t))∥2GˉW(t)=AVG(GW(i)(t))
其中W是模型参数的子集,也是需要应用Gradient Normalization的参数集,一般是选择模型中共享参数的最后一层。然后得到不同任务loss的训练速度:
L ~ i ( t ) = L i ( t ) / L i ( 0 ) r i ( t ) = L ~ i ( t ) / A V G ( L ~ i ( t ) ) \begin{gathered} \tilde{L}_{i}(t)=L_{i}(t) / L_{i}(0) \\ r_{i}(t)=\tilde{L}_{i}(t) / A V G\left(\tilde{L}_{i}(t)\right) \end{gathered} L~i(t)=Li(t)/Li(0)ri(t)=L~i(t)/AVG(L~i(t))
$r_{i}(t) $衡量任务训练的速度, $r_{i}(t) $ 越大,表明任务训练的越慢。这点和DWA的思想接近,但是这里使用的是第一步的loss,而不是DWA中的前一步loss。
最终gradient loss为:
L g r a d ( t ; w i ( t ) ) = ∑ i ∣ G W ( i ) ( t ) − G ˉ W ( t ) ∗ [ r i ( t ) ] α ∣ 1 L_{g r a d}\left(t ; w_{i}(t)\right)=\sum_{i}\left|G_{W}^{(i)}(t)-\bar{G}_{W}(t) *\left[r_{i}(t)\right]^{\alpha}\right|_{1} Lgrad(t;wi(t))=i∑∣∣GW(i)(t)−GˉW(t)∗[ri(t)]α∣∣1
- $\bar{G}{W}(t) *\left[r{i}(t)\right]^{\alpha} $表示理想的梯度标准化后的值。**这里的gradient loss只用于更新 $w_{i}(t) ∗ ∗ 。 **。 ∗∗。w_{i}(t) $还会经过最终的重normalize,使得 $\sum_{i} w_{i}(t)=N $,N是任务的数量。
- α是设定恢复力强度的超参数,即将任务的训练速度调节到平均水准的强度。如果任务的复杂程度很不一样,大致人物之间的学习速率大不相同,就应该使用较高的alpha来进行较强的训练速率平衡;反之对于多个相似的任务,应该使用较小的α。
- 从gradient loss的定义来看, $r_i(t) $ 越大,表明训练越快,gradient loss越大;$\left|G_{W}^{(i)}(t)-\bar{G}{W}(t)\right| 表明 l o s s 量级的变化,不论 表明loss量级的变化,不论 表明loss量级的变化,不论G{W}^{(i)}(t) $过大或者过小都会导致gradient loss变大。
- 所以gradient loss 希望:1、不同任务的loss的量级接近;2、不同任务以相近的速度训练(收敛速度)
边栏推荐
猜你喜欢
Intranet penetration lanproxy deployment
Try compiling QT test on Allwinner V853 development board
Find the sum of two numbers
30-day question brushing plan (5)
如何看待腾讯云数据库负责人林晓斌借了一个亿炒股?
【七夕特别篇】七夕已至,让爱闪耀
第59章 ApplicationPart内置依赖注入中间件
明日盛会|ApacheCon Asia 2022 Pulsar 技术议题一览
30天刷题计划(五)
Redis启动时提示Creating Server TCP listening socket *:6379: bind: No error
随机推荐
PHP 安全最佳实践
研究生新同学,牛人看英文文献的经验,值得你收藏
MySQL开发技巧——存储过程
Ruijie switch basic configuration
When installing the GBase 8c database, the error message "Resource: gbase8c already in use" is displayed. How to deal with this?
安全作业7.25
easyUI中datagrid中的formatter里面向后台发送请求获取数据
ARTS_202207W2
ssh & scp
专利检索常用的网站有哪些?
Find the sum of two numbers
kubernetes - deploy nfs storage class
小白系统初始化配置资源失败怎么办
Ha ha!A print function, quite good at playing!
不要再使用MySQL online DDL了
To drive efficient upstream and downstream collaboration, how can cross-border B2B e-commerce platforms release the core value of the LED industry supply chain?
JS数组过滤
Choosing the right DevOps tool starts with understanding DevOps
MySQL你到底都加了什么锁?
57:第五章:开发admin管理服务:10:开发【从MongoDB的GridFS中,获取文件,接口】;(从GridFS中,获取文件的SOP)(不使用MongoDB的服务,可以排除其自动加载类)