diff --git a/test/test_job_config.py b/test/test_job_config.py index 0e3d9c6302..23571f7de5 100644 --- a/test/test_job_config.py +++ b/test/test_job_config.py @@ -9,7 +9,7 @@ class TestJobConfig: def test_command_line_args(self): config = JobConfig() config.parse_args([]) - assert config.training.steps == -1 + assert config.training.steps == 10000 def test_job_config_file(self): config = JobConfig() diff --git a/torchtrain/config_manager.py b/torchtrain/config_manager.py index 613f941130..9bfaa94ab7 100644 --- a/torchtrain/config_manager.py +++ b/torchtrain/config_manager.py @@ -151,10 +151,10 @@ def init_args_from_command_line( "--training.seq_len", type=int, default=2048, help="sequence length" ) parser.add_argument( - "--training.warmup_pct", - type=float, - default=0.20, - help="percentage of total training steps to use for warmup", + "--training.warmup_steps", + type=int, + default=200, + help="steps for lr scheduler warmup", ) parser.add_argument( "--training.max_norm", @@ -163,7 +163,10 @@ def init_args_from_command_line( help="max norm for gradient clipping", ) parser.add_argument( - "--training.steps", type=int, default=-1, help="how many train steps to run" + "--training.steps", + type=int, + default=10000, + help="how many train steps to run", ) parser.add_argument( "--training.data_parallel_degree", diff --git a/torchtrain/lr_scheduling.py b/torchtrain/lr_scheduling.py index b76bd6d3f9..5961cecf73 100644 --- a/torchtrain/lr_scheduling.py +++ b/torchtrain/lr_scheduling.py @@ -1,3 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. @@ -6,7 +9,7 @@ # global states for scheduling # these are needed as LambdaLR does not support argument passing -_warmup_steps = 2 +_warmup_steps = 200 _decay_steps = 0 @@ -33,9 +36,7 @@ def linear_warmup_linear_decay(current_step: int) -> float: def get_lr_scheduler(optimizer, job_config: JobConfig): """Build a linear warmup and linear decay scheduler""" global _warmup_steps, _decay_steps - _warmup_steps = max( - int(job_config.training.steps * job_config.training.warmup_pct), 2 - ) + _warmup_steps = int(job_config.training.warmup_steps) _decay_steps = float(max(1, job_config.training.steps - _warmup_steps)) warmup_scheduler = LambdaLR(optimizer, lr_lambda=linear_warmup_linear_decay) diff --git a/train.py b/train.py index 95d4222656..1cfe835c0d 100644 --- a/train.py +++ b/train.py @@ -187,10 +187,7 @@ def main(job_config: JobConfig): losses_since_last_log: List[float] = [] nwords_since_last_log = 0 time_last_log = timer() - while ( - train_state.step < job_config.training.steps - or job_config.training.steps == -1 - ): + while train_state.step < job_config.training.steps: train_state.step += 1 # get batch data_load_start = timer() diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 1cca38b093..8b6c91c344 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -26,7 +26,7 @@ lr = 8e-4 [training] batch_size = 8 seq_len = 2048 -warmup_pct = 0.20 # lr scheduler warm up +warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps max_norm = 1.0 # grad norm clipping steps = 10 data_parallel_degree = -1