Skip to content

Commit 2d597b1

Browse files
committed
Missed extra nadam algo step for capturable path
1 parent 4790c0f commit 2d597b1

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

timm/optim/nadamw.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,11 @@ def _multi_tensor_nadamw(
315315

316316
bias_correction2_sqrt = torch._foreach_sqrt(bias_correction2)
317317

318+
# Only difference between NAdamW and AdamW in this implementation.
319+
# The official PyTorch implementation of NAdam uses a different algorithm.
320+
exp_avgs = torch._foreach_mul(exp_avgs, beta1)
321+
torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1)
322+
318323
exp_avg_sq_sqrt = torch._foreach_sqrt(exp_avg_sqs)
319324
torch._foreach_div_(
320325
exp_avg_sq_sqrt, torch._foreach_mul(bias_correction2_sqrt, step_size)

0 commit comments

Comments
 (0)