Skip to content

Commit 55c87fa

Browse files
committed
Changed max_steps to default to -1
1 parent b294c57 commit 55c87fa

File tree

8 files changed

+28
-40
lines changed

8 files changed

+28
-40
lines changed

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(self, min_steps: int, max_steps: int):
3838
super().__init__()
3939
self.min_steps: int = min_steps
4040

41-
if max_steps and max_steps < -1:
41+
if max_steps < -1:
4242
raise MisconfigurationException(f"`max_steps` must be a positive integer or -1. You passed in {max_steps}.")
4343
self.max_steps: int = max_steps
4444

@@ -75,7 +75,7 @@ def done(self) -> bool:
7575
The criteria are that the number of steps reached the max steps, the last batch is reached or the trainer
7676
signals to stop (e.g. by early stopping).
7777
"""
78-
max_steps_reached = self.max_steps is not None and self.global_step >= self.max_steps
78+
max_steps_reached = self.max_steps != -1 and self.global_step >= self.max_steps
7979
return max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(self.is_last_batch)
8080

8181
def connect(

pytorch_lightning/loops/fit_loop.py

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,8 @@ class FitLoop(Loop):
3333
max_epochs: The maximum number of epochs
3434
"""
3535

36-
def __init__(self, min_epochs: Optional[int] = None, max_epochs: Optional[int] = None):
36+
def __init__(self, min_epochs: Optional[int] = None, max_epochs: int = None):
3737
super().__init__()
38-
# Allow max_epochs or max_steps to be zero, since this will be handled by fit_loop.done
39-
if max_epochs and max_epochs < -1:
40-
raise MisconfigurationException(
41-
f"`max_epochs` must be a positive integer or -1. You passed in {max_epochs}."
42-
)
43-
4438
self.max_epochs = max_epochs
4539
self.min_epochs = min_epochs
4640
self.epoch_loop: Optional[TrainingEpochLoop] = None
@@ -135,19 +129,6 @@ def _results(self) -> ResultCollection:
135129
return self.epoch_loop.val_loop._results
136130
raise RuntimeError("`FitLoop._results` property isn't defined. Accessed outside of scope")
137131

138-
@staticmethod
139-
def _is_max_limit_enabled(max_value: Optional[int]) -> bool:
140-
"""Checks whether the max_value is enabled. This can be used for checking whether max_epochs or max_steps
141-
is enabled.
142-
143-
Args:
144-
max_value: the value to check
145-
146-
Returns:
147-
whether the limit for this value should be enabled
148-
"""
149-
return max_value not in (None, -1)
150-
151132
@property
152133
def done(self) -> bool:
153134
"""Evaluates when to leave the loop.
@@ -156,8 +137,8 @@ def done(self) -> bool:
156137
is reached.
157138
"""
158139
# TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop
159-
stop_steps = FitLoop._is_max_limit_enabled(self.max_steps) and self.global_step >= self.max_steps
160-
stop_epochs = FitLoop._is_max_limit_enabled(self.max_epochs) and self.current_epoch >= self.max_epochs
140+
stop_steps = self.max_steps != -1 and self.global_step >= self.max_steps
141+
stop_epochs = self.max_epochs != -1 and self.current_epoch >= self.max_epochs
161142

162143
should_stop = False
163144
if self.trainer.should_stop:

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,7 @@ def restore_loops(self) -> None:
191191
self.trainer.fit_loop.current_epoch = self._loaded_checkpoint["epoch"]
192192

