@@ -122,7 +122,7 @@ def __init__(
122122 accumulate_grad_batches : Union [int , Dict [int , int ]] = 1 ,
123123 max_epochs : Optional [int ] = None ,
124124 min_epochs : Optional [int ] = None ,
125- max_steps : Optional [ int ] = None ,
125+ max_steps : int = - 1 ,
126126 min_steps : Optional [int ] = None ,
127127 max_time : Optional [Union [str , timedelta , Dict [str , int ]]] = None ,
128128 limit_train_batches : Union [int , float ] = 1.0 ,
@@ -273,9 +273,9 @@ def __init__(
273273 min_epochs: Force training for at least these many epochs. Disabled by default (None).
274274 If both min_epochs and min_steps are not specified, defaults to ``min_epochs = 1``.
275275
276- max_steps: Stop training after this number of steps. Disabled by default (None ). If ``max_steps = None ``
276+ max_steps: Stop training after this number of steps. Disabled by default (-1 ). If ``max_steps = -1 ``
277277 and ``max_epochs = None``, will default to ``max_epochs = 1000``. To disable this default, set
278- ``max_steps `` to ``-1``.
278+ ``max_epochs `` to ``-1``.
279279
280280 min_steps: Force training for at least these number of steps. Disabled by default (None).
281281
@@ -382,10 +382,19 @@ def __init__(
382382 self .slurm_connector = SLURMConnector (self )
383383 self .tuner = Tuner (self )
384384
385- # max_epochs won't default to 1000 if max_steps/max_time are specified (including being set to -1).
385+ if max_epochs is None :
386+ # max_epochs won't default to 1000 if max_steps/max_time are non-default values.
387+ max_epochs = 1000 if (max_steps == - 1 and max_time is None ) else - 1
388+
389+ elif max_epochs < - 1 :
390+ # Allow max_epochs to be zero, since this will be handled by fit_loop.done
391+ raise MisconfigurationException (
392+ f"`max_epochs` must be a positive integer or -1. You passed in { max_epochs } ."
393+ )
394+
386395 fit_loop = FitLoop (
387396 min_epochs = (1 if (min_epochs is None and min_steps is None and max_time is None ) else min_epochs ),
388- max_epochs = ( 1000 if ( max_epochs is None and max_steps is None and max_time is None ) else max_epochs ) ,
397+ max_epochs = max_epochs ,
389398 )
390399 training_epoch_loop = TrainingEpochLoop (min_steps , max_steps )
391400 training_batch_loop = TrainingBatchLoop ()
0 commit comments