|
4 | 4 | # This source code is licensed under the BSD-style license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
| 7 | +import functools |
| 8 | + |
7 | 9 | from torch.optim.lr_scheduler import LambdaLR |
8 | 10 | from torchtitan.config_manager import JobConfig |
9 | 11 |
|
10 | | -# global states for scheduling |
11 | | -# these are needed as LambdaLR does not support argument passing |
12 | | -_warmup_steps = 200 |
13 | | -_decay_steps = 0 |
14 | | - |
15 | 12 |
|
16 | | -def linear_warmup_linear_decay(current_step: int) -> float: |
| 13 | +def linear_warmup_linear_decay( |
| 14 | + warmup_steps: int, decay_steps: int, current_step: int |
| 15 | +) -> float: |
17 | 16 | """Computes linear warmup followed by linear decay. |
18 | 17 | Per LambdaLR requirement, this is accomplished by returning |
19 | 18 | a multiplicative factor to adjust the learning rate to |
20 | 19 | create the desired schedule. |
21 | 20 | """ |
22 | | - if current_step < _warmup_steps: |
| 21 | + if current_step < warmup_steps: |
23 | 22 | # linear warmup |
24 | 23 | # 0-indexed step, hence + 1 adjustments |
25 | 24 | current_step += 1 |
26 | | - curr_adjustment = float(current_step / (_warmup_steps + 1)) |
| 25 | + curr_adjustment = float(current_step / (warmup_steps + 1)) |
27 | 26 |
|
28 | 27 | else: |
29 | 28 | # linear decay |
30 | | - normalized_step = _decay_steps - (current_step - _warmup_steps) |
31 | | - curr_adjustment = 1 - (_decay_steps - normalized_step) / _decay_steps |
| 29 | + normalized_step = decay_steps - (current_step - warmup_steps) |
| 30 | + curr_adjustment = 1 - (decay_steps - normalized_step) / decay_steps |
32 | 31 |
|
33 | 32 | return curr_adjustment |
34 | 33 |
|
35 | 34 |
|
36 | 35 | def get_lr_schedulers(optimizers, job_config: JobConfig): |
37 | 36 | def _get_lr_scheduler(optimizer): |
38 | 37 | """Build a linear warmup and linear decay scheduler""" |
39 | | - global _warmup_steps, _decay_steps |
40 | | - _warmup_steps = int(job_config.training.warmup_steps) |
41 | | - _decay_steps = float(max(1, job_config.training.steps - _warmup_steps)) |
42 | | - |
43 | | - warmup_scheduler = LambdaLR(optimizer, lr_lambda=linear_warmup_linear_decay) |
| 38 | + warmup_steps = int(job_config.training.warmup_steps) |
| 39 | + decay_steps = float(max(1, job_config.training.steps - warmup_steps)) |
| 40 | + lr_lambda = functools.partial( |
| 41 | + linear_warmup_linear_decay, warmup_steps, decay_steps |
| 42 | + ) |
| 43 | + warmup_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda) |
44 | 44 | return warmup_scheduler |
45 | 45 |
|
46 | 46 | class SchedulersContainer: |
|
0 commit comments