193193
# crash if max_epochs is lower then the current epoch from the checkpoint
194-
if (
195-
FitLoop._is_max_limit_enabled(self.trainer.max_epochs)
196-
and self.trainer.current_epoch > self.trainer.max_epochs
197-
):
194+
if self.trainer.max_epochs != -1 and self.trainer.current_epoch > self.trainer.max_epochs:
198195
raise MisconfigurationException(
199196
f"You restored a checkpoint with current_epoch={self.trainer.current_epoch},"
200197
f" but you have set Trainer(max_epochs={self.trainer.max_epochs})."

pytorch_lightning/trainer/trainer.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def __init__(
122122
accumulate_grad_batches: Union[int, Dict[int, int]] = 1,
123123
max_epochs: Optional[int] = None,
124124
min_epochs: Optional[int] = None,
125-
max_steps: Optional[int] = None,
125+
max_steps: int = -1,
126126
min_steps: Optional[int] = None,
127127
max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None,
128128
limit_train_batches: Union[int, float] = 1.0,
@@ -273,9 +273,9 @@ def __init__(
273273
min_epochs: Force training for at least these many epochs. Disabled by default (None).
274274
If both min_epochs and min_steps are not specified, defaults to ``min_epochs = 1``.
275275
276-
max_steps: Stop training after this number of steps. Disabled by default (None). If ``max_steps = None``
276+
max_steps: Stop training after this number of steps. Disabled by default (-1). If ``max_steps = -1``
277277
and ``max_epochs = None``, will default to ``max_epochs = 1000``. To disable this default, set
278-
``max_steps`` to ``-1``.
278+
``max_epochs`` to ``-1``.
279279
280280
min_steps: Force training for at least these number of steps. Disabled by default (None).
281281
@@ -382,10 +382,19 @@ def __init__(
382382
self.slurm_connector = SLURMConnector(self)
383383
self.tuner = Tuner(self)
384384

385-
# max_epochs won't default to 1000 if max_steps/max_time are specified (including being set to -1).
385+
if max_epochs is None:
386+
# max_epochs won't default to 1000 if max_steps/max_time are non-default values.
387+
max_epochs = 1000 if (max_steps == -1 and max_time is None) else -1
388+
389+
elif max_epochs < -1:
390+
# Allow max_epochs to be zero, since this will be handled by fit_loop.done
391+
raise MisconfigurationException(
392+
f"`max_epochs` must be a positive integer or -1. You passed in {max_epochs}."
393+
)
394+
386395
fit_loop = FitLoop(
387396
min_epochs=(1 if (min_epochs is None and min_steps is None and max_time is None) else min_epochs),
388-
max_epochs=(1000 if (max_epochs is None and max_steps is None and max_time is None) else max_epochs),
397+
max_epochs=max_epochs,
389398
)
390399
training_epoch_loop = TrainingEpochLoop(min_steps, max_steps)
391400
training_batch_loop = TrainingBatchLoop()

tests/callbacks/test_timer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def on_fit_start(self):
4949
timer = [c for c in trainer.callbacks if isinstance(c, Timer)][0]
5050
assert timer._duration == 1
5151
assert trainer.max_epochs == -1
52-
assert trainer.max_steps is None
52+
assert trainer.max_steps == -1
5353

5454

5555
@pytest.mark.parametrize(

tests/trainer/flags/test_env_vars.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ def test_passing_no_env_variables():
2121
"""Testing overwriting trainer arguments."""
2222
trainer = Trainer()
2323
assert trainer.logger is not None
24-
assert trainer.max_steps is None
24+
assert trainer.max_steps == -1
25+
assert trainer.max_epochs == 1000
2526
trainer = Trainer(False, max_steps=42)
2627
assert trainer.logger is None
2728
assert trainer.max_steps == 42

tests/trainer/optimization/test_optimizers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def test_lr_scheduler_strict(step_mock, tmpdir, complete_epoch):
373373
optimizer = optim.Adam(model.parameters())
374374
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
375375
max_epochs = 1 if complete_epoch else None
376-
max_steps = None if complete_epoch else 1
376+
max_steps = -1 if complete_epoch else 1
377377
trainer = Trainer(default_root_dir=tmpdir, max_epochs=max_epochs, max_steps=max_steps)
378378

379379
model.configure_optimizers = lambda: {

tests/trainer/test_trainer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ def test_trainer_max_steps_and_epochs(tmpdir):
504504
@pytest.mark.parametrize(
505505
"max_epochs,max_steps,incorrect_variable,incorrect_value",
506506
[
507-
(-100, None, "max_epochs", -100),
507+
(-100, -1, "max_epochs", -100),
508508
(1, -2, "max_steps", -2),
509509
],
510510
)
@@ -520,13 +520,13 @@ def test_trainer_max_steps_and_epochs_validation(max_epochs, max_steps, incorrec
520520
@pytest.mark.parametrize(
521521
"max_epochs,max_steps,is_done,correct_trainer_epochs",
522522
[
523-
(None, None, False, 1000),
524-
(-1, None, False, -1),
523+
(None, -1, False, 1000),
524+
(-1, -1, False, -1),
525525
(None, -1, False, None),
526526
(5, -1, False, 5),
527527
(-1, 10, False, -1),
528528
(None, 0, True, None),
529-
(0, None, True, 0),
529+
(0, -1, True, 0),
530530
(-1, 0, True, -1),
531531
(0, -1, True, 0),
532532
],

0 commit comments

Comments
 (0)