在SGD优化的前提下,L2正则和Weight Decay是等价的。当我们考虑为损失函数引入L2正则项时,Pytorch是这样实现的:
torch.optim.SGD(...,weight_decay=0.001)
而当我们使用Adam作为优化器时,是否可以通过下面的方式实现L2正则呢?
torch.optim.Adam(...,weight_decay=0.001)
答案是否定的。引用Bert原文语句:
Just adding the square of the weights to the loss function is *not* the correct way of using L2 regularization/weight decay with Adam, since that will interact with the m and v parameters in strange ways.
在Adam优化器中,weight decay与L2正则并不等价,除此之外,Adam+L2的方案会导致不理想的优化过程。论文《Decoupled Weight Decay Regularization》指出了这一点,并提出了AdamW优化器,实现了Adam与weight dacay共同使用时的解耦。实现方式(Pytorch):
torch.optim.AdamW(...,weight_decay=0.001)
我们不妨简单模拟一下Adam+L2的优化过程。
首先在损失函数中引入正则项,系数为 ,求梯度结果为:
计算一阶动量 与二阶动量
:
更新参数, 为学习率,
是一个为了防止分母为0而引入的极小数:
以上是Adam+L2方案的简化优化过程。将 与
代入上述公式:
我们知道L2正则化的目的在于使参数 有更小的值。从上述公式分析:分子发挥正常效果,对于大的
加大惩罚;而分母使得在梯度快速变化的方向(
较大)更新的更少,从而削弱了L2正则的惩罚。正是由于
中分子与分母的相互作用,使得L2正则效果变得模糊。
AdamW将L2正则与Adam进行简单的解耦,从而获得了效果的提升。既然 这一项会引入无效的优化计算,那么只需要删除这一项,在更新时额外加入Weight Decay即可。
很容易发现,在理论上可以认为AdamW=Adam+Weight Decay。那回到开头,我们是不是可以通过下面方式实现AdamW相同的效果:
torch.optim.Adam(...,weight_decay=0.001)
答案是否定的。这是由于在大多数库中实现Weight Decay的方式并不是正确的,在Adam中,Weight Decay通常以第一种的方式实现,而不是直接将权重进行衰减:
# I st: Adam weight decay implementation (L2 regularization)
final_loss=loss + wd * all_weights.pow(2).sum() / 2
# II nd: equivalent to this in SGD
w=w - lr *w.grad - lr *wd * w
而AdamW采用第二种方式实现。
参考资料:
https://arxiv.org/pdf/1711.05101.pdf
https://stackoverflow.com/questions/64621585/adamw-and-adam-with-weight-decay