@@ -56,16 +56,6 @@ def __init__(
5656 self ._is_fresh_start_epoch : bool = True
5757 self ._outputs : _EPOCH_OUTPUTS_TYPE = []
5858
59- @property
60- def current_epoch (self ) -> int :
61- """Return the current epoch."""
62- return self .epoch_progress .current .completed
63-
64- @current_epoch .setter
65- def current_epoch (self , value : int ) -> None :
66- """Setter for the current epoch."""
67- self .epoch_progress .current .completed = value
68-
6959 @property
7060 def global_step (self ) -> int :
7161 """Returns the global step."""
@@ -129,6 +119,18 @@ def running_loss(self) -> TensorRunningAccum:
129119 """Returns the running loss."""
130120 return self .epoch_loop .batch_loop .running_loss
131121
122+ @Loop .restarting .setter
123+ def restarting (self , restarting : bool ) -> None :
124+ # if the last epoch completely finished, we are not actually restarting, we can check this to see if all
125+ # current values are equal
126+ values = (
127+ self .epoch_progress .current .ready ,
128+ self .epoch_progress .current .started ,
129+ self .epoch_progress .current .processed ,
130+ )
131+ restarting &= any (v != self .epoch_progress .current .completed for v in values )
132+ Loop .restarting .fset (self , restarting ) # call the parent setter
133+
132134 @property
133135 def _skip_backward (self ) -> bool :
134136 """Determines whether the loop will skip backward during automatic optimization."""
@@ -152,11 +154,11 @@ def done(self) -> bool:
152154 """Evaluates when to leave the loop."""
153155 # TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop
154156 stop_steps = _is_max_limit_reached (self .global_step , self .max_steps )
155- stop_epochs = _is_max_limit_reached (self .current_epoch , self .max_epochs )
157+ stop_epochs = _is_max_limit_reached (self .epoch_progress . current . processed , self .max_epochs )
156158
157159 should_stop = self .trainer .should_stop
158160 if should_stop :
159- should_stop = self .current_epoch >= self .min_epochs if self .min_epochs else True
161+ should_stop = self .epoch_progress . current . processed >= self .min_epochs if self .min_epochs else True
160162 if not should_stop :
161163 log .info (
162164 f"Trainer was signaled to stop but required minimum epochs ({ self .min_epochs } ) has not been met."
@@ -169,7 +171,7 @@ def skip(self) -> bool:
169171 """Whether we should skip the training and immediately return from the call to :meth:`run`."""
170172 # since `trainer.num_training_batches` depends on the `train_dataloader` but that won't be called
171173 # until `on_run_start`, we use `limit_train_batches` instead
172- return self .done or self . trainer .limit_train_batches == 0
174+ return self .trainer .limit_train_batches == 0
173175
174176 def connect (self , epoch_loop : TrainingEpochLoop ) -> None : # type: ignore[override]
175177 """Connects a training epoch loop to this fit loop."""
@@ -207,7 +209,7 @@ def on_advance_start(self) -> None: # type: ignore[override]
207209 getattr (self .trainer .train_dataloader .sampler , "set_epoch" , None )
208210 ):
209211 # set seed for distributed sampler (enables shuffling for each epoch)
210- self .trainer .train_dataloader .sampler .set_epoch (self .current_epoch )
212+ self .trainer .train_dataloader .sampler .set_epoch (self .epoch_progress . current . processed )
211213
212214 # changing gradient according accumulation_scheduler
213215 self .trainer .accumulation_scheduler .on_train_epoch_start (self .trainer , self .trainer .lightning_module )
0 commit comments