误区! Adam+L2并不能发挥效果!

  新闻资讯     |      2024-06-24 15:17

在SGD优化的前提下,L2正则和Weight Decay是等价的。当我们考虑为损失函数引入L2正则项时,Pytorch是这样实现的:

torch.optim.SGD(...,weight_decay=0.001)

而当我们使用Adam作为优化器时,是否可以通过下面的方式实现L2正则呢?

torch.optim.Adam(...,weight_decay=0.001)

答案是否定的。引用Bert原文语句:

Just adding the square of the weights to the loss function is *not* the correct way of using L2 regularization/weight decay with Adam, since that will interact with the m and v parameters in strange ways.

在Adam优化器中,weight decay与L2正则并不等价,除此之外,Adam+L2的方案会导致不理想的优化过程。论文《Decoupled Weight Decay Regularization》指出了这一点,并提出了AdamW优化器,实现了Adam与weight dacay共同使用时的解耦。实现方式(Pytorch):

torch.optim.AdamW(...,weight_decay=0.001)

我们不妨简单模拟一下Adam+L2的优化过程。

首先在损失函数中引入正则项,系数为 \\lambda ,求梯度结果为:

g_{t}=\
abla f_{t}(\	heta_{t-1}) +\\lambda \	heta_{t-1}\\\\

计算一阶动量 m_t 与二阶动量 v_t

m_t=\\beta_1 m_{t-1}+(1-\\beta_1)g_t  \\\\v_t=\\beta_2 v_{t-1}+(1-\\beta_2)g^2_t

更新参数, \\alpha 为学习率, \\varepsilon 是一个为了防止分母为0而引入的极小数:

\	heta_t=\	heta_{t-1}+\\frac{\\alpha m_t}{\\sqrt{v_t}+\\varepsilon}\\\\

以上是Adam+L2方案的简化优化过程。将 m_tg_t 代入上述公式:

\	heta_t=\	heta_{t-1}-\\frac{\\alpha[\\beta_1 m_{t-1}+(1-\\beta_1)(\
abla f_{t}(\	heta_{t-1}) +\\lambda \	heta_{t-1})]}{\\sqrt{v_t}+\\varepsilon}\\\\

我们知道L2正则化的目的在于使参数 \	heta 有更小的值。从上述公式分析:分子发挥正常效果,对于大的 \	heta 加大惩罚;而分母使得在梯度快速变化的方向( \	heta 较大)更新的更少,从而削弱了L2正则的惩罚。正是由于 \\frac{\\lambda\	heta}{\\sqrt{v_t}} 中分子与分母的相互作用,使得L2正则效果变得模糊。

AdamW将L2正则与Adam进行简单的解耦,从而获得了效果的提升。既然 \\frac{\\lambda\	heta}{\\sqrt{v_t}} 这一项会引入无效的优化计算,那么只需要删除这一项,在更新时额外加入Weight Decay即可。

\	heta_t=\	heta_{t-1}-\\frac{\\alpha[\\beta_1 m_{t-1}+(1-\\beta_1)(\
abla f_{t}(\	heta_{t-1}) ]}{\\sqrt{v_t}+\\varepsilon}-\\alpha\\lambda \	heta_{t-1}\\\\

很容易发现,在理论上可以认为AdamW=Adam+Weight Decay。那回到开头,我们是不是可以通过下面方式实现AdamW相同的效果:

torch.optim.Adam(...,weight_decay=0.001)

答案是否定的。这是由于在大多数库中实现Weight Decay的方式并不是正确的,在Adam中,Weight Decay通常以第一种的方式实现,而不是直接将权重进行衰减:

# I st: Adam weight decay implementation (L2 regularization)
final_loss=loss + wd * all_weights.pow(2).sum() / 2

# II nd: equivalent to this in SGD
w=w - lr *w.grad - lr *wd * w

而AdamW采用第二种方式实现。


参考资料:

arxiv.org/pdf/1711.0510

stackoverflow.com/quest

fast.ai/2018/07/02/adam

paperplanet:都9102年了,别再用Adam + L2 regularization了