Skip to content

Commit eefcdee

Browse files
committed
update mypy and error msg
1 parent a8662f9 commit eefcdee

File tree

3 files changed

+14
-12
lines changed

3 files changed

+14
-12
lines changed

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
4141
max_steps: The maximum number of steps (batches) to process
4242
"""
4343

44-
def __init__(self, min_steps: Optional[int] = 0, max_steps: int = -1):
44+
def __init__(self, min_steps: Optional[int] = 0, max_steps: int = -1) -> None:
4545
super().__init__()
4646
if max_steps is None:
4747
rank_zero_deprecation(
@@ -51,10 +51,10 @@ def __init__(self, min_steps: Optional[int] = 0, max_steps: int = -1):
5151
max_steps = -1
5252
elif max_steps < -1:
5353
raise MisconfigurationException(
54-
f"`max_steps` must be a non-negative integer or -1. You passed in {max_steps}."
54+
f"`max_steps` must be a non-negative integer or -1 (infinite steps). You passed in {max_steps}."
5555
)
56-
self.min_steps: Optional[int] = min_steps
57-
self.max_steps: int = max_steps
56+
self.min_steps = min_steps
57+
self.max_steps = max_steps
5858

5959
self.global_step: int = 0
6060
self.batch_progress = BatchProgress()

pytorch_lightning/loops/fit_loop.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(
3838
self,
3939
min_epochs: Optional[int] = 1,
4040
max_epochs: int = 1000,
41-
):
41+
) -> None:
4242
super().__init__()
4343
if max_epochs < -1:
4444
# Allow max_epochs to be zero, since this will be handled by fit_loop.done
@@ -105,7 +105,7 @@ def max_steps(self) -> int:
105105
return self.epoch_loop.max_steps
106106

107107
@max_steps.setter
108-
def max_steps(self, value: Optional[int]) -> None:
108+
def max_steps(self, value: int) -> None:
109109
"""Sets the maximum number of steps (forwards to epoch_loop)"""
110110
# TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided
111111
if value is None:
@@ -115,7 +115,9 @@ def max_steps(self, value: Optional[int]) -> None:
115115
)
116116
value = -1
117117
elif value < -1:
118-
raise MisconfigurationException(f"`max_steps` must be a non-negative integer or -1. You passed in {value}.")
118+
raise MisconfigurationException(
119+
f"`max_steps` must be a non-negative integer or -1 (infinite steps). You passed in {value}."
120+
)
119121
self.epoch_loop.max_steps = value
120122

121123
@property

tests/trainer/test_trainer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -498,17 +498,17 @@ def test_trainer_max_steps_and_epochs(tmpdir):
498498

499499

500500
@pytest.mark.parametrize(
501-
"max_epochs,max_steps,incorrect_variable,incorrect_value",
501+
"max_epochs,max_steps,incorrect_variable",
502502
[
503-
(-100, -1, "max_epochs", -100),
504-
(1, -2, "max_steps", -2),
503+
(-100, -1, "max_epochs"),
504+
(1, -2, "max_steps"),
505505
],
506506
)
507-
def test_trainer_max_steps_and_epochs_validation(max_epochs, max_steps, incorrect_variable, incorrect_value):
507+
def test_trainer_max_steps_and_epochs_validation(max_epochs, max_steps, incorrect_variable):
508508
"""Don't allow max_epochs or max_steps to be less than -1 or a float."""
509509
with pytest.raises(
510510
MisconfigurationException,
511-
match=f"`{incorrect_variable}` must be a non-negative integer or -1. You passed in {incorrect_value}",
511+
match=f"`{incorrect_variable}` must be a non-negative integer or -1",
512512
):
513513
Trainer(max_epochs=max_epochs, max_steps=max_steps)
514514

0 commit comments

Comments
 (0)