From 9a9633e4a60894301e387ac8c6ec02260cafeec8 Mon Sep 17 00:00:00 2001 From: Sina Hajimiri Date: Thu, 24 Oct 2024 15:02:49 -0400 Subject: [PATCH 1/2] Fix cosine LR scheduler for warmup --- timm/scheduler/cosine_lr.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/timm/scheduler/cosine_lr.py b/timm/scheduler/cosine_lr.py index 4eaaa86a81..5e7739d3a8 100644 --- a/timm/scheduler/cosine_lr.py +++ b/timm/scheduler/cosine_lr.py @@ -87,12 +87,12 @@ def _get_lr(self, t: int) -> List[float]: if self.cycle_mul != 1: i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul)) - t_i = self.cycle_mul ** i * self.t_initial - t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial + t_i = self.cycle_mul ** i * (self.t_initial - self.warmup_t) + t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * t_i else: i = t // self.t_initial - t_i = self.t_initial - t_curr = t - (self.t_initial * i) + t_i = self.t_initial - self.warmup_t + t_curr = t - (t_i * i) gamma = self.cycle_decay ** i lr_max_values = [v * gamma for v in self.base_values] From 772e3b7482f7c1280277a13f95b00e71bdc4ee35 Mon Sep 17 00:00:00 2001 From: Sina Hajimiri Date: Thu, 7 Nov 2024 13:00:33 -0500 Subject: [PATCH 2/2] Fix cosine LR scheduler for warmup and cycles --- timm/scheduler/cosine_lr.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/timm/scheduler/cosine_lr.py b/timm/scheduler/cosine_lr.py index 5e7739d3a8..aca01d6653 100644 --- a/timm/scheduler/cosine_lr.py +++ b/timm/scheduler/cosine_lr.py @@ -82,16 +82,16 @@ def _get_lr(self, t: int) -> List[float]: if t < self.warmup_t: lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] else: - if self.warmup_prefix: - t = t - self.warmup_t + t_i = self.t_initial + if self.warmup_prefix and t < self.t_initial: + t -= self.warmup_t + t_i -= self.warmup_t if self.cycle_mul != 1: i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul)) - t_i = self.cycle_mul ** i * (self.t_initial - self.warmup_t) t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * t_i else: i = t // self.t_initial - t_i = self.t_initial - self.warmup_t t_curr = t - (t_i * i) gamma = self.cycle_decay ** i @@ -100,7 +100,7 @@ def _get_lr(self, t: int) -> List[float]: if i < self.cycle_limit: lrs = [ - self.lr_min + 0.5 * (lr_max - self.lr_min) * (1 + math.cos(math.pi * t_curr ** k / t_i ** k)) + self.lr_min + 0.5 * (lr_max - self.lr_min) * (1 + math.cos(math.pi * t_curr ** k / (t_i * self.cycle_mul ** i) ** k)) for lr_max in lr_max_values ] else: