Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- DeepSpeed does not require lightning module zero 3 partitioning ([#10655](https://github.com/PyTorchLightning/pytorch-lightning/pull/10655))


- Deprecated `training_type_plugin` property in favor of `strategy` in `Trainer` and updated the references ([#11141](https://github.com/PyTorchLightning/pytorch-lightning/pull/11141))


- Moved `Strategy` classes to the `strategies` directory ([#11226](https://github.com/PyTorchLightning/pytorch-lightning/pull/11226))


Expand All @@ -266,6 +263,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Changed `MisconfigurationException` to `ModuleNotFoundError` when `rich` isn't available ([#11360](https://github.com/PyTorchLightning/pytorch-lightning/pull/11360))


- The `trainer.current_epoch` value is now increased by 1 during and after `on_train_end` ([#8578](https://github.com/PyTorchLightning/pytorch-lightning/pull/8578))


- Inherit from `ABC` for `Accelerator`: Users need to implement `auto_device_count` ([#11521](https://github.com/PyTorchLightning/pytorch-lightning/pull/11521))


Expand All @@ -288,8 +289,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Deprecated

- Deprecated `training_type_plugin` property in favor of `strategy` in `Trainer` and updated the references ([#11141](https://github.com/PyTorchLightning/pytorch-lightning/pull/11141))


- Deprecated `Trainer.{validated,tested,predicted}_ckpt_path` and replaced with read-only property `Trainer.ckpt_path` set when checkpoints loaded via `Trainer.{fit,validate,test,predict}` ([#11696](https://github.com/PyTorchLightning/pytorch-lightning/pull/11696))


- Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/pull/10103))


Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,14 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", *args):
trainer.fit_loop._skip_backward = False

def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
if self._model_contains_batch_norm and trainer.current_epoch == self.swa_end + 1:
# the trainer increases the current epoch before this hook is called
if self._model_contains_batch_norm and trainer.current_epoch - 1 == self.swa_end + 1:
# BatchNorm epoch update. Reset state
trainer.accumulate_grad_batches = self._accumulate_grad_batches
trainer.num_training_batches -= 1
trainer.fit_loop.max_epochs -= 1
self.reset_momenta()
elif trainer.current_epoch == self.swa_end:
elif trainer.current_epoch - 1 == self.swa_end:
# Last SWA epoch. Transfer weights from average model to pl_module
self.transfer_weights(self._average_model, pl_module)

Expand Down
5 changes: 1 addition & 4 deletions pytorch_lightning/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import pytorch_lightning as pl
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
from pytorch_lightning.trainer.progress import BaseProgress
from pytorch_lightning.utilities.enums import _FaultTolerantMode
from pytorch_lightning.utilities.exceptions import MisconfigurationException

T = TypeVar("T") # the output type of `run`
Expand Down Expand Up @@ -288,11 +287,9 @@ def state_dict(self, destination: Optional[Dict] = None, prefix: str = "") -> Di

destination[prefix + "state_dict"] = self.on_save_checkpoint()

# do not get the mode from `self.trainer` because it might not have been attached yet
ft_enabled = _FaultTolerantMode.detect_current_mode().is_enabled
for k, v in self.__dict__.items():
key = prefix + k
if ft_enabled and isinstance(v, BaseProgress):
if isinstance(v, BaseProgress):
destination[key] = v.state_dict()
elif isinstance(v, Loop):
v.state_dict(destination, key + ".")
Expand Down
17 changes: 12 additions & 5 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections import defaultdict
from typing import Any, Dict, Generator, Iterator, List, Optional, overload, Tuple, Union

Expand Down Expand Up @@ -118,6 +119,17 @@ def reset(self) -> None:
self.batch_progress.reset_on_restart()
self.scheduler_progress.reset_on_restart()
self.batch_loop.optimizer_loop.optim_progress.reset_on_restart()

trainer = self.trainer
if not trainer.state._fault_tolerant_mode.is_enabled and trainer.num_training_batches != float("inf"):
expected_steps = math.ceil(trainer.num_training_batches / trainer.accumulate_grad_batches)
if self.global_step % expected_steps != 0:
rank_zero_warn(
"You're resuming from a checkpoint that ended before the epoch ended. This can cause unreliable"
" results if further training is done. Consider using an end-of-epoch checkpoint or enabling"
" fault-tolerant training:"
" https://pytorch-lightning.readthedocs.io/en/stable/advanced/fault_tolerant_training.html"
)
else:
self.batch_progress.reset_on_run()
self.scheduler_progress.reset_on_run()
Expand Down Expand Up @@ -479,11 +491,6 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:

# TODO(@awaelchli): let training/eval loop handle logic around limit_*_batches and val_check_batch
is_val_check_batch = is_last_batch

# while restarting with no fault-tolerant, batch_progress.current.ready is -1
if batch_idx == -1:
return False

if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset:
is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0
elif self.trainer.val_check_batch != float("inf"):
Expand Down
49 changes: 21 additions & 28 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
import os
from functools import partial
from typing import Optional, Type
Expand All @@ -26,7 +25,6 @@
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
from pytorch_lightning.trainer.progress import Progress
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities.enums import _FaultTolerantMode
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import (
AbstractDataFetcher,
Expand All @@ -51,7 +49,7 @@ class FitLoop(Loop[None]):

def __init__(
self,
min_epochs: Optional[int] = 1,
min_epochs: int = 0,
max_epochs: int = 1000,
) -> None:
super().__init__()
Expand Down Expand Up @@ -133,6 +131,21 @@ def running_loss(self) -> TensorRunningAccum:
"""Returns the running loss."""
return self.epoch_loop.batch_loop.running_loss

@Loop.restarting.setter
def restarting(self, restarting: bool) -> None:
# if the last epoch completely finished, we are not actually restarting, we can check this to see if all
# current values are equal
values = (
self.epoch_progress.current.ready,
self.epoch_progress.current.started,
self.epoch_progress.current.processed,
)
finished_before_on_train_end = any(v != self.epoch_progress.current.completed for v in values)
if finished_before_on_train_end:
self.epoch_progress.current.completed = self.epoch_progress.current.processed
restarting &= finished_before_on_train_end
Loop.restarting.fset(self, restarting) # call the parent setter

@property
def _skip_backward(self) -> bool:
"""Determines whether the loop will skip backward during automatic optimization."""
Expand All @@ -156,12 +169,14 @@ def done(self) -> bool:
"""Evaluates when to leave the loop."""
# TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop
stop_steps = _is_max_limit_reached(self.global_step, self.max_steps)
stop_epochs = _is_max_limit_reached(self.epoch_progress.current.completed, self.max_epochs)
# `processed` is increased before `on_train_epoch_end`, the hook where checkpoints are typically saved.
# we use it here because the checkpoint data won't have `completed` increased yet
stop_epochs = _is_max_limit_reached(self.epoch_progress.current.processed, self.max_epochs)

should_stop = False
if self.trainer.should_stop:
# early stopping
met_min_epochs = self.epoch_progress.current.completed >= self.min_epochs if self.min_epochs else True
met_min_epochs = self.epoch_progress.current.processed >= self.min_epochs if self.min_epochs else True
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
if met_min_epochs and met_min_steps:
should_stop = True
Expand Down Expand Up @@ -198,23 +213,6 @@ def on_run_start(self) -> None: # type: ignore[override]
data_fetcher_cls = _select_data_fetcher(self.trainer)
self._data_fetcher = data_fetcher_cls()

ft_enabled = _FaultTolerantMode.detect_current_mode().is_enabled
if not ft_enabled and self.restarting and self.trainer.num_training_batches not in (0, float("inf")):
self.trainer.accumulate_grad_batches = self.trainer.accumulation_scheduler.get_accumulate_grad_batches(
self.trainer.current_epoch
)
expected_steps = math.ceil(self.trainer.num_training_batches / self.trainer.accumulate_grad_batches)

# global_step is incremented during checkpointing (#11555)
if (self.trainer.global_step - 1) % expected_steps != 0:
rank_zero_warn(
"You're resuming from a checkpoint that ended mid-epoch."
" Training will start from the beginning of the next epoch."
" This can cause unreliable results if further training is done,"
" consider using an end of epoch checkpoint or use fault-tolerant training"
" to restart as if training did not stop."
)

self._is_fresh_start_epoch = True
self._results.to(device=self.trainer.lightning_module.device)

Expand All @@ -240,7 +238,7 @@ def on_advance_start(self) -> None: # type: ignore[override]
getattr(self.trainer.train_dataloader.sampler, "set_epoch", None)
):
# set seed for distributed sampler (enables shuffling for each epoch)
self.trainer.train_dataloader.sampler.set_epoch(self.epoch_progress.current.completed)
self.trainer.train_dataloader.sampler.set_epoch(self.epoch_progress.current.processed)

# changing gradient according accumulation_scheduler
self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module)
Expand Down Expand Up @@ -325,11 +323,6 @@ def on_advance_end(self) -> None:
def on_run_end(self) -> None:
"""Calls the ``on_train_end`` hook."""
log.detail(f"{self.__class__.__name__}: train run ended")
# NOTE: the current_epoch is already incremented
# Lightning today does not increment the current epoch at the last epoch run in Trainer.fit
# To simulate that current behavior, we decrement here.
# TODO: must be fixed by https://github.com/PyTorchLightning/pytorch-lightning/issues/5007
self.epoch_progress.current.completed = max(self.epoch_progress.current.completed - 1, 0)

# hook
self.trainer._call_callback_hooks("on_train_end")
Expand Down
9 changes: 7 additions & 2 deletions pytorch_lightning/loops/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _parse_loop_limits(
min_epochs: Optional[int],
max_epochs: int,
max_time: Optional[Union[str, timedelta, Dict[str, int]]],
) -> Tuple[Optional[int], int, Optional[int], int, Optional[Union[str, timedelta, Dict[str, int]]]]:
) -> Tuple[Optional[int], int, int, int, Optional[Union[str, timedelta, Dict[str, int]]]]:
"""This utility computes the default values for the minimum and maximum number of steps and epochs given the
values the user has selected.

Expand All @@ -95,7 +95,12 @@ def _parse_loop_limits(
max_epochs = 1000
else:
max_epochs = -1
min_epochs = 1 if (min_epochs is None and min_steps is None and max_time is None) else min_epochs
if min_epochs is None and min_steps is not None:
# setting this allows FitLoop.done to re-evaluate should_stop when it gets triggered `on_fit_start`
min_epochs = 1
if min_epochs is None:
# the default value is 0 so no training will be done when should_stop is triggered `on_fit_start`
min_epochs = 0
return min_steps, max_steps, min_epochs, max_epochs, max_time


Expand Down
18 changes: 4 additions & 14 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from torchmetrics import Metric

import pytorch_lightning as pl
from pytorch_lightning.loops.utilities import _is_max_limit_reached
from pytorch_lightning.plugins.environments import SLURMEnvironment
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE
Expand Down Expand Up @@ -230,7 +229,7 @@ def restore_loops(self) -> None:
assert self.trainer.state.fn is not None
state_dict = self._loaded_checkpoint.get("loops")
if state_dict is not None:
if self.trainer.state.fn == TrainerFn.FITTING:
if self.trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
self.trainer.fit_loop.load_state_dict(state_dict["fit_loop"])
elif self.trainer.state.fn == TrainerFn.VALIDATING:
self.trainer.validate_loop.load_state_dict(state_dict["validate_loop"])
Expand Down Expand Up @@ -329,21 +328,12 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
LightningDataModule.__class__.__qualname__: pl DataModule's state
}
"""

# dump epoch/global_step/pytorch-lightning_version
current_epoch = self.trainer.current_epoch
global_step = self.trainer.global_step
has_reached_max_steps = _is_max_limit_reached(global_step, self.trainer.max_steps)

global_step += 1
if not has_reached_max_steps:
current_epoch += 1

model = self.trainer.lightning_module

checkpoint = {
"epoch": current_epoch,
"global_step": global_step,
# the epoch is saved for compatibility but it's not relevant for restoration
"epoch": self.trainer.current_epoch,
"global_step": self.trainer.global_step + 1,
"pytorch-lightning_version": pl.__version__,
"state_dict": self._get_lightning_module_state_dict(),
"loops": self._get_loops_state_dict(),
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2438,7 +2438,7 @@ def max_epochs(self) -> int:
return self.fit_loop.max_epochs

@property
def min_epochs(self) -> Optional[int]:
def min_epochs(self) -> int:
return self.fit_loop.min_epochs

@property
Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,8 @@ def scale_batch_size(

# Save initial model, that is loaded after batch size is found
ckpt_path = os.path.join(trainer.default_root_dir, f".scale_batch_size_{uuid.uuid4()}.ckpt")
trainer.fit_loop.epoch_progress.current.completed -= 1
trainer.fit_loop.global_step -= 1
trainer.save_checkpoint(ckpt_path)
trainer.fit_loop.epoch_progress.current.completed += 1
trainer.fit_loop.global_step += 1
params = __scale_batch_dump_params(trainer)

Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,8 @@ def lr_find(

# Save initial model, that is loaded after learning rate is found
ckpt_path = os.path.join(trainer.default_root_dir, f".lr_find_{uuid.uuid4()}.ckpt")
trainer.fit_loop.epoch_progress.current.completed -= 1
trainer.fit_loop.global_step -= 1
trainer.save_checkpoint(ckpt_path)
trainer.fit_loop.epoch_progress.current.completed += 1
trainer.fit_loop.global_step += 1
params = __lr_finder_dump_params(trainer)

Expand Down
14 changes: 7 additions & 7 deletions tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
# ensure state is persisted properly
checkpoint = torch.load(checkpoint_filepath)
# the checkpoint saves "epoch + 1"
early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"] - 1]
early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"]]
assert 4 == len(early_stop_callback.saved_states)
es_name = "EarlyStoppingTestRestore{'monitor': 'train_loss', 'mode': 'min'}"
assert checkpoint["callbacks"][es_name] == early_stop_callback_state
Expand Down Expand Up @@ -143,7 +143,7 @@ def validation_epoch_end(self, outputs):
enable_progress_bar=False,
)
trainer.fit(model)
assert trainer.current_epoch == expected_stop_epoch
assert trainer.current_epoch - 1 == expected_stop_epoch


@pytest.mark.parametrize("validation_step_none", [True, False])
Expand Down Expand Up @@ -179,7 +179,7 @@ def training_epoch_end(self, outputs):
enable_progress_bar=False,
)
trainer.fit(model)
assert trainer.current_epoch == expected_stop_epoch
assert trainer.current_epoch - 1 == expected_stop_epoch


def test_pickling(tmpdir):
Expand Down Expand Up @@ -236,7 +236,7 @@ def validation_epoch_end(self, outputs):
max_epochs=20,
)
trainer.fit(model)
assert trainer.current_epoch == expected_epoch, "early_stopping failed"
assert trainer.current_epoch - 1 == expected_epoch, "early_stopping failed"


@pytest.mark.parametrize("stop_value", [torch.tensor(np.inf), torch.tensor(np.nan)])
Expand All @@ -260,7 +260,7 @@ def validation_epoch_end(self, outputs):
max_epochs=10,
)
trainer.fit(model)
assert trainer.current_epoch == expected_stop_epoch
assert trainer.current_epoch - 1 == expected_stop_epoch
assert early_stopping.stopped_epoch == expected_stop_epoch


Expand Down Expand Up @@ -388,7 +388,7 @@ def validation_epoch_end(self, outputs):
self._epoch_end()

def on_train_end(self) -> None:
assert self.trainer.current_epoch == self.expected_end_epoch, "Early Stopping Failed"
assert self.trainer.current_epoch - 1 == self.expected_end_epoch, "Early Stopping Failed"


_ES_CHECK = dict(check_on_train_epoch_end=True)
Expand Down Expand Up @@ -481,7 +481,7 @@ def validation_step(self, batch, batch_idx):
if case == "val_check_interval":
assert trainer.global_step == len(side_effect) * int(trainer.limit_train_batches * trainer.val_check_interval)
else:
assert trainer.current_epoch == len(side_effect) * trainer.check_val_every_n_epoch - 1
assert trainer.current_epoch == len(side_effect) * trainer.check_val_every_n_epoch


def test_early_stopping_squeezes():
Expand Down
Loading