diff --git a/timm/scheduler/cosine_lr.py b/timm/scheduler/cosine_lr.py index 4eaaa86a81..aca01d6653 100644 --- a/timm/scheduler/cosine_lr.py +++ b/timm/scheduler/cosine_lr.py @@ -82,17 +82,17 @@ def _get_lr(self, t: int) -> List[float]: if t < self.warmup_t: lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] else: - if self.warmup_prefix: - t = t - self.warmup_t + t_i = self.t_initial + if self.warmup_prefix and t < self.t_initial: + t -= self.warmup_t + t_i -= self.warmup_t if self.cycle_mul != 1: i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul)) - t_i = self.cycle_mul ** i * self.t_initial - t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial + t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * t_i else: i = t // self.t_initial - t_i = self.t_initial - t_curr = t - (self.t_initial * i) + t_curr = t - (t_i * i) gamma = self.cycle_decay ** i lr_max_values = [v * gamma for v in self.base_values] @@ -100,7 +100,7 @@ def _get_lr(self, t: int) -> List[float]: if i < self.cycle_limit: lrs = [ - self.lr_min + 0.5 * (lr_max - self.lr_min) * (1 + math.cos(math.pi * t_curr ** k / t_i ** k)) + self.lr_min + 0.5 * (lr_max - self.lr_min) * (1 + math.cos(math.pi * t_curr ** k / (t_i * self.cycle_mul ** i) ** k)) for lr_max in lr_max_values ] else: