We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 4790c0f commit 2d597b1Copy full SHA for 2d597b1
timm/optim/nadamw.py
@@ -315,6 +315,11 @@ def _multi_tensor_nadamw(
315
316
bias_correction2_sqrt = torch._foreach_sqrt(bias_correction2)
317
318
+ # Only difference between NAdamW and AdamW in this implementation.
319
+ # The official PyTorch implementation of NAdam uses a different algorithm.
320
+ exp_avgs = torch._foreach_mul(exp_avgs, beta1)
321
+ torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1)
322
+
323
exp_avg_sq_sqrt = torch._foreach_sqrt(exp_avg_sqs)
324
torch._foreach_div_(
325
exp_avg_sq_sqrt, torch._foreach_mul(bias_correction2_sqrt, step_size)
0 commit comments