@@ -38,7 +38,7 @@ class FitLoop(Loop[None]):
3838
3939 def __init__ (
4040 self ,
41- min_epochs : Optional [ int ] = 1 ,
41+ min_epochs : int = 0 ,
4242 max_epochs : int = 1000 ,
4343 ) -> None :
4444 super ().__init__ ()
@@ -119,6 +119,21 @@ def running_loss(self) -> TensorRunningAccum:
119119 """Returns the running loss."""
120120 return self .epoch_loop .batch_loop .running_loss
121121
122+ @Loop .restarting .setter
123+ def restarting (self , restarting : bool ) -> None :
124+ # if the last epoch completely finished, we are not actually restarting, we can check this to see if all
125+ # current values are equal
126+ values = (
127+ self .epoch_progress .current .ready ,
128+ self .epoch_progress .current .started ,
129+ self .epoch_progress .current .processed ,
130+ )
131+ finished_before_on_train_end = any (v != self .epoch_progress .current .completed for v in values )
132+ if finished_before_on_train_end :
133+ self .epoch_progress .current .completed = self .epoch_progress .current .processed
134+ restarting &= finished_before_on_train_end
135+ Loop .restarting .fset (self , restarting ) # call the parent setter
136+
122137 @property
123138 def _skip_backward (self ) -> bool :
124139 """Determines whether the loop will skip backward during automatic optimization."""
@@ -142,31 +157,23 @@ def done(self) -> bool:
142157 """Evaluates when to leave the loop."""
143158 # TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop
144159 stop_steps = _is_max_limit_reached (self .global_step , self .max_steps )
145- stop_epochs = _is_max_limit_reached (self .epoch_progress .current .completed , self .max_epochs )
146-
147- should_stop = False
148- if self .trainer .should_stop :
149- # early stopping
150- met_min_epochs = self .epoch_progress .current .completed >= self .min_epochs if self .min_epochs else True
151- met_min_steps = self .global_step >= self .min_steps if self .min_steps else True
152- if met_min_epochs and met_min_steps :
153- should_stop = True
154- else :
160+ stop_epochs = _is_max_limit_reached (self .epoch_progress .current .processed , self .max_epochs )
161+
162+ if self .trainer .should_stop and self .min_epochs :
163+ self .trainer .should_stop = self .epoch_progress .current .processed >= self .min_epochs
164+ if not self .trainer .should_stop :
155165 log .info (
156- "Trainer was signaled to stop but required minimum epochs"
157- f" ({ self .min_epochs } ) or minimum steps ({ self .min_steps } ) has"
158- " not been met. Training will continue..."
166+ f"Trainer was signaled to stop but required minimum epochs ({ self .min_epochs } ) has not been met."
167+ " Training will continue..."
159168 )
160- self .trainer .should_stop = should_stop
161-
162- return stop_steps or should_stop or stop_epochs or self .trainer .num_training_batches == 0
169+ return stop_steps or self .trainer .should_stop or stop_epochs or self .trainer .num_training_batches == 0
163170
164171 @property
165172 def skip (self ) -> bool :
166173 """Whether we should skip the training and immediately return from the call to :meth:`run`."""
167174 # since `trainer.num_training_batches` depends on the `train_dataloader` but that won't be called
168175 # until `on_run_start`, we use `limit_train_batches` instead
169- return self .done or self . trainer .limit_train_batches == 0
176+ return self .trainer .limit_train_batches == 0
170177
171178 def connect (self , epoch_loop : TrainingEpochLoop ) -> None : # type: ignore[override]
172179 """Connects a training epoch loop to this fit loop."""
@@ -205,7 +212,7 @@ def on_advance_start(self) -> None: # type: ignore[override]
205212 getattr (self .trainer .train_dataloader .sampler , "set_epoch" , None )
206213 ):
207214 # set seed for distributed sampler (enables shuffling for each epoch)
208- self .trainer .train_dataloader .sampler .set_epoch (self .epoch_progress .current .completed )
215+ self .trainer .train_dataloader .sampler .set_epoch (self .epoch_progress .current .processed )
209216
210217 # changing gradient according accumulation_scheduler
211218 self .trainer .accumulation_scheduler .on_train_epoch_start (self .trainer , self .trainer .lightning_module )
@@ -289,11 +296,6 @@ def on_advance_end(self) -> None:
289296 def on_run_end (self ) -> None :
290297 """Calls the ``on_train_end`` hook."""
291298 log .detail (f"{ self .__class__ .__name__ } : train run ended" )
292- # NOTE: the current_epoch is already incremented
293- # Lightning today does not increment the current epoch at the last epoch run in Trainer.fit
294- # To simulate that current behavior, we decrement here.
295- # TODO: must be fixed by https://github.com/PyTorchLightning/pytorch-lightning/issues/5007
296- self .epoch_progress .current .completed = max (self .epoch_progress .current .completed - 1 , 0 )
297299
298300 # hook
299301 self .trainer ._call_callback_hooks ("on_train_end" )
0 commit comments