当前位置:网站首页>动态权重之多任务不平衡论文 (二) MetaBalance

动态权重之多任务不平衡论文 (二) MetaBalance

2022-08-02 15:05:00 枫桦

以下文章来源于秋枫学习笔记 ,作者秋枫学习笔记

title:MetaBalance: Improving Multi-Task Recommendations via Adapting Gradient Magnitudes of Auxiliary Tasks link:https://arxiv.org/pdf/2203.06801v1.pdf code:https://github.com/facebookresearch/MetaBalance from:WWW 2022

1. 导读

在推荐场景中使用多任务学习,通常会遇到严重的优化不平衡问题。

  • 一方面,一个或多个辅助任务可能比目标任务具有更大的影响,甚至主导网络权重,导致目标任务的推荐精度降低。
  • 另一方面,一个或多个辅助任务的影响可能太弱,无法帮助目标任务。
  • 这种不平衡在整个训练过程中动态变化,并在同一网络的不同部分发生变化。

本文提出了一种新的方法:MetaBalance,该方法对辅助任务的梯度依据目标任务的梯度进行缩放,并且在缩放的同时保留一部分自身的梯度,从而缓解辅助任务梯度过大过小的问题。

2. 问题定义

令θ表示共享参数,这类参数是被目标任务和辅助任务共同优化的,损失函数可以表示为下式,

\mathcal{L}_{t o t a l}=\mathcal{L}_{t a r}+\sum_{i=1}^{K} \mathcal{L}_{a u x, i}

然后利用损失函数的梯度来更新参数θ,表示如下,用

G_{tar}

表示目标网络的梯度,用

G_{aux,i}

表示辅助任务的梯度,

||G||

表示正则项的梯度。

\theta^{t+1}=\theta^{t}-\alpha * G_{total}^t
\mathrm{G}_{\text {total }}^{t}=\nabla_{\theta} \mathcal{L}_{\text {total }}^{t}=\nabla_{\theta} \mathcal{L}_{t a r}^{t}+\sum_{i=1}^{K} \nabla_{\theta} \mathcal{L}_{a u x, i}^{t}

3. 方法

3.1 调整辅助任务梯度幅度

主任务和辅助任务梯度幅度的不平衡会对整体任务带来负面影响,MetaBalance通过三种策略和放松因子对梯度进行动态的、自适应的调整。

基础版伪代码如下,主要包括四个步骤:

  • 分别计算主任务和辅助任务的梯度,
G_{tar}^t

,

G_{aux,i}^t

  • 在第 5 行中,可以选择减小幅度大于目标梯度的辅助梯度,或者放大幅度较小的辅助梯度,或者同时应用这两种策略。可以根据目标任务的验证性能来选择策略。
  • 将辅助梯度标准化为单位向量,然后和目标梯度相乘得到新的辅助梯度
  • 更新参数

image.png

  • 优点:通过标准化后与目标梯度相乘使得目标任务和辅助任务的梯度能够在相同的量级上,缓解辅助任务梯度过大或过小的问题。
  • 缺点:辅助梯度是依据主任务梯度生成的,但是主任务的梯度未必是准确的或最优的,因此定义了一个放松因子来控制辅助梯度向主梯度的靠近程度。

3.2 调整幅度接近度

本文设置了一个放松因子r来控制辅助梯度向主梯度的靠近程度,r为超参数,公式如下,

\mathrm{G}_{a u x, i}^{t} \leftarrow\left(\mathrm{G}_{a u x, i}^{t} * \frac{\left\|\mathrm{G}_{t a r}^{t}\right\|}{\left\|\mathrm{G}_{a u x, i}^{t}\right\|}\right) * r+\mathrm{G}_{a u x, i}^{t} *(1-r)

上式可以改写为下式,当

||G_{tar}^t|| > ||G_{aux,i}^t||

时,r越大,w越大;反之,r越大,w越小。

\mathrm{G}_{a u x, i}^{t} \leftarrow \mathrm{G}_{a u x, i}^{t} * w_{a u x, i}^{t}
w_{a u x, i}^{t}=\left(\frac{\left\|\mathrm{G}_{t a r}^{t}\right\|}{\left\|\mathrm{G}_{a u x, i}^{t}\right\|}-1\right) * r+1

并且利用梯度的移动平均来替代原来的即时梯度,从而考虑梯度之间的方差,公式如下,

\begin{array}{c} m_{\text {tar }}^{t}=\beta * m_{\text {tar }}^{t-1}+(1-\beta) *\left\|\mathrm{G}_{\text {tar }}^{t}\right\| \\ m_{a u x, i}^{t}=\beta * m_{a u x, i}^{t-1}+(1-\beta) *\left\|\mathrm{G}_{a u x, i}^{t}\right\|, \forall i=1, \ldots, K \end{array}

伪代码如下,

4. 结果

image.png

原网站

版权声明
本文为[枫桦]所创,转载请带上原文链接,感谢
https://cloud.tencent.com/developer/article/2064431