Skip to content

Commit 73482d7

Browse files
committed
use warmup steps for lr scheduler, ban steps == -1
as titled, we don't want to allow steps == -1 case as it would blow up the lr scheduler
1 parent 96d1cb1 commit 73482d7

File tree

4 files changed

+10
-15
lines changed

4 files changed

+10
-15
lines changed

torchtrain/config_manager.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,10 @@ def init_args_from_command_line(
151151
"--training.seq_len", type=int, default=2048, help="sequence length"
152152
)
153153
parser.add_argument(
154-
"--training.warmup_pct",
155-
type=float,
156-
default=0.20,
157-
help="percentage of total training steps to use for warmup",
154+
"--training.warmup_steps",
155+
type=int,
156+
default=200,
157+
help="steps for lr scheduler warmup",
158158
)
159159
parser.add_argument(
160160
"--training.max_norm",
@@ -163,7 +163,7 @@ def init_args_from_command_line(
163163
help="max norm for gradient clipping",
164164
)
165165
parser.add_argument(
166-
"--training.steps", type=int, default=-1, help="how many train steps to run"
166+
"--training.steps", type=int, default=10000, help="how many train steps to run"
167167
)
168168
parser.add_argument(
169169
"--training.data_parallel_degree",

torchtrain/lr_scheduling.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# global states for scheduling
88
# these are needed as LambdaLR does not support argument passing
9-
_warmup_steps = 2
9+
_warmup_steps = 200
1010
_decay_steps = 0
1111

1212

@@ -33,9 +33,7 @@ def linear_warmup_linear_decay(current_step: int) -> float:
3333
def get_lr_scheduler(optimizer, job_config: JobConfig):
3434
"""Build a linear warmup and linear decay scheduler"""
3535
global _warmup_steps, _decay_steps
36-
_warmup_steps = max(
37-
int(job_config.training.steps * job_config.training.warmup_pct), 2
38-
)
36+
_warmup_steps = int(job_config.training.warmup_steps)
3937
_decay_steps = float(max(1, job_config.training.steps - _warmup_steps))
4038

4139
warmup_scheduler = LambdaLR(optimizer, lr_lambda=linear_warmup_linear_decay)

train.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -187,10 +187,7 @@ def main(job_config: JobConfig):
187187
losses_since_last_log: List[float] = []
188188
nwords_since_last_log = 0
189189
time_last_log = timer()
190-
while (
191-
train_state.step < job_config.training.steps
192-
or job_config.training.steps == -1
193-
):
190+
while train_state.step < job_config.training.steps:
194191
train_state.step += 1
195192
# get batch
196193
data_load_start = timer()
@@ -220,7 +217,7 @@ def main(job_config: JobConfig):
220217

221218
# clip gradients (after unscaling gradients of the optimizer's params)
222219
scaler.unscale_(optimizer)
223-
model.clip_grad_norm_(job_config.training.max_norm)
220+
# model.clip_grad_norm_(job_config.training.max_norm)
224221

225222
# optimizer step
226223
# If gradients don't contain infs/NaNs, optimizer.step() is then called;

train_configs/debug_model.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ lr = 8e-4
2626
[training]
2727
batch_size = 8
2828
seq_len = 2048
29-
warmup_pct = 0.20 # lr scheduler warm up
29+
warmup_steps = 5 # lr scheduler warm up
3030
max_norm = 1.0 # grad norm clipping
3131
steps = 10
3232
data_parallel_degree = -1

0 commit comments

Comments
 (0)