From 277b0b811fb1419d6c06e7953941d6f6076eaf6d Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 21 Oct 2022 13:44:35 +0200 Subject: [PATCH 01/24] migration --- src/pytorch_lightning/core/saving.py | 7 +- .../connectors/checkpoint_connector.py | 28 +-- src/pytorch_lightning/utilities/migration.py | 159 +++++++++++++++++- .../utilities/upgrade_checkpoint.py | 33 +--- .../utilities/test_upgrade_checkpoint.py | 14 +- 5 files changed, 169 insertions(+), 72 deletions(-) diff --git a/src/pytorch_lightning/core/saving.py b/src/pytorch_lightning/core/saving.py index 46f1663ed705c..247d864438d87 100644 --- a/src/pytorch_lightning/core/saving.py +++ b/src/pytorch_lightning/core/saving.py @@ -31,7 +31,7 @@ from lightning_lite.utilities.cloud_io import load as pl_load from lightning_lite.utilities.types import _MAP_LOCATION_TYPE, _PATH from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE -from pytorch_lightning.utilities.migration import pl_legacy_patch +from pytorch_lightning.utilities.migration import migrate_checkpoint, pl_legacy_patch from pytorch_lightning.utilities.parsing import AttributeDict, parse_class_init_keys from pytorch_lightning.utilities.rank_zero import rank_zero_warn @@ -156,6 +156,9 @@ def _load_from_checkpoint( with pl_legacy_patch(): checkpoint = pl_load(checkpoint_path, map_location=map_location) + # convert legacy checkpoints to the new format + checkpoint = migrate_checkpoint(checkpoint) + if hparams_file is not None: extension = str(hparams_file).split(".")[-1] if extension.lower() == "csv": @@ -168,6 +171,7 @@ def _load_from_checkpoint( # overwrite hparams by the given file checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams + # TODO: make this a migration: # for past checkpoint need to add the new key checkpoint.setdefault(cls.CHECKPOINT_HYPER_PARAMS_KEY, {}) # override the hparams with values that were passed in @@ -197,6 +201,7 @@ def _load_state( if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: if issubclass(cls, pl.LightningModule): + # TODO: make this a migration: # 1. (backward compatibility) Try to restore model hparams from checkpoint using old/past keys for _old_hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS: cls_kwargs_loaded.update(checkpoint.get(_old_hparam_key, {})) diff --git a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 50480e769b4ce..01d5a6a7e14ed 100644 --- a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -32,9 +32,8 @@ from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training -from pytorch_lightning.utilities.migration import pl_legacy_patch +from pytorch_lightning.utilities.migration import migrate_checkpoint, pl_legacy_patch from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn -from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS if _OMEGACONF_AVAILABLE: from omegaconf import Container @@ -86,13 +85,7 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None: def _load_and_validate_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: with pl_legacy_patch(): loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path) - if any(key in loaded_checkpoint for key in DEPRECATED_CHECKPOINT_KEYS): - raise ValueError( - "The checkpoint you're attempting to load follows an" - " outdated schema. You can upgrade to the current schema by running" - " `python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt`" - " where `model.ckpt` is your checkpoint file." - ) + loaded_checkpoint = migrate_checkpoint(loaded_checkpoint) return loaded_checkpoint def _set_ckpt_path( @@ -348,23 +341,6 @@ def restore_loops(self) -> None: return fit_loop = self.trainer.fit_loop - pl_module = self.trainer.lightning_module - assert pl_module is not None - - # set the `global_step` value for checkpoints before v1.6 without the progress tracking state. - # it will be overwritten by the loop's state if it was also saved - batch_loop = fit_loop.epoch_loop.batch_loop - if pl_module.automatic_optimization: - batch_loop.optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint[ - "global_step" - ] - else: - batch_loop.manual_loop.optim_step_progress.total.completed = self._loaded_checkpoint["global_step"] - - # set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state. - # it will be overwritten by the loop's state if it was also saved - fit_loop.epoch_progress.current.completed = self._loaded_checkpoint["epoch"] - assert self.trainer.state.fn is not None state_dict = self._loaded_checkpoint.get("loops") if state_dict is not None: diff --git a/src/pytorch_lightning/utilities/migration.py b/src/pytorch_lightning/utilities/migration.py index ed71f25a571f7..7f400fc56efa4 100644 --- a/src/pytorch_lightning/utilities/migration.py +++ b/src/pytorch_lightning/utilities/migration.py @@ -11,16 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations +"""Contains migration functions to upgrade legacy checkpoints to the format of the current Lightning version. + +When Lightning loads a checkpoint, these migrations will be applied on the loaded checkpoint dictionary sequentially, +see :func:`migrate_checkpoint`. +""" import sys -import threading +from distutils.version import LooseVersion from types import ModuleType, TracebackType +from typing import Any, Dict, Optional, Type +import pytorch_lightning as pl import pytorch_lightning.utilities.argparse -# Create a global lock to ensure no race condition with deleting sys modules -_lock = threading.Lock() +_CHECKPOINT = Dict[str, Any] class pl_legacy_patch: @@ -28,7 +33,7 @@ class pl_legacy_patch: unpickling old checkpoints. The following patches apply. 1. ``pytorch_lightning.utilities.argparse._gpus_arg_default``: Applies to all checkpoints saved prior to - version 1.2.8. See: https://github.com/Lightning-AI/lightning/pull/6898 + version 1.2.8. See: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898 2. ``pytorch_lightning.utilities.argparse_utils``: A module that was deprecated in 1.2 and removed in 1.4, but still needs to be available for import for legacy checkpoints. @@ -38,8 +43,7 @@ class pl_legacy_patch: torch.load("path/to/legacy/checkpoint.ckpt") """ - def __enter__(self) -> None: - _lock.acquire() + def __enter__(self) -> "pl_legacy_patch": # `pl.utilities.argparse_utils` was renamed to `pl.utilities.argparse` legacy_argparse_module = ModuleType("pytorch_lightning.utilities.argparse_utils") sys.modules["pytorch_lightning.utilities.argparse_utils"] = legacy_argparse_module @@ -47,11 +51,148 @@ def __enter__(self) -> None: # `_gpus_arg_default` used to be imported from these locations legacy_argparse_module._gpus_arg_default = lambda x: x pytorch_lightning.utilities.argparse._gpus_arg_default = lambda x: x + return self def __exit__( - self, exc_type: type[BaseException] | None, exc_value: BaseException | None, exc_traceback: TracebackType | None + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + exc_traceback: Optional[TracebackType], ) -> None: if hasattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default"): delattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default") del sys.modules["pytorch_lightning.utilities.argparse_utils"] - _lock.release() + + +def get_version(checkpoint: _CHECKPOINT) -> str: + """Get the version of a Lightning checkpoint.""" + return checkpoint["pytorch-lightning_version"] + + +def set_version(checkpoint: _CHECKPOINT, version: str) -> None: + """Set the version of a Lightning checkpoint.""" + checkpoint["pytorch-lightning_version"] = version + + +def should_upgrade(checkpoint: _CHECKPOINT, target: str) -> bool: + """Returns whether a checkpoint qualifies for an upgrade when the version is lower than the given target.""" + return LooseVersion(get_version(checkpoint)) < LooseVersion(target) + + +def migrate_checkpoint(checkpoint: _CHECKPOINT) -> _CHECKPOINT: + """Applies all migrations below in order.""" + if should_upgrade(checkpoint, "0.10.0"): + _migrate_model_checkpoint_early_stopping(checkpoint) + if should_upgrade(checkpoint, "1.6.0"): + _migrate_loop_global_step_to_progress_tracking(checkpoint) + _migrate_loop_current_epoch_to_progress_tracking(checkpoint) + + set_version(checkpoint, pl.__version__) + + # TODO: If any migrations apply, log a message. Suggest to run upgrade_checkpoint script to convert + # checkpoints permanently + return checkpoint + + +def _migrate_model_checkpoint_early_stopping(checkpoint: _CHECKPOINT) -> _CHECKPOINT: + """The checkpoint and early stopping keys were renamed. + + Version: 0.10.0 + Commit: + """ + from pytorch_lightning.callbacks.early_stopping import EarlyStopping + from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint + + keys_mapping = { + "checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"), + "checkpoint_callback_best_model_path": (ModelCheckpoint, "best_model_path"), + "checkpoint_callback_best": (ModelCheckpoint, "best_model_score"), + "early_stop_callback_wait": (EarlyStopping, "wait_count"), + "early_stop_callback_patience": (EarlyStopping, "patience"), + } + checkpoint["callbacks"] = checkpoint.get("callbacks") or {} + + for key, new_path in keys_mapping.items(): + if key in checkpoint: + value = checkpoint[key] + callback_type, callback_key = new_path + checkpoint["callbacks"][callback_type] = checkpoint["callbacks"].get(callback_type) or {} + checkpoint["callbacks"][callback_type][callback_key] = value + del checkpoint[key] + return checkpoint + + +def _migrate_loop_global_step_to_progress_tracking(checkpoint: _CHECKPOINT) -> _CHECKPOINT: + """Set the `global_step` value for checkpoints before v1.6 without the progress tracking state. + It will be overwritten by the loop's state if it was also saved. + + Version: 1.6.0 + Commit: + """ + global_step = checkpoint["global_step"] + checkpoint.setdefault("loops", {"fit_loop": _FIT_LOOP_INITIAL_STATE_1_6_0}) + checkpoint["loops"].setdefault("fit_loop", _FIT_LOOP_INITIAL_STATE_1_6_0) + # for automatic optimization + optim_progress = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.optimizer_loop.optim_progress"] + optim_progress["optimizer"]["step"]["total"]["completed"] = global_step + # for manual optimization + optim_step_progress = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.manual_loop.optim_step_progress"] + optim_step_progress["total"]["completed"] = global_step + return checkpoint + + +def _migrate_loop_current_epoch_to_progress_tracking(checkpoint: _CHECKPOINT) -> _CHECKPOINT: + """Set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state. + It will be overwritten by the loop's state if it was also saved. + + Version: 1.6.0 + Commit: + """ + epoch = checkpoint["epoch"] + checkpoint.setdefault("loops", {"fit_loop": _FIT_LOOP_INITIAL_STATE_1_6_0}) + checkpoint["loops"].setdefault("fit_loop", _FIT_LOOP_INITIAL_STATE_1_6_0) + checkpoint["loops"]["fit_loop"]["epoch_progress"]["current"]["completed"] = epoch + + +_FIT_LOOP_INITIAL_STATE_1_6_0 = { + "epoch_loop.batch_loop.manual_loop.optim_step_progress": { + "current": {"completed": 0, "ready": 0}, + "total": {"completed": 0, "ready": 0}, + }, + "epoch_loop.batch_loop.manual_loop.state_dict": {}, + "epoch_loop.batch_loop.optimizer_loop.optim_progress": { + "optimizer": { + "step": {"current": {"completed": 0, "ready": 0}, "total": {"completed": 0, "ready": 0}}, + "zero_grad": { + "current": {"completed": 0, "ready": 0, "started": 0}, + "total": {"completed": 0, "ready": 0, "started": 0}, + }, + }, + "optimizer_position": 0, + }, + "epoch_loop.batch_loop.optimizer_loop.state_dict": {}, + "epoch_loop.batch_loop.state_dict": {}, + "epoch_loop.batch_progress": { + "current": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + "is_last_batch": False, + "total": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + }, + "epoch_loop.scheduler_progress": {"current": {"completed": 0, "ready": 0}, "total": {"completed": 0, "ready": 0}}, + "epoch_loop.state_dict": {"_batches_that_stepped": 0}, + "epoch_loop.val_loop.dataloader_progress": { + "current": {"completed": 0, "ready": 0}, + "total": {"completed": 0, "ready": 0}, + }, + "epoch_loop.val_loop.epoch_loop.batch_progress": { + "current": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + "is_last_batch": False, + "total": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + }, + "epoch_loop.val_loop.epoch_loop.state_dict": {}, + "epoch_loop.val_loop.state_dict": {}, + "epoch_progress": { + "current": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + "total": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + }, + "state_dict": {}, +} diff --git a/src/pytorch_lightning/utilities/upgrade_checkpoint.py b/src/pytorch_lightning/utilities/upgrade_checkpoint.py index 6f4dd5ca938dd..46038701b1e52 100644 --- a/src/pytorch_lightning/utilities/upgrade_checkpoint.py +++ b/src/pytorch_lightning/utilities/upgrade_checkpoint.py @@ -17,36 +17,11 @@ import torch -from lightning_lite.utilities.types import _PATH -from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint -from pytorch_lightning.utilities.migration import pl_legacy_patch - -KEYS_MAPPING = { - "checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"), - "checkpoint_callback_best_model_path": (ModelCheckpoint, "best_model_path"), - "checkpoint_callback_best": (ModelCheckpoint, "best_model_score"), - "early_stop_callback_wait": (EarlyStopping, "wait_count"), - "early_stop_callback_patience": (EarlyStopping, "patience"), -} +from pytorch_lightning.utilities.migration.base import pl_legacy_patch +from pytorch_lightning.utilities.migration.migrations import migrate_checkpoint log = logging.getLogger(__name__) - -def upgrade_checkpoint(filepath: _PATH) -> None: - checkpoint = torch.load(filepath) - checkpoint["callbacks"] = checkpoint.get("callbacks") or {} - - for key, new_path in KEYS_MAPPING.items(): - if key in checkpoint: - value = checkpoint[key] - callback_type, callback_key = new_path - checkpoint["callbacks"][callback_type] = checkpoint["callbacks"].get(callback_type) or {} - checkpoint["callbacks"][callback_type][callback_key] = value - del checkpoint[key] - - torch.save(checkpoint, filepath) - - if __name__ == "__main__": parser = argparse.ArgumentParser( @@ -61,4 +36,6 @@ def upgrade_checkpoint(filepath: _PATH) -> None: log.info("Creating a backup of the existing checkpoint file before overwriting in the upgrade process.") copyfile(args.file, args.file + ".bak") with pl_legacy_patch(): - upgrade_checkpoint(args.file) + checkpoint = torch.load(args.file) + migrate_checkpoint(checkpoint) + torch.save(checkpoint, args.file) diff --git a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py index a58bdb5721bc7..c01fcf7eb249d 100644 --- a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py +++ b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py @@ -11,13 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os - import pytest -import torch +import pytorch_lightning as pl from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint -from pytorch_lightning.utilities.upgrade_checkpoint import upgrade_checkpoint +from pytorch_lightning.utilities.migration import get_version, migrate_checkpoint, set_version @pytest.mark.parametrize( @@ -42,8 +40,8 @@ ], ) def test_upgrade_checkpoint(tmpdir, old_checkpoint, new_checkpoint): - filepath = os.path.join(tmpdir, "model.ckpt") - torch.save(old_checkpoint, filepath) - upgrade_checkpoint(filepath) - updated_checkpoint = torch.load(filepath) + set_version(old_checkpoint, "0.9.0") + set_version(new_checkpoint, pl.__version__) + updated_checkpoint = migrate_checkpoint(old_checkpoint) assert updated_checkpoint == new_checkpoint + assert get_version(updated_checkpoint) == pl.__version__ From cc110a3d1c7f21728ff74938317101667544d948 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 21 Oct 2022 11:49:28 +0000 Subject: [PATCH 02/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/utilities/migration.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/pytorch_lightning/utilities/migration.py b/src/pytorch_lightning/utilities/migration.py index 7f400fc56efa4..8df43b83dad48 100644 --- a/src/pytorch_lightning/utilities/migration.py +++ b/src/pytorch_lightning/utilities/migration.py @@ -123,8 +123,8 @@ def _migrate_model_checkpoint_early_stopping(checkpoint: _CHECKPOINT) -> _CHECKP def _migrate_loop_global_step_to_progress_tracking(checkpoint: _CHECKPOINT) -> _CHECKPOINT: - """Set the `global_step` value for checkpoints before v1.6 without the progress tracking state. - It will be overwritten by the loop's state if it was also saved. + """Set the `global_step` value for checkpoints before v1.6 without the progress tracking state. It will be + overwritten by the loop's state if it was also saved. Version: 1.6.0 Commit: @@ -142,8 +142,8 @@ def _migrate_loop_global_step_to_progress_tracking(checkpoint: _CHECKPOINT) -> _ def _migrate_loop_current_epoch_to_progress_tracking(checkpoint: _CHECKPOINT) -> _CHECKPOINT: - """Set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state. - It will be overwritten by the loop's state if it was also saved. + """Set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state. It will be + overwritten by the loop's state if it was also saved. Version: 1.6.0 Commit: From e13fcc61ddc9dfa5b01c7e3d3d70066be5cd65d5 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 21 Oct 2022 14:13:58 +0200 Subject: [PATCH 03/24] import --- src/pytorch_lightning/utilities/upgrade_checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/utilities/upgrade_checkpoint.py b/src/pytorch_lightning/utilities/upgrade_checkpoint.py index 46038701b1e52..03705c5287e8d 100644 --- a/src/pytorch_lightning/utilities/upgrade_checkpoint.py +++ b/src/pytorch_lightning/utilities/upgrade_checkpoint.py @@ -17,8 +17,8 @@ import torch -from pytorch_lightning.utilities.migration.base import pl_legacy_patch -from pytorch_lightning.utilities.migration.migrations import migrate_checkpoint +from pytorch_lightning.utilities.migration import pl_legacy_patch +from pytorch_lightning.utilities.migration import migrate_checkpoint log = logging.getLogger(__name__) From 9838f008c61ad7a50d9ba7e7344ec16bd67111b6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 21 Oct 2022 12:16:07 +0000 Subject: [PATCH 04/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/utilities/upgrade_checkpoint.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/pytorch_lightning/utilities/upgrade_checkpoint.py b/src/pytorch_lightning/utilities/upgrade_checkpoint.py index 03705c5287e8d..4bcfb4a86f5bd 100644 --- a/src/pytorch_lightning/utilities/upgrade_checkpoint.py +++ b/src/pytorch_lightning/utilities/upgrade_checkpoint.py @@ -17,8 +17,7 @@ import torch -from pytorch_lightning.utilities.migration import pl_legacy_patch -from pytorch_lightning.utilities.migration import migrate_checkpoint +from pytorch_lightning.utilities.migration import migrate_checkpoint, pl_legacy_patch log = logging.getLogger(__name__) From ed1ab3fce307e3b7f8e9dbdf5e79157ec821c98d Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 24 Oct 2022 12:58:36 +0200 Subject: [PATCH 05/24] refactor --- .../utilities/migration/__init__.py | 16 +++ .../{migration.py => migration/migrations.py} | 98 +++++-------------- .../utilities/migration/utils.py | 91 +++++++++++++++++ 3 files changed, 131 insertions(+), 74 deletions(-) create mode 100644 src/pytorch_lightning/utilities/migration/__init__.py rename src/pytorch_lightning/utilities/{migration.py => migration/migrations.py} (59%) create mode 100644 src/pytorch_lightning/utilities/migration/utils.py diff --git a/src/pytorch_lightning/utilities/migration/__init__.py b/src/pytorch_lightning/utilities/migration/__init__.py new file mode 100644 index 0000000000000..8e0f79f6904cb --- /dev/null +++ b/src/pytorch_lightning/utilities/migration/__init__.py @@ -0,0 +1,16 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pytorch_lightning.utilities.migration.utils import pl_legacy_patch # noqa: F401 +from pytorch_lightning.utilities.migration.utils import migrate_checkpoint # noqa: F401 diff --git a/src/pytorch_lightning/utilities/migration.py b/src/pytorch_lightning/utilities/migration/migrations.py similarity index 59% rename from src/pytorch_lightning/utilities/migration.py rename to src/pytorch_lightning/utilities/migration/migrations.py index 8df43b83dad48..a148f79ca5d7c 100644 --- a/src/pytorch_lightning/utilities/migration.py +++ b/src/pytorch_lightning/utilities/migration/migrations.py @@ -14,84 +14,36 @@ """Contains migration functions to upgrade legacy checkpoints to the format of the current Lightning version. When Lightning loads a checkpoint, these migrations will be applied on the loaded checkpoint dictionary sequentially, -see :func:`migrate_checkpoint`. -""" - -import sys -from distutils.version import LooseVersion -from types import ModuleType, TracebackType -from typing import Any, Dict, Optional, Type - -import pytorch_lightning as pl -import pytorch_lightning.utilities.argparse - -_CHECKPOINT = Dict[str, Any] - - -class pl_legacy_patch: - """Registers legacy artifacts (classes, methods, etc.) that were removed but still need to be included for - unpickling old checkpoints. The following patches apply. - - 1. ``pytorch_lightning.utilities.argparse._gpus_arg_default``: Applies to all checkpoints saved prior to - version 1.2.8. See: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898 - 2. ``pytorch_lightning.utilities.argparse_utils``: A module that was deprecated in 1.2 and removed in 1.4, - but still needs to be available for import for legacy checkpoints. - - Example: - - with pl_legacy_patch(): - torch.load("path/to/legacy/checkpoint.ckpt") - """ +see :func:`~pytorch_lightning.utilities.migration.utils.migrate_checkpoint`. - def __enter__(self) -> "pl_legacy_patch": - # `pl.utilities.argparse_utils` was renamed to `pl.utilities.argparse` - legacy_argparse_module = ModuleType("pytorch_lightning.utilities.argparse_utils") - sys.modules["pytorch_lightning.utilities.argparse_utils"] = legacy_argparse_module +How to add a new migration? - # `_gpus_arg_default` used to be imported from these locations - legacy_argparse_module._gpus_arg_default = lambda x: x - pytorch_lightning.utilities.argparse._gpus_arg_default = lambda x: x - return self +1. Create a new function with a descriptive name and docstring that explains the details of this migration. Include + version informatin as well as the specific commit or PR where the breaking change happened. +2. Add the function to the `migration_index()` below. The key in the index is the version of Lightning in which the + change happened. Any checkpoint with a version greater or equal to that version will apply the given function. + Multiple migrations per version get executed in the provided list order. +3. You can test the migration on a checkpoint (backup your files first) by running: - def __exit__( - self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - exc_traceback: Optional[TracebackType], - ) -> None: - if hasattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default"): - delattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default") - del sys.modules["pytorch_lightning.utilities.argparse_utils"] - - -def get_version(checkpoint: _CHECKPOINT) -> str: - """Get the version of a Lightning checkpoint.""" - return checkpoint["pytorch-lightning_version"] - - -def set_version(checkpoint: _CHECKPOINT, version: str) -> None: - """Set the version of a Lightning checkpoint.""" - checkpoint["pytorch-lightning_version"] = version + cp model.ckpt model.ckpt.backup + python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt +""" -def should_upgrade(checkpoint: _CHECKPOINT, target: str) -> bool: - """Returns whether a checkpoint qualifies for an upgrade when the version is lower than the given target.""" - return LooseVersion(get_version(checkpoint)) < LooseVersion(target) +from typing import Any, Dict, Callable, List +from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint -def migrate_checkpoint(checkpoint: _CHECKPOINT) -> _CHECKPOINT: - """Applies all migrations below in order.""" - if should_upgrade(checkpoint, "0.10.0"): - _migrate_model_checkpoint_early_stopping(checkpoint) - if should_upgrade(checkpoint, "1.6.0"): - _migrate_loop_global_step_to_progress_tracking(checkpoint) - _migrate_loop_current_epoch_to_progress_tracking(checkpoint) +_CHECKPOINT = Dict[str, Any] - set_version(checkpoint, pl.__version__) - # TODO: If any migrations apply, log a message. Suggest to run upgrade_checkpoint script to convert - # checkpoints permanently - return checkpoint +def migration_index() -> Dict[str, List[Callable[[_CHECKPOINT], _CHECKPOINT]]]: + """Migration functions returned here will get executed in the order they are listed.""" + return { + "0.10.0": [_migrate_model_checkpoint_early_stopping], + "1.6.0": [_migrate_loop_global_step_to_progress_tracking, _migrate_loop_current_epoch_to_progress_tracking] + } def _migrate_model_checkpoint_early_stopping(checkpoint: _CHECKPOINT) -> _CHECKPOINT: @@ -100,9 +52,6 @@ def _migrate_model_checkpoint_early_stopping(checkpoint: _CHECKPOINT) -> _CHECKP Version: 0.10.0 Commit: """ - from pytorch_lightning.callbacks.early_stopping import EarlyStopping - from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint - keys_mapping = { "checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"), "checkpoint_callback_best_model_path": (ModelCheckpoint, "best_model_path"), @@ -123,7 +72,7 @@ def _migrate_model_checkpoint_early_stopping(checkpoint: _CHECKPOINT) -> _CHECKP def _migrate_loop_global_step_to_progress_tracking(checkpoint: _CHECKPOINT) -> _CHECKPOINT: - """Set the `global_step` value for checkpoints before v1.6 without the progress tracking state. It will be + """Sets the `global_step` value for checkpoints before v1.6 without the progress tracking state. It will be overwritten by the loop's state if it was also saved. Version: 1.6.0 @@ -142,7 +91,7 @@ def _migrate_loop_global_step_to_progress_tracking(checkpoint: _CHECKPOINT) -> _ def _migrate_loop_current_epoch_to_progress_tracking(checkpoint: _CHECKPOINT) -> _CHECKPOINT: - """Set the `current_epoch` value for checkpoints before v1.6 without the progress tracking state. It will be + """Sets the `current_epoch` value for checkpoints before v1.6 without the progress tracking state. It will be overwritten by the loop's state if it was also saved. Version: 1.6.0 @@ -152,6 +101,7 @@ def _migrate_loop_current_epoch_to_progress_tracking(checkpoint: _CHECKPOINT) -> checkpoint.setdefault("loops", {"fit_loop": _FIT_LOOP_INITIAL_STATE_1_6_0}) checkpoint["loops"].setdefault("fit_loop", _FIT_LOOP_INITIAL_STATE_1_6_0) checkpoint["loops"]["fit_loop"]["epoch_progress"]["current"]["completed"] = epoch + return checkpoint _FIT_LOOP_INITIAL_STATE_1_6_0 = { diff --git a/src/pytorch_lightning/utilities/migration/utils.py b/src/pytorch_lightning/utilities/migration/utils.py new file mode 100644 index 0000000000000..a7e8443c78467 --- /dev/null +++ b/src/pytorch_lightning/utilities/migration/utils.py @@ -0,0 +1,91 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import pytorch_lightning as pl +from distutils.version import LooseVersion +from types import ModuleType, TracebackType +from typing import Optional, Type, Dict, Any + +from pytorch_lightning.utilities.migration.migrations import _migrate_model_checkpoint_early_stopping, \ + _migrate_loop_global_step_to_progress_tracking, _migrate_loop_current_epoch_to_progress_tracking, migration_index + +_CHECKPOINT = Dict[str, Any] + + +class pl_legacy_patch: + """Registers legacy artifacts (classes, methods, etc.) that were removed but still need to be included for + unpickling old checkpoints. The following patches apply. + + 1. ``pytorch_lightning.utilities.argparse._gpus_arg_default``: Applies to all checkpoints saved prior to + version 1.2.8. See: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898 + 2. ``pytorch_lightning.utilities.argparse_utils``: A module that was deprecated in 1.2 and removed in 1.4, + but still needs to be available for import for legacy checkpoints. + + Example: + + with pl_legacy_patch(): + torch.load("path/to/legacy/checkpoint.ckpt") + """ + + def __enter__(self) -> "pl_legacy_patch": + # `pl.utilities.argparse_utils` was renamed to `pl.utilities.argparse` + legacy_argparse_module = ModuleType("pytorch_lightning.utilities.argparse_utils") + sys.modules["pytorch_lightning.utilities.argparse_utils"] = legacy_argparse_module + + # `_gpus_arg_default` used to be imported from these locations + legacy_argparse_module._gpus_arg_default = lambda x: x + pl.utilities.argparse._gpus_arg_default = lambda x: x + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + exc_traceback: Optional[TracebackType], + ) -> None: + if hasattr(pl.utilities.argparse, "_gpus_arg_default"): + delattr(pl.utilities.argparse, "_gpus_arg_default") + del sys.modules["pytorch_lightning.utilities.argparse_utils"] + + +def get_version(checkpoint: _CHECKPOINT) -> str: + """Get the version of a Lightning checkpoint.""" + return checkpoint["pytorch-lightning_version"] + + +def set_version(checkpoint: _CHECKPOINT, version: str) -> None: + """Set the version of a Lightning checkpoint.""" + checkpoint["pytorch-lightning_version"] = version + + +def should_upgrade(checkpoint: _CHECKPOINT, target: str) -> bool: + """Returns whether a checkpoint qualifies for an upgrade when the version is lower than the given target.""" + return LooseVersion(get_version(checkpoint)) < LooseVersion(target) + + +def migrate_checkpoint(checkpoint: _CHECKPOINT) -> _CHECKPOINT: + """Applies all migrations below in order.""" + index = migration_index() + for migration_version, migration_functions in index.items(): + if not should_upgrade(checkpoint, migration_version): + continue + for migration_function in migration_functions: + checkpoint = migration_function(checkpoint) + + set_version(checkpoint, pl.__version__) + + # TODO: If any migrations apply, log a message. Suggest to run upgrade_checkpoint script to convert + # checkpoints permanently + return checkpoint From 756b2e7ee153c39957c70bf5eb4d986eabb064d2 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 24 Oct 2022 13:05:21 +0200 Subject: [PATCH 06/24] protected --- .../utilities/migration/utils.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/pytorch_lightning/utilities/migration/utils.py b/src/pytorch_lightning/utilities/migration/utils.py index a7e8443c78467..15769a3efa267 100644 --- a/src/pytorch_lightning/utilities/migration/utils.py +++ b/src/pytorch_lightning/utilities/migration/utils.py @@ -18,8 +18,7 @@ from types import ModuleType, TracebackType from typing import Optional, Type, Dict, Any -from pytorch_lightning.utilities.migration.migrations import _migrate_model_checkpoint_early_stopping, \ - _migrate_loop_global_step_to_progress_tracking, _migrate_loop_current_epoch_to_progress_tracking, migration_index +from pytorch_lightning.utilities.migration.migrations import migration_index _CHECKPOINT = Dict[str, Any] @@ -60,31 +59,31 @@ def __exit__( del sys.modules["pytorch_lightning.utilities.argparse_utils"] -def get_version(checkpoint: _CHECKPOINT) -> str: +def _get_version(checkpoint: _CHECKPOINT) -> str: """Get the version of a Lightning checkpoint.""" return checkpoint["pytorch-lightning_version"] -def set_version(checkpoint: _CHECKPOINT, version: str) -> None: +def _set_version(checkpoint: _CHECKPOINT, version: str) -> None: """Set the version of a Lightning checkpoint.""" checkpoint["pytorch-lightning_version"] = version -def should_upgrade(checkpoint: _CHECKPOINT, target: str) -> bool: +def _should_upgrade(checkpoint: _CHECKPOINT, target: str) -> bool: """Returns whether a checkpoint qualifies for an upgrade when the version is lower than the given target.""" - return LooseVersion(get_version(checkpoint)) < LooseVersion(target) + return LooseVersion(_get_version(checkpoint)) < LooseVersion(target) def migrate_checkpoint(checkpoint: _CHECKPOINT) -> _CHECKPOINT: - """Applies all migrations below in order.""" + """Applies Lightning version migrations to a checkpoint.""" index = migration_index() for migration_version, migration_functions in index.items(): - if not should_upgrade(checkpoint, migration_version): + if not _should_upgrade(checkpoint, migration_version): continue for migration_function in migration_functions: checkpoint = migration_function(checkpoint) - set_version(checkpoint, pl.__version__) + _set_version(checkpoint, pl.__version__) # TODO: If any migrations apply, log a message. Suggest to run upgrade_checkpoint script to convert # checkpoints permanently From d64b5edd467049bb5e6bd29b8fbd6e204e8adba9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 24 Oct 2022 11:07:07 +0000 Subject: [PATCH 07/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/utilities/migration/__init__.py | 2 +- src/pytorch_lightning/utilities/migration/migrations.py | 5 ++--- src/pytorch_lightning/utilities/migration/utils.py | 4 ++-- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/pytorch_lightning/utilities/migration/__init__.py b/src/pytorch_lightning/utilities/migration/__init__.py index 8e0f79f6904cb..199541c19034f 100644 --- a/src/pytorch_lightning/utilities/migration/__init__.py +++ b/src/pytorch_lightning/utilities/migration/__init__.py @@ -12,5 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.utilities.migration.utils import pl_legacy_patch # noqa: F401 from pytorch_lightning.utilities.migration.utils import migrate_checkpoint # noqa: F401 +from pytorch_lightning.utilities.migration.utils import pl_legacy_patch # noqa: F401 diff --git a/src/pytorch_lightning/utilities/migration/migrations.py b/src/pytorch_lightning/utilities/migration/migrations.py index a148f79ca5d7c..e9ff88cc36f01 100644 --- a/src/pytorch_lightning/utilities/migration/migrations.py +++ b/src/pytorch_lightning/utilities/migration/migrations.py @@ -27,10 +27,9 @@ cp model.ckpt model.ckpt.backup python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt - """ -from typing import Any, Dict, Callable, List +from typing import Any, Callable, Dict, List from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint @@ -42,7 +41,7 @@ def migration_index() -> Dict[str, List[Callable[[_CHECKPOINT], _CHECKPOINT]]]: """Migration functions returned here will get executed in the order they are listed.""" return { "0.10.0": [_migrate_model_checkpoint_early_stopping], - "1.6.0": [_migrate_loop_global_step_to_progress_tracking, _migrate_loop_current_epoch_to_progress_tracking] + "1.6.0": [_migrate_loop_global_step_to_progress_tracking, _migrate_loop_current_epoch_to_progress_tracking], } diff --git a/src/pytorch_lightning/utilities/migration/utils.py b/src/pytorch_lightning/utilities/migration/utils.py index 15769a3efa267..fc580471d463a 100644 --- a/src/pytorch_lightning/utilities/migration/utils.py +++ b/src/pytorch_lightning/utilities/migration/utils.py @@ -13,11 +13,11 @@ # limitations under the License. import sys -import pytorch_lightning as pl from distutils.version import LooseVersion from types import ModuleType, TracebackType -from typing import Optional, Type, Dict, Any +from typing import Any, Dict, Optional, Type +import pytorch_lightning as pl from pytorch_lightning.utilities.migration.migrations import migration_index _CHECKPOINT = Dict[str, Any] From 954b4d3c9cc1308ddff8b8a51713f47d04800e13 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 24 Oct 2022 13:08:22 +0200 Subject: [PATCH 08/24] typo --- src/pytorch_lightning/utilities/migration/migrations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/utilities/migration/migrations.py b/src/pytorch_lightning/utilities/migration/migrations.py index a148f79ca5d7c..a3dc604a8294d 100644 --- a/src/pytorch_lightning/utilities/migration/migrations.py +++ b/src/pytorch_lightning/utilities/migration/migrations.py @@ -19,7 +19,7 @@ How to add a new migration? 1. Create a new function with a descriptive name and docstring that explains the details of this migration. Include - version informatin as well as the specific commit or PR where the breaking change happened. + version information as well as the specific commit or PR where the breaking change happened. 2. Add the function to the `migration_index()` below. The key in the index is the version of Lightning in which the change happened. Any checkpoint with a version greater or equal to that version will apply the given function. Multiple migrations per version get executed in the provided list order. From 0f988d1746be1672e3c94f77e6329cd830a3b40c Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 24 Oct 2022 13:50:04 +0200 Subject: [PATCH 09/24] tests --- .../utilities/migration/utils.py | 32 ++++---- .../utilities/migration/__init__.py | 0 .../utilities/migration/test_utils.py | 75 +++++++++++++++++++ .../utilities/test_upgrade_checkpoint.py | 19 +++-- 4 files changed, 102 insertions(+), 24 deletions(-) create mode 100644 tests/tests_pytorch/utilities/migration/__init__.py create mode 100644 tests/tests_pytorch/utilities/migration/test_utils.py diff --git a/src/pytorch_lightning/utilities/migration/utils.py b/src/pytorch_lightning/utilities/migration/utils.py index fc580471d463a..a52aca9944e2c 100644 --- a/src/pytorch_lightning/utilities/migration/utils.py +++ b/src/pytorch_lightning/utilities/migration/utils.py @@ -23,6 +23,22 @@ _CHECKPOINT = Dict[str, Any] +def migrate_checkpoint(checkpoint: _CHECKPOINT) -> _CHECKPOINT: + """Applies Lightning version migrations to a checkpoint dictionary.""" + index = migration_index() + for migration_version, migration_functions in index.items(): + if not _should_upgrade(checkpoint, migration_version): + continue + for migration_function in migration_functions: + checkpoint = migration_function(checkpoint) + + _set_version(checkpoint, pl.__version__) + + # TODO: If any migrations apply, log a message. Suggest to run upgrade_checkpoint script to convert + # checkpoints permanently + return checkpoint + + class pl_legacy_patch: """Registers legacy artifacts (classes, methods, etc.) that were removed but still need to be included for unpickling old checkpoints. The following patches apply. @@ -72,19 +88,3 @@ def _set_version(checkpoint: _CHECKPOINT, version: str) -> None: def _should_upgrade(checkpoint: _CHECKPOINT, target: str) -> bool: """Returns whether a checkpoint qualifies for an upgrade when the version is lower than the given target.""" return LooseVersion(_get_version(checkpoint)) < LooseVersion(target) - - -def migrate_checkpoint(checkpoint: _CHECKPOINT) -> _CHECKPOINT: - """Applies Lightning version migrations to a checkpoint.""" - index = migration_index() - for migration_version, migration_functions in index.items(): - if not _should_upgrade(checkpoint, migration_version): - continue - for migration_function in migration_functions: - checkpoint = migration_function(checkpoint) - - _set_version(checkpoint, pl.__version__) - - # TODO: If any migrations apply, log a message. Suggest to run upgrade_checkpoint script to convert - # checkpoints permanently - return checkpoint diff --git a/tests/tests_pytorch/utilities/migration/__init__.py b/tests/tests_pytorch/utilities/migration/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/tests_pytorch/utilities/migration/test_utils.py b/tests/tests_pytorch/utilities/migration/test_utils.py new file mode 100644 index 0000000000000..227e6e0b590cd --- /dev/null +++ b/tests/tests_pytorch/utilities/migration/test_utils.py @@ -0,0 +1,75 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytorch_lightning as pl +from pytorch_lightning.utilities.migration import migrate_checkpoint + + +def test_migrate_checkpoint(monkeypatch): + """Test that the correct migration function gets executed given the current version of the checkpoint.""" + # A checkpoint that is older than any migration point in the index + old_checkpoint = { + "pytorch-lightning_version": "0.0.0", + "content": 123 + } + new_checkpoint, call_order = _run_simple_migration(monkeypatch, old_checkpoint) + assert call_order == ["one", "two", "three", "four"] + assert new_checkpoint == { + "pytorch-lightning_version": pl.__version__, + "content": 123 + } + + # A checkpoint that is newer, but not the newest + old_checkpoint = { + "pytorch-lightning_version": "1.0.3", + "content": 123 + } + new_checkpoint, call_order = _run_simple_migration(monkeypatch, old_checkpoint) + assert call_order == ["four"] + assert new_checkpoint == { + "pytorch-lightning_version": pl.__version__, + "content": 123 + } + + # A checkpoint newer than any migration point in the index + old_checkpoint = { + "pytorch-lightning_version": "2.0", + "content": 123 + } + new_checkpoint, call_order = _run_simple_migration(monkeypatch, old_checkpoint) + assert call_order == [] + assert new_checkpoint == { + "pytorch-lightning_version": pl.__version__, + "content": 123 + } + + +def _run_simple_migration(monkeypatch, old_checkpoint): + call_order = [] + + def dummy_upgrade(tag): + def upgrade(ckpt): + call_order.append(tag) + return ckpt + + return upgrade + + index = { + "0.0.1": [dummy_upgrade("one")], + "0.0.2": [dummy_upgrade("two"), dummy_upgrade("three")], + "1.2.3": [dummy_upgrade("four")], + } + monkeypatch.setattr(pl.utilities.migration.utils, "migration_index", lambda: index) + new_checkpoint = migrate_checkpoint(old_checkpoint) + return new_checkpoint, call_order diff --git a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py index c01fcf7eb249d..2429067ed0557 100644 --- a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py +++ b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py @@ -11,11 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import ANY + import pytest import pytorch_lightning as pl from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint -from pytorch_lightning.utilities.migration import get_version, migrate_checkpoint, set_version +from pytorch_lightning.utilities.migration import migrate_checkpoint +from pytorch_lightning.utilities.migration.utils import _set_version, _get_version @pytest.mark.parametrize( @@ -23,25 +26,25 @@ [ ( {"epoch": 1, "global_step": 23, "checkpoint_callback_best": 0.34}, - {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.34}}}, + {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.34}}, "loops": ANY}, ), ( {"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_score": 0.99}, - {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.99}}}, + {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.99}}, "loops": ANY}, ), ( {"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_path": "path"}, - {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_path": "path"}}}, + {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_path": "path"}}, "loops": ANY}, ), ( {"epoch": 1, "global_step": 23, "early_stop_callback_wait": 2, "early_stop_callback_patience": 4}, - {"epoch": 1, "global_step": 23, "callbacks": {EarlyStopping: {"wait_count": 2, "patience": 4}}}, + {"epoch": 1, "global_step": 23, "callbacks": {EarlyStopping: {"wait_count": 2, "patience": 4}}, "loops": ANY}, ), ], ) def test_upgrade_checkpoint(tmpdir, old_checkpoint, new_checkpoint): - set_version(old_checkpoint, "0.9.0") - set_version(new_checkpoint, pl.__version__) + _set_version(old_checkpoint, "0.9.0") + _set_version(new_checkpoint, pl.__version__) updated_checkpoint = migrate_checkpoint(old_checkpoint) assert updated_checkpoint == new_checkpoint - assert get_version(updated_checkpoint) == pl.__version__ + assert _get_version(updated_checkpoint) == pl.__version__ From b3069a919b9ef003ad4cacfe652ecc3a64e72475 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 3 Nov 2022 11:31:36 +0100 Subject: [PATCH 10/24] update --- .../utilities/migration/migration.py | 78 ++++++++++ .../utilities/migration/migrations.py | 147 ------------------ .../test_migration.py} | 2 +- .../tests_pytorch/utilities/test_migration.py | 36 ----- 4 files changed, 79 insertions(+), 184 deletions(-) delete mode 100644 src/pytorch_lightning/utilities/migration/migrations.py rename tests/tests_pytorch/utilities/{test_upgrade_checkpoint.py => migration/test_migration.py} (98%) delete mode 100644 tests/tests_pytorch/utilities/test_migration.py diff --git a/src/pytorch_lightning/utilities/migration/migration.py b/src/pytorch_lightning/utilities/migration/migration.py index 3431c01709b89..caccd21d55d19 100644 --- a/src/pytorch_lightning/utilities/migration/migration.py +++ b/src/pytorch_lightning/utilities/migration/migration.py @@ -41,6 +41,7 @@ def _migration_index() -> Dict[str, List[Callable[[_CHECKPOINT], _CHECKPOINT]]]: """Migration functions returned here will get executed in the order they are listed.""" return { "0.10.0": [_migrate_model_checkpoint_early_stopping], + "1.6.0": [_migrate_loop_global_step_to_progress_tracking, _migrate_loop_current_epoch_to_progress_tracking], } @@ -67,3 +68,80 @@ def _migrate_model_checkpoint_early_stopping(checkpoint: _CHECKPOINT) -> _CHECKP checkpoint["callbacks"][callback_type][callback_key] = value del checkpoint[key] return checkpoint + + +def _migrate_loop_global_step_to_progress_tracking(checkpoint: _CHECKPOINT) -> _CHECKPOINT: + """Sets the `global_step` value for checkpoints before v1.6 without the progress tracking state. It will be + overwritten by the loop's state if it was also saved. + + Version: 1.6.0 + Commit: + """ + global_step = checkpoint["global_step"] + checkpoint.setdefault("loops", {"fit_loop": _FIT_LOOP_INITIAL_STATE_1_6_0}) + checkpoint["loops"].setdefault("fit_loop", _FIT_LOOP_INITIAL_STATE_1_6_0) + # for automatic optimization + optim_progress = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.optimizer_loop.optim_progress"] + optim_progress["optimizer"]["step"]["total"]["completed"] = global_step + # for manual optimization + optim_step_progress = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.manual_loop.optim_step_progress"] + optim_step_progress["total"]["completed"] = global_step + return checkpoint + + +def _migrate_loop_current_epoch_to_progress_tracking(checkpoint: _CHECKPOINT) -> _CHECKPOINT: + """Sets the `current_epoch` value for checkpoints before v1.6 without the progress tracking state. It will be + overwritten by the loop's state if it was also saved. + + Version: 1.6.0 + Commit: + """ + epoch = checkpoint["epoch"] + checkpoint.setdefault("loops", {"fit_loop": _FIT_LOOP_INITIAL_STATE_1_6_0}) + checkpoint["loops"].setdefault("fit_loop", _FIT_LOOP_INITIAL_STATE_1_6_0) + checkpoint["loops"]["fit_loop"]["epoch_progress"]["current"]["completed"] = epoch + return checkpoint + + +_FIT_LOOP_INITIAL_STATE_1_6_0 = { + "epoch_loop.batch_loop.manual_loop.optim_step_progress": { + "current": {"completed": 0, "ready": 0}, + "total": {"completed": 0, "ready": 0}, + }, + "epoch_loop.batch_loop.manual_loop.state_dict": {}, + "epoch_loop.batch_loop.optimizer_loop.optim_progress": { + "optimizer": { + "step": {"current": {"completed": 0, "ready": 0}, "total": {"completed": 0, "ready": 0}}, + "zero_grad": { + "current": {"completed": 0, "ready": 0, "started": 0}, + "total": {"completed": 0, "ready": 0, "started": 0}, + }, + }, + "optimizer_position": 0, + }, + "epoch_loop.batch_loop.optimizer_loop.state_dict": {}, + "epoch_loop.batch_loop.state_dict": {}, + "epoch_loop.batch_progress": { + "current": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + "is_last_batch": False, + "total": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + }, + "epoch_loop.scheduler_progress": {"current": {"completed": 0, "ready": 0}, "total": {"completed": 0, "ready": 0}}, + "epoch_loop.state_dict": {"_batches_that_stepped": 0}, + "epoch_loop.val_loop.dataloader_progress": { + "current": {"completed": 0, "ready": 0}, + "total": {"completed": 0, "ready": 0}, + }, + "epoch_loop.val_loop.epoch_loop.batch_progress": { + "current": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + "is_last_batch": False, + "total": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + }, + "epoch_loop.val_loop.epoch_loop.state_dict": {}, + "epoch_loop.val_loop.state_dict": {}, + "epoch_progress": { + "current": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + "total": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + }, + "state_dict": {}, +} diff --git a/src/pytorch_lightning/utilities/migration/migrations.py b/src/pytorch_lightning/utilities/migration/migrations.py deleted file mode 100644 index 199337ec540be..0000000000000 --- a/src/pytorch_lightning/utilities/migration/migrations.py +++ /dev/null @@ -1,147 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Contains migration functions to upgrade legacy checkpoints to the format of the current Lightning version. - -When Lightning loads a checkpoint, these migrations will be applied on the loaded checkpoint dictionary sequentially, -see :func:`~pytorch_lightning.utilities.migration.utils.migrate_checkpoint`. - -How to add a new migration? - -1. Create a new function with a descriptive name and docstring that explains the details of this migration. Include - version information as well as the specific commit or PR where the breaking change happened. -2. Add the function to the `migration_index()` below. The key in the index is the version of Lightning in which the - change happened. Any checkpoint with a version greater or equal to that version will apply the given function. - Multiple migrations per version get executed in the provided list order. -3. You can test the migration on a checkpoint (backup your files first) by running: - - cp model.ckpt model.ckpt.backup - python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt -""" - -from typing import Any, Callable, Dict, List - -from pytorch_lightning.callbacks.early_stopping import EarlyStopping -from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint - -_CHECKPOINT = Dict[str, Any] - - -def migration_index() -> Dict[str, List[Callable[[_CHECKPOINT], _CHECKPOINT]]]: - """Migration functions returned here will get executed in the order they are listed.""" - return { - "0.10.0": [_migrate_model_checkpoint_early_stopping], - "1.6.0": [_migrate_loop_global_step_to_progress_tracking, _migrate_loop_current_epoch_to_progress_tracking], - } - - -def _migrate_model_checkpoint_early_stopping(checkpoint: _CHECKPOINT) -> _CHECKPOINT: - """The checkpoint and early stopping keys were renamed. - - Version: 0.10.0 - Commit: - """ - keys_mapping = { - "checkpoint_callback_best_model_score": (ModelCheckpoint, "best_model_score"), - "checkpoint_callback_best_model_path": (ModelCheckpoint, "best_model_path"), - "checkpoint_callback_best": (ModelCheckpoint, "best_model_score"), - "early_stop_callback_wait": (EarlyStopping, "wait_count"), - "early_stop_callback_patience": (EarlyStopping, "patience"), - } - checkpoint["callbacks"] = checkpoint.get("callbacks") or {} - - for key, new_path in keys_mapping.items(): - if key in checkpoint: - value = checkpoint[key] - callback_type, callback_key = new_path - checkpoint["callbacks"][callback_type] = checkpoint["callbacks"].get(callback_type) or {} - checkpoint["callbacks"][callback_type][callback_key] = value - del checkpoint[key] - return checkpoint - - -def _migrate_loop_global_step_to_progress_tracking(checkpoint: _CHECKPOINT) -> _CHECKPOINT: - """Sets the `global_step` value for checkpoints before v1.6 without the progress tracking state. It will be - overwritten by the loop's state if it was also saved. - - Version: 1.6.0 - Commit: - """ - global_step = checkpoint["global_step"] - checkpoint.setdefault("loops", {"fit_loop": _FIT_LOOP_INITIAL_STATE_1_6_0}) - checkpoint["loops"].setdefault("fit_loop", _FIT_LOOP_INITIAL_STATE_1_6_0) - # for automatic optimization - optim_progress = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.optimizer_loop.optim_progress"] - optim_progress["optimizer"]["step"]["total"]["completed"] = global_step - # for manual optimization - optim_step_progress = checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.manual_loop.optim_step_progress"] - optim_step_progress["total"]["completed"] = global_step - return checkpoint - - -def _migrate_loop_current_epoch_to_progress_tracking(checkpoint: _CHECKPOINT) -> _CHECKPOINT: - """Sets the `current_epoch` value for checkpoints before v1.6 without the progress tracking state. It will be - overwritten by the loop's state if it was also saved. - - Version: 1.6.0 - Commit: - """ - epoch = checkpoint["epoch"] - checkpoint.setdefault("loops", {"fit_loop": _FIT_LOOP_INITIAL_STATE_1_6_0}) - checkpoint["loops"].setdefault("fit_loop", _FIT_LOOP_INITIAL_STATE_1_6_0) - checkpoint["loops"]["fit_loop"]["epoch_progress"]["current"]["completed"] = epoch - return checkpoint - - -_FIT_LOOP_INITIAL_STATE_1_6_0 = { - "epoch_loop.batch_loop.manual_loop.optim_step_progress": { - "current": {"completed": 0, "ready": 0}, - "total": {"completed": 0, "ready": 0}, - }, - "epoch_loop.batch_loop.manual_loop.state_dict": {}, - "epoch_loop.batch_loop.optimizer_loop.optim_progress": { - "optimizer": { - "step": {"current": {"completed": 0, "ready": 0}, "total": {"completed": 0, "ready": 0}}, - "zero_grad": { - "current": {"completed": 0, "ready": 0, "started": 0}, - "total": {"completed": 0, "ready": 0, "started": 0}, - }, - }, - "optimizer_position": 0, - }, - "epoch_loop.batch_loop.optimizer_loop.state_dict": {}, - "epoch_loop.batch_loop.state_dict": {}, - "epoch_loop.batch_progress": { - "current": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, - "is_last_batch": False, - "total": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, - }, - "epoch_loop.scheduler_progress": {"current": {"completed": 0, "ready": 0}, "total": {"completed": 0, "ready": 0}}, - "epoch_loop.state_dict": {"_batches_that_stepped": 0}, - "epoch_loop.val_loop.dataloader_progress": { - "current": {"completed": 0, "ready": 0}, - "total": {"completed": 0, "ready": 0}, - }, - "epoch_loop.val_loop.epoch_loop.batch_progress": { - "current": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, - "is_last_batch": False, - "total": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, - }, - "epoch_loop.val_loop.epoch_loop.state_dict": {}, - "epoch_loop.val_loop.state_dict": {}, - "epoch_progress": { - "current": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, - "total": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, - }, - "state_dict": {}, -} diff --git a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py b/tests/tests_pytorch/utilities/migration/test_migration.py similarity index 98% rename from tests/tests_pytorch/utilities/test_upgrade_checkpoint.py rename to tests/tests_pytorch/utilities/migration/test_migration.py index 5b704c6da6be3..6696f24faffa5 100644 --- a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py +++ b/tests/tests_pytorch/utilities/migration/test_migration.py @@ -18,7 +18,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.utilities.migration import migrate_checkpoint -from pytorch_lightning.utilities.migration.utils import _set_version, _get_version +from pytorch_lightning.utilities.migration.utils import _set_version, _get_version, _set_legacy_version @pytest.mark.parametrize( diff --git a/tests/tests_pytorch/utilities/test_migration.py b/tests/tests_pytorch/utilities/test_migration.py deleted file mode 100644 index ee94ee690e798..0000000000000 --- a/tests/tests_pytorch/utilities/test_migration.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import sys - -import pytorch_lightning -from pytorch_lightning.utilities.migration import pl_legacy_patch - - -def test_patch_legacy_argparse_utils(): - with pl_legacy_patch(): - from pytorch_lightning.utilities import argparse_utils - - assert callable(argparse_utils._gpus_arg_default) - assert "pytorch_lightning.utilities.argparse_utils" in sys.modules - - assert "pytorch_lightning.utilities.argparse_utils" not in sys.modules - - -def test_patch_legacy_gpus_arg_default(): - with pl_legacy_patch(): - from pytorch_lightning.utilities.argparse import _gpus_arg_default - - assert callable(_gpus_arg_default) - assert not hasattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default") - assert not hasattr(pytorch_lightning.utilities.argparse, "_gpus_arg_default") From 94cbba99b431155026c338688d47a055de6ec9e5 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 3 Nov 2022 11:32:26 +0100 Subject: [PATCH 11/24] x --- .../tests_pytorch/utilities/migration/test_migration.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/tests_pytorch/utilities/migration/test_migration.py b/tests/tests_pytorch/utilities/migration/test_migration.py index 6696f24faffa5..620fedc22a158 100644 --- a/tests/tests_pytorch/utilities/migration/test_migration.py +++ b/tests/tests_pytorch/utilities/migration/test_migration.py @@ -18,7 +18,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.utilities.migration import migrate_checkpoint -from pytorch_lightning.utilities.migration.utils import _set_version, _get_version, _set_legacy_version +from pytorch_lightning.utilities.migration.utils import _get_version, _set_legacy_version, _set_version @pytest.mark.parametrize( @@ -38,7 +38,12 @@ ), ( {"epoch": 1, "global_step": 23, "early_stop_callback_wait": 2, "early_stop_callback_patience": 4}, - {"epoch": 1, "global_step": 23, "callbacks": {EarlyStopping: {"wait_count": 2, "patience": 4}}, "loops": ANY}, + { + "epoch": 1, + "global_step": 23, + "callbacks": {EarlyStopping: {"wait_count": 2, "patience": 4}}, + "loops": ANY, + }, ), ], ) From dd51bb06dbd757c952475f11ecade23934f73784 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 3 Nov 2022 14:13:32 +0100 Subject: [PATCH 12/24] update --- src/pytorch_lightning/utilities/migration/migration.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/utilities/migration/migration.py b/src/pytorch_lightning/utilities/migration/migration.py index caccd21d55d19..3c785fa95cc88 100644 --- a/src/pytorch_lightning/utilities/migration/migration.py +++ b/src/pytorch_lightning/utilities/migration/migration.py @@ -75,7 +75,7 @@ def _migrate_loop_global_step_to_progress_tracking(checkpoint: _CHECKPOINT) -> _ overwritten by the loop's state if it was also saved. Version: 1.6.0 - Commit: + Commit: aea96e4 """ global_step = checkpoint["global_step"] checkpoint.setdefault("loops", {"fit_loop": _FIT_LOOP_INITIAL_STATE_1_6_0}) @@ -94,7 +94,7 @@ def _migrate_loop_current_epoch_to_progress_tracking(checkpoint: _CHECKPOINT) -> overwritten by the loop's state if it was also saved. Version: 1.6.0 - Commit: + Commit: aea96e4 """ epoch = checkpoint["epoch"] checkpoint.setdefault("loops", {"fit_loop": _FIT_LOOP_INITIAL_STATE_1_6_0}) From a2d7972057d811c5e11c8455e2cf2d79f489982d Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 3 Nov 2022 14:36:01 +0100 Subject: [PATCH 13/24] tests --- .../loops/epoch/training_epoch_loop.py | 3 +-- .../utilities/migration/migration.py | 17 +++++++++++- tests/tests_pytorch/models/test_restore.py | 23 ---------------- .../utilities/migration/test_migration.py | 27 +++++++++++++++++++ 4 files changed, 44 insertions(+), 26 deletions(-) diff --git a/src/pytorch_lightning/loops/epoch/training_epoch_loop.py b/src/pytorch_lightning/loops/epoch/training_epoch_loop.py index 226a69c869311..777fa01b04847 100644 --- a/src/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/src/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -299,8 +299,7 @@ def on_save_checkpoint(self) -> Dict: def on_load_checkpoint(self, state_dict: Dict) -> None: # cache the dataloader state dict until the dataloader objects are available self._dataloader_state_dict = state_dict.get("dataloader_state_dict", {}) - # restore global step instead to make sure logging works correctly if checkpoints None: # reload dataloaders diff --git a/src/pytorch_lightning/utilities/migration/migration.py b/src/pytorch_lightning/utilities/migration/migration.py index 3c785fa95cc88..f7341ba43758c 100644 --- a/src/pytorch_lightning/utilities/migration/migration.py +++ b/src/pytorch_lightning/utilities/migration/migration.py @@ -42,6 +42,7 @@ def _migration_index() -> Dict[str, List[Callable[[_CHECKPOINT], _CHECKPOINT]]]: return { "0.10.0": [_migrate_model_checkpoint_early_stopping], "1.6.0": [_migrate_loop_global_step_to_progress_tracking, _migrate_loop_current_epoch_to_progress_tracking], + "1.6.5": [_migrate_loop_batches_that_stepped] } @@ -75,7 +76,8 @@ def _migrate_loop_global_step_to_progress_tracking(checkpoint: _CHECKPOINT) -> _ overwritten by the loop's state if it was also saved. Version: 1.6.0 - Commit: aea96e4 + Commit: c67b075 + PR: #13645, #11805 """ global_step = checkpoint["global_step"] checkpoint.setdefault("loops", {"fit_loop": _FIT_LOOP_INITIAL_STATE_1_6_0}) @@ -95,6 +97,7 @@ def _migrate_loop_current_epoch_to_progress_tracking(checkpoint: _CHECKPOINT) -> Version: 1.6.0 Commit: aea96e4 + PR: #11805 """ epoch = checkpoint["epoch"] checkpoint.setdefault("loops", {"fit_loop": _FIT_LOOP_INITIAL_STATE_1_6_0}) @@ -103,6 +106,18 @@ def _migrate_loop_current_epoch_to_progress_tracking(checkpoint: _CHECKPOINT) -> return checkpoint +def _migrate_loop_batches_that_stepped(checkpoint: _CHECKPOINT) -> _CHECKPOINT: + """Sets the `_batches_that_stepped` default value for checkpoints before v1.6.5 which don't have this key. + + Version: 1.6.5 + Commit: c67b075 + PR: #13645 + """ + global_step = checkpoint["global_step"] + checkpoint["loops"]["fit_loop"]["epoch_loop.state_dict"].setdefault("_batches_that_stepped", global_step) + return checkpoint + + _FIT_LOOP_INITIAL_STATE_1_6_0 = { "epoch_loop.batch_loop.manual_loop.optim_step_progress": { "current": {"completed": 0, "ready": 0}, diff --git a/tests/tests_pytorch/models/test_restore.py b/tests/tests_pytorch/models/test_restore.py index 8648d9ba1a6bf..fa6c8aa8d2468 100644 --- a/tests/tests_pytorch/models/test_restore.py +++ b/tests/tests_pytorch/models/test_restore.py @@ -254,29 +254,6 @@ def on_train_start(self) -> None: assert trainer.fit_loop.epoch_loop._batches_that_stepped == max_epochs * train_batches -@pytest.mark.parametrize("model_class", [BoringModel, ManualOptimBoringModel]) -def test_logging_step_loaded_correctly_pre_1_6_5(tmpdir, model_class): - trainer = Trainer(max_steps=1, limit_val_batches=0, default_root_dir=tmpdir) - model = model_class() - trainer.fit(model) - ckpt_path = trainer.checkpoint_callback.best_model_path - ckpt = torch.load(ckpt_path) - # the key "_batches_that_stepped" doesn't exist in checkpoints generated with None: - assert self.trainer.global_step == 1 - assert self.trainer.fit_loop.epoch_loop._batches_that_stepped == 1 - - trainer = Trainer(max_steps=2, limit_val_batches=0, default_root_dir=tmpdir) - model = TestModel() - trainer.fit(model, ckpt_path=ckpt_path) - new_loop = trainer.fit_loop.epoch_loop - assert new_loop.global_step == new_loop._batches_that_stepped == 2 - - def test_fit_twice(tmpdir): epochs = [] diff --git a/tests/tests_pytorch/utilities/migration/test_migration.py b/tests/tests_pytorch/utilities/migration/test_migration.py index 620fedc22a158..cfcafdf6be6c1 100644 --- a/tests/tests_pytorch/utilities/migration/test_migration.py +++ b/tests/tests_pytorch/utilities/migration/test_migration.py @@ -14,9 +14,12 @@ from unittest.mock import ANY import pytest +import torch import pytorch_lightning as pl +from pytorch_lightning import Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.demos.boring_classes import BoringModel, ManualOptimBoringModel from pytorch_lightning.utilities.migration import migrate_checkpoint from pytorch_lightning.utilities.migration.utils import _get_version, _set_legacy_version, _set_version @@ -54,3 +57,27 @@ def test_upgrade_checkpoint(tmpdir, old_checkpoint, new_checkpoint): updated_checkpoint, _ = migrate_checkpoint(old_checkpoint) assert updated_checkpoint == old_checkpoint == new_checkpoint assert _get_version(updated_checkpoint) == pl.__version__ + + +@pytest.mark.parametrize("model_class", [BoringModel, ManualOptimBoringModel]) +def test_logging_step_loaded_correctly_pre_1_6_5(tmpdir, model_class): + trainer = Trainer(max_steps=1, limit_val_batches=0, default_root_dir=tmpdir) + model = model_class() + trainer.fit(model) + ckpt_path = trainer.checkpoint_callback.best_model_path + ckpt = torch.load(ckpt_path) + # the key "_batches_that_stepped" doesn't exist in checkpoints generated with None: + assert self.trainer.global_step == 1 + assert self.trainer.fit_loop.epoch_loop._batches_that_stepped == 1 + + trainer = Trainer(max_steps=2, limit_val_batches=0, default_root_dir=tmpdir) + model = TestModel() + trainer.fit(model, ckpt_path=ckpt_path) + new_loop = trainer.fit_loop.epoch_loop + assert new_loop.global_step == new_loop._batches_that_stepped == 2 From 698fc708a445afdbe5e714b9a002cdcc7a353875 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 3 Nov 2022 13:38:14 +0000 Subject: [PATCH 14/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/utilities/migration/migration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/utilities/migration/migration.py b/src/pytorch_lightning/utilities/migration/migration.py index f7341ba43758c..ba1165288b949 100644 --- a/src/pytorch_lightning/utilities/migration/migration.py +++ b/src/pytorch_lightning/utilities/migration/migration.py @@ -42,7 +42,7 @@ def _migration_index() -> Dict[str, List[Callable[[_CHECKPOINT], _CHECKPOINT]]]: return { "0.10.0": [_migrate_model_checkpoint_early_stopping], "1.6.0": [_migrate_loop_global_step_to_progress_tracking, _migrate_loop_current_epoch_to_progress_tracking], - "1.6.5": [_migrate_loop_batches_that_stepped] + "1.6.5": [_migrate_loop_batches_that_stepped], } From cca1c2fe6401510882af680b82b510e09f24dafd Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 3 Nov 2022 14:43:52 +0100 Subject: [PATCH 15/24] update test --- tests/tests_pytorch/utilities/migration/test_migration.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/tests_pytorch/utilities/migration/test_migration.py b/tests/tests_pytorch/utilities/migration/test_migration.py index cfcafdf6be6c1..c53420cb0371f 100644 --- a/tests/tests_pytorch/utilities/migration/test_migration.py +++ b/tests/tests_pytorch/utilities/migration/test_migration.py @@ -50,7 +50,7 @@ ), ], ) -def test_upgrade_checkpoint(tmpdir, old_checkpoint, new_checkpoint): +def test_migrate_model_checkpoint_early_stopping(tmpdir, old_checkpoint, new_checkpoint): _set_version(old_checkpoint, "0.9.0") _set_legacy_version(new_checkpoint, "0.9.0") _set_version(new_checkpoint, pl.__version__) @@ -60,13 +60,14 @@ def test_upgrade_checkpoint(tmpdir, old_checkpoint, new_checkpoint): @pytest.mark.parametrize("model_class", [BoringModel, ManualOptimBoringModel]) -def test_logging_step_loaded_correctly_pre_1_6_5(tmpdir, model_class): +def test_migrate_loop_batches_that_stepped(tmpdir, model_class): trainer = Trainer(max_steps=1, limit_val_batches=0, default_root_dir=tmpdir) model = model_class() trainer.fit(model) ckpt_path = trainer.checkpoint_callback.best_model_path + + # pretend we have a checkpoint produced in < v1.6.5; the key "_batches_that_stepped" didn't exist back then ckpt = torch.load(ckpt_path) - # the key "_batches_that_stepped" doesn't exist in checkpoints generated with Date: Thu, 3 Nov 2022 14:54:12 +0100 Subject: [PATCH 16/24] format --- .../utilities/migration/test_migration.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/tests_pytorch/utilities/migration/test_migration.py b/tests/tests_pytorch/utilities/migration/test_migration.py index c53420cb0371f..d6a94c76720f3 100644 --- a/tests/tests_pytorch/utilities/migration/test_migration.py +++ b/tests/tests_pytorch/utilities/migration/test_migration.py @@ -59,6 +59,33 @@ def test_migrate_model_checkpoint_early_stopping(tmpdir, old_checkpoint, new_che assert _get_version(updated_checkpoint) == pl.__version__ +def test_migrate_loop_global_step_to_progress_tracking(): + old_checkpoint = {"global_step": 15, "epoch": 2} + _set_version(old_checkpoint, "1.5.9") # pretend a checkpoint prior to 1.6.0 + updated_checkpoint, _ = migrate_checkpoint(old_checkpoint) + # automatic optimization + assert ( + updated_checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.optimizer_loop.optim_progress"]["optimizer"][ + "step" + ]["total"]["completed"] + == 15 + ) + # for manual optimization + assert ( + updated_checkpoint["loops"]["fit_loop"]["epoch_loop.batch_loop.manual_loop.optim_step_progress"]["total"][ + "completed" + ] + == 15 + ) + + +def test_migrate_loop_current_epoch_to_progress_tracking(): + old_checkpoint = {"global_step": 15, "epoch": 2} + _set_version(old_checkpoint, "1.5.9") # pretend a checkpoint prior to 1.6.0 + updated_checkpoint, _ = migrate_checkpoint(old_checkpoint) + assert updated_checkpoint["loops"]["fit_loop"]["epoch_progress"]["current"]["completed"] == 2 + + @pytest.mark.parametrize("model_class", [BoringModel, ManualOptimBoringModel]) def test_migrate_loop_batches_that_stepped(tmpdir, model_class): trainer = Trainer(max_steps=1, limit_val_batches=0, default_root_dir=tmpdir) From 4e6013c5c7ab9264bbed88cccf40de1c67058e38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 3 Nov 2022 14:54:24 +0100 Subject: [PATCH 17/24] notebook --- _notebooks | 1 - docs/source-app/examples/file_server/app.py | 19 +++------------ .../examples/github_repo_runner/app.py | 24 ++++++++----------- .../build_command_line_interface/app.py | 5 +--- .../commands/notebook/run.py | 2 +- .../post_example.py | 2 +- .../workflows/build_rest_api/post_example.py | 2 +- .../build_rest_api/post_example_pydantic.py | 2 +- 8 files changed, 18 insertions(+), 39 deletions(-) delete mode 160000 _notebooks diff --git a/_notebooks b/_notebooks deleted file mode 160000 index 6d5634b794218..0000000000000 --- a/_notebooks +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 6d5634b7942180e6ba4a30bfbd74926d1c22f1eb diff --git a/docs/source-app/examples/file_server/app.py b/docs/source-app/examples/file_server/app.py index e36040f1b8662..fe31aae3b6f19 100644 --- a/docs/source-app/examples/file_server/app.py +++ b/docs/source-app/examples/file_server/app.py @@ -10,13 +10,7 @@ class FileServer(L.LightningWork): - def __init__( - self, - drive: Drive, - base_dir: str = "file_server", - chunk_size=10240, - **kwargs - ): + def __init__(self, drive: Drive, base_dir: str = "file_server", chunk_size=10240, **kwargs): """This component uploads, downloads files to your application. Arguments: @@ -54,9 +48,7 @@ def upload_file(self, file): filename = file.filename uploaded_file = self.get_random_filename() meta_file = uploaded_file + ".meta" - self.uploaded_files[filename] = { - "progress": (0, None), "done": False - } + self.uploaded_files[filename] = {"progress": (0, None), "done": False} # 2: Create a stream and write bytes of # the file to the disk under `uploaded_file` path. @@ -163,7 +155,6 @@ def alive(self): class TestFileServer(LightningWork): - def __init__(self, drive: Drive): super().__init__(cache_calls=True) self.drive = drive @@ -173,10 +164,7 @@ def run(self, file_server_url: str, first=True): with open("test.txt", "w") as f: f.write("Some text.") - response = requests.post( - file_server_url + "/upload_file/", - files={'file': open("test.txt", 'rb')} - ) + response = requests.post(file_server_url + "/upload_file/", files={"file": open("test.txt", "rb")}) assert response.status_code == 200 else: response = requests.get(file_server_url) @@ -188,7 +176,6 @@ def run(self, file_server_url: str, first=True): class Flow(LightningFlow): - def __init__(self): super().__init__() # 1: Create a drive to share data between works diff --git a/docs/source-app/examples/github_repo_runner/app.py b/docs/source-app/examples/github_repo_runner/app.py index 70e20ac380d31..d0c43d8b76b63 100644 --- a/docs/source-app/examples/github_repo_runner/app.py +++ b/docs/source-app/examples/github_repo_runner/app.py @@ -56,8 +56,7 @@ def run(self, *args, **kwargs): # 2: Use git command line to clone the repo. repo_name = self.github_repo.split("/")[-1].replace(".git", "") cwd = os.path.dirname(__file__) - subprocess.Popen( - f"git clone {self.github_repo}", cwd=cwd, shell=True).wait() + subprocess.Popen(f"git clone {self.github_repo}", cwd=cwd, shell=True).wait() # 3: Execute the parent run method of the TracerPythonScript class. os.chdir(os.path.join(cwd, repo_name)) @@ -73,7 +72,6 @@ def configure_layout(self): class PyTorchLightningGithubRepoRunner(GithubRepoRunner): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.best_model_path = None @@ -105,8 +103,7 @@ def trainer_pre_fn(self, *args, work=None, **kwargs): # 5. Patch the `__init__` method of the Trainer # to inject our callback with a reference to the work. - tracer.add_traced( - Trainer, "__init__", pre_fn=partial(trainer_pre_fn, work=self)) + tracer.add_traced(Trainer, "__init__", pre_fn=partial(trainer_pre_fn, work=self)) return tracer def on_after_run(self, end_script_globals): @@ -213,9 +210,7 @@ def page_1__create_new_run(state): script_path = st.text_input("Enter your script to run", value="train_script.py") script_args = st.text_input("Enter your base script arguments", value=default_script_args) requirements = st.text_input("Enter your requirements", value=default_requirements) - ml_framework = st.radio( - "Select your ML Training Frameworks", options=["PyTorch Lightning", "Keras", "Tensorflow"] - ) + ml_framework = st.radio("Select your ML Training Frameworks", options=["PyTorch Lightning", "Keras", "Tensorflow"]) if ml_framework not in ("PyTorch Lightning"): st.write(f"{ml_framework} isn't supported yet.") @@ -279,8 +274,7 @@ def render_fn(state: AppState): "View your Runs": partial(page_2__view_run_lists, state=state), "View the App state": partial(page_3__view_app_state, state=state), } - selected_page = st.sidebar.selectbox( - "Select a page", page_names_to_funcs.keys()) + selected_page = st.sidebar.selectbox("Select a page", page_names_to_funcs.keys()) page_names_to_funcs[selected_page]() @@ -296,10 +290,12 @@ def run(self): def configure_layout(self): # 1: Add the main StreamLit UI - selection_tab = [{ - "name": "Run your Github Repo", - "content": self.flow, - }] + selection_tab = [ + { + "name": "Run your Github Repo", + "content": self.flow, + } + ] # 2: Add a new tab whenever a new work is dynamically created run_tabs = [e.configure_layout() for e in self.flow.ws.values()] # 3: Returns the list of tabs. diff --git a/docs/source-app/workflows/build_command_line_interface/app.py b/docs/source-app/workflows/build_command_line_interface/app.py index f6a398096b96c..7ac0231d32f01 100644 --- a/docs/source-app/workflows/build_command_line_interface/app.py +++ b/docs/source-app/workflows/build_command_line_interface/app.py @@ -6,7 +6,6 @@ class Flow(L.LightningFlow): - def __init__(self): super().__init__() self.notebooks = Dict() @@ -17,9 +16,7 @@ def run_notebook(self, config: RunNotebookConfig): return f"The Notebook {config.name} already exists." else: # 2. Dynamically creates the Notebook if it doesn't exist and runs it. - self.notebooks[config.name] = JupyterLab( - cloud_compute=L.CloudCompute(config.cloud_compute) - ) + self.notebooks[config.name] = JupyterLab(cloud_compute=L.CloudCompute(config.cloud_compute)) self.notebooks[config.name].run() return f"The Notebook {config.name} was created." diff --git a/docs/source-app/workflows/build_command_line_interface/commands/notebook/run.py b/docs/source-app/workflows/build_command_line_interface/commands/notebook/run.py index a44e6bfa4f9c8..c36252dd714b3 100644 --- a/docs/source-app/workflows/build_command_line_interface/commands/notebook/run.py +++ b/docs/source-app/workflows/build_command_line_interface/commands/notebook/run.py @@ -17,7 +17,7 @@ class RunNotebook(ClientCommand): def run(self): # 1. Define your own argument parser. You can use argparse, click, etc... - parser = ArgumentParser(description='Run Notebook Parser') + parser = ArgumentParser(description="Run Notebook Parser") parser.add_argument("--name", type=str, default=None) parser.add_argument("--cloud_compute", type=str, default="cpu") hparams = parser.parse_args() diff --git a/docs/source-app/workflows/build_command_line_interface/post_example.py b/docs/source-app/workflows/build_command_line_interface/post_example.py index c7f87f1cffdf7..43dcd92408e38 100644 --- a/docs/source-app/workflows/build_command_line_interface/post_example.py +++ b/docs/source-app/workflows/build_command_line_interface/post_example.py @@ -16,7 +16,7 @@ def run(self): # 3. Method executed when a request is received. def handle_post(self, name: str): self.names.append(name) - return f'The name {name} was registered' + return f"The name {name} was registered" # 4. Defines this Component's Restful API. You can have several routes. def configure_api(self): diff --git a/docs/source-app/workflows/build_rest_api/post_example.py b/docs/source-app/workflows/build_rest_api/post_example.py index 4a306f176e4b0..a900ff51fbbcb 100644 --- a/docs/source-app/workflows/build_rest_api/post_example.py +++ b/docs/source-app/workflows/build_rest_api/post_example.py @@ -16,7 +16,7 @@ def run(self): # 3. Method executed when a request is received. def handle_post(self, name: str): self.names.append(name) - return f'The name {name} was registered' + return f"The name {name} was registered" # 4. Defines this Component's Restful API. You can have several routes. def configure_api(self): diff --git a/docs/source-app/workflows/build_rest_api/post_example_pydantic.py b/docs/source-app/workflows/build_rest_api/post_example_pydantic.py index e3c16ca35de48..4e2306cf5214b 100644 --- a/docs/source-app/workflows/build_rest_api/post_example_pydantic.py +++ b/docs/source-app/workflows/build_rest_api/post_example_pydantic.py @@ -18,7 +18,7 @@ def run(self): # 3. Annotate your input with your custom pydantic model. def handle_post(self, config: NamePostConfig): self.names.append(config.name) - return f'The name {config} was registered' + return f"The name {config} was registered" # 4. Defines this Component's Restful API. You can have several routes. def configure_api(self): From 093676da7e96f5ce2520804977869636be8b3199 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 3 Nov 2022 14:54:33 +0100 Subject: [PATCH 18/24] notebook --- _notebooks | 1 + 1 file changed, 1 insertion(+) create mode 160000 _notebooks diff --git a/_notebooks b/_notebooks new file mode 160000 index 0000000000000..0ad097a6fec2b --- /dev/null +++ b/_notebooks @@ -0,0 +1 @@ +Subproject commit 0ad097a6fec2b2c3f8ddd5d2263e178c41d614f5 From d899bf1d1c4910a06b5566a404c633a46061eed2 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 3 Nov 2022 14:58:03 +0100 Subject: [PATCH 19/24] reset --- docs/source-app/conf.py | 8 +++---- docs/source-app/examples/file_server/app.py | 19 ++++++++++++--- .../examples/github_repo_runner/app.py | 24 +++++++++++-------- .../build_command_line_interface/app.py | 5 +++- .../commands/notebook/run.py | 2 +- .../post_example.py | 2 +- .../workflows/build_rest_api/post_example.py | 2 +- .../build_rest_api/post_example_pydantic.py | 2 +- 8 files changed, 42 insertions(+), 22 deletions(-) diff --git a/docs/source-app/conf.py b/docs/source-app/conf.py index 62bd18a900417..0cb428ca33656 100644 --- a/docs/source-app/conf.py +++ b/docs/source-app/conf.py @@ -16,7 +16,7 @@ import shutil import sys -import pt_lightning_sphinx_theme +import lai_sphinx_theme import lightning_app @@ -93,7 +93,7 @@ "sphinx_paramlinks", "sphinx_togglebutton", "sphinx.ext.githubpages", - "pt_lightning_sphinx_theme.extensions.lightning", + "lai_sphinx_theme.extensions.lightning", ] # Add any paths that contain templates here, relative to this directory. @@ -149,8 +149,8 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = "pt_lightning_sphinx_theme" -html_theme_path = [pt_lightning_sphinx_theme.get_html_theme_path()] +html_theme = "lai_sphinx_theme" +html_theme_path = [lai_sphinx_theme.get_html_theme_path()] # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the diff --git a/docs/source-app/examples/file_server/app.py b/docs/source-app/examples/file_server/app.py index fe31aae3b6f19..e36040f1b8662 100644 --- a/docs/source-app/examples/file_server/app.py +++ b/docs/source-app/examples/file_server/app.py @@ -10,7 +10,13 @@ class FileServer(L.LightningWork): - def __init__(self, drive: Drive, base_dir: str = "file_server", chunk_size=10240, **kwargs): + def __init__( + self, + drive: Drive, + base_dir: str = "file_server", + chunk_size=10240, + **kwargs + ): """This component uploads, downloads files to your application. Arguments: @@ -48,7 +54,9 @@ def upload_file(self, file): filename = file.filename uploaded_file = self.get_random_filename() meta_file = uploaded_file + ".meta" - self.uploaded_files[filename] = {"progress": (0, None), "done": False} + self.uploaded_files[filename] = { + "progress": (0, None), "done": False + } # 2: Create a stream and write bytes of # the file to the disk under `uploaded_file` path. @@ -155,6 +163,7 @@ def alive(self): class TestFileServer(LightningWork): + def __init__(self, drive: Drive): super().__init__(cache_calls=True) self.drive = drive @@ -164,7 +173,10 @@ def run(self, file_server_url: str, first=True): with open("test.txt", "w") as f: f.write("Some text.") - response = requests.post(file_server_url + "/upload_file/", files={"file": open("test.txt", "rb")}) + response = requests.post( + file_server_url + "/upload_file/", + files={'file': open("test.txt", 'rb')} + ) assert response.status_code == 200 else: response = requests.get(file_server_url) @@ -176,6 +188,7 @@ def run(self, file_server_url: str, first=True): class Flow(LightningFlow): + def __init__(self): super().__init__() # 1: Create a drive to share data between works diff --git a/docs/source-app/examples/github_repo_runner/app.py b/docs/source-app/examples/github_repo_runner/app.py index d0c43d8b76b63..70e20ac380d31 100644 --- a/docs/source-app/examples/github_repo_runner/app.py +++ b/docs/source-app/examples/github_repo_runner/app.py @@ -56,7 +56,8 @@ def run(self, *args, **kwargs): # 2: Use git command line to clone the repo. repo_name = self.github_repo.split("/")[-1].replace(".git", "") cwd = os.path.dirname(__file__) - subprocess.Popen(f"git clone {self.github_repo}", cwd=cwd, shell=True).wait() + subprocess.Popen( + f"git clone {self.github_repo}", cwd=cwd, shell=True).wait() # 3: Execute the parent run method of the TracerPythonScript class. os.chdir(os.path.join(cwd, repo_name)) @@ -72,6 +73,7 @@ def configure_layout(self): class PyTorchLightningGithubRepoRunner(GithubRepoRunner): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.best_model_path = None @@ -103,7 +105,8 @@ def trainer_pre_fn(self, *args, work=None, **kwargs): # 5. Patch the `__init__` method of the Trainer # to inject our callback with a reference to the work. - tracer.add_traced(Trainer, "__init__", pre_fn=partial(trainer_pre_fn, work=self)) + tracer.add_traced( + Trainer, "__init__", pre_fn=partial(trainer_pre_fn, work=self)) return tracer def on_after_run(self, end_script_globals): @@ -210,7 +213,9 @@ def page_1__create_new_run(state): script_path = st.text_input("Enter your script to run", value="train_script.py") script_args = st.text_input("Enter your base script arguments", value=default_script_args) requirements = st.text_input("Enter your requirements", value=default_requirements) - ml_framework = st.radio("Select your ML Training Frameworks", options=["PyTorch Lightning", "Keras", "Tensorflow"]) + ml_framework = st.radio( + "Select your ML Training Frameworks", options=["PyTorch Lightning", "Keras", "Tensorflow"] + ) if ml_framework not in ("PyTorch Lightning"): st.write(f"{ml_framework} isn't supported yet.") @@ -274,7 +279,8 @@ def render_fn(state: AppState): "View your Runs": partial(page_2__view_run_lists, state=state), "View the App state": partial(page_3__view_app_state, state=state), } - selected_page = st.sidebar.selectbox("Select a page", page_names_to_funcs.keys()) + selected_page = st.sidebar.selectbox( + "Select a page", page_names_to_funcs.keys()) page_names_to_funcs[selected_page]() @@ -290,12 +296,10 @@ def run(self): def configure_layout(self): # 1: Add the main StreamLit UI - selection_tab = [ - { - "name": "Run your Github Repo", - "content": self.flow, - } - ] + selection_tab = [{ + "name": "Run your Github Repo", + "content": self.flow, + }] # 2: Add a new tab whenever a new work is dynamically created run_tabs = [e.configure_layout() for e in self.flow.ws.values()] # 3: Returns the list of tabs. diff --git a/docs/source-app/workflows/build_command_line_interface/app.py b/docs/source-app/workflows/build_command_line_interface/app.py index 7ac0231d32f01..f6a398096b96c 100644 --- a/docs/source-app/workflows/build_command_line_interface/app.py +++ b/docs/source-app/workflows/build_command_line_interface/app.py @@ -6,6 +6,7 @@ class Flow(L.LightningFlow): + def __init__(self): super().__init__() self.notebooks = Dict() @@ -16,7 +17,9 @@ def run_notebook(self, config: RunNotebookConfig): return f"The Notebook {config.name} already exists." else: # 2. Dynamically creates the Notebook if it doesn't exist and runs it. - self.notebooks[config.name] = JupyterLab(cloud_compute=L.CloudCompute(config.cloud_compute)) + self.notebooks[config.name] = JupyterLab( + cloud_compute=L.CloudCompute(config.cloud_compute) + ) self.notebooks[config.name].run() return f"The Notebook {config.name} was created." diff --git a/docs/source-app/workflows/build_command_line_interface/commands/notebook/run.py b/docs/source-app/workflows/build_command_line_interface/commands/notebook/run.py index c36252dd714b3..a44e6bfa4f9c8 100644 --- a/docs/source-app/workflows/build_command_line_interface/commands/notebook/run.py +++ b/docs/source-app/workflows/build_command_line_interface/commands/notebook/run.py @@ -17,7 +17,7 @@ class RunNotebook(ClientCommand): def run(self): # 1. Define your own argument parser. You can use argparse, click, etc... - parser = ArgumentParser(description="Run Notebook Parser") + parser = ArgumentParser(description='Run Notebook Parser') parser.add_argument("--name", type=str, default=None) parser.add_argument("--cloud_compute", type=str, default="cpu") hparams = parser.parse_args() diff --git a/docs/source-app/workflows/build_command_line_interface/post_example.py b/docs/source-app/workflows/build_command_line_interface/post_example.py index 43dcd92408e38..c7f87f1cffdf7 100644 --- a/docs/source-app/workflows/build_command_line_interface/post_example.py +++ b/docs/source-app/workflows/build_command_line_interface/post_example.py @@ -16,7 +16,7 @@ def run(self): # 3. Method executed when a request is received. def handle_post(self, name: str): self.names.append(name) - return f"The name {name} was registered" + return f'The name {name} was registered' # 4. Defines this Component's Restful API. You can have several routes. def configure_api(self): diff --git a/docs/source-app/workflows/build_rest_api/post_example.py b/docs/source-app/workflows/build_rest_api/post_example.py index a900ff51fbbcb..4a306f176e4b0 100644 --- a/docs/source-app/workflows/build_rest_api/post_example.py +++ b/docs/source-app/workflows/build_rest_api/post_example.py @@ -16,7 +16,7 @@ def run(self): # 3. Method executed when a request is received. def handle_post(self, name: str): self.names.append(name) - return f"The name {name} was registered" + return f'The name {name} was registered' # 4. Defines this Component's Restful API. You can have several routes. def configure_api(self): diff --git a/docs/source-app/workflows/build_rest_api/post_example_pydantic.py b/docs/source-app/workflows/build_rest_api/post_example_pydantic.py index 4e2306cf5214b..e3c16ca35de48 100644 --- a/docs/source-app/workflows/build_rest_api/post_example_pydantic.py +++ b/docs/source-app/workflows/build_rest_api/post_example_pydantic.py @@ -18,7 +18,7 @@ def run(self): # 3. Annotate your input with your custom pydantic model. def handle_post(self, config: NamePostConfig): self.names.append(config.name) - return f"The name {config} was registered" + return f'The name {config} was registered' # 4. Defines this Component's Restful API. You can have several routes. def configure_api(self): From aa6eca6ad371bacbb2be8177297119bbed6955d0 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 3 Nov 2022 15:04:00 +0100 Subject: [PATCH 20/24] remove unused import --- tests/tests_pytorch/models/test_restore.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_pytorch/models/test_restore.py b/tests/tests_pytorch/models/test_restore.py index fa6c8aa8d2468..6c368935a9f9e 100644 --- a/tests/tests_pytorch/models/test_restore.py +++ b/tests/tests_pytorch/models/test_restore.py @@ -29,7 +29,7 @@ from lightning_lite import seed_everything from pytorch_lightning import Callback, Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.demos.boring_classes import BoringModel, ManualOptimBoringModel +from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.trainer.states import TrainerFn from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf From 1fe623bbaa5cac5ed7168dd9bf56b483487a74b5 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 3 Nov 2022 17:30:14 +0100 Subject: [PATCH 21/24] add missing keys --- .../tests_pytorch/utilities/migration/test_utils.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/tests_pytorch/utilities/migration/test_utils.py b/tests/tests_pytorch/utilities/migration/test_utils.py index d662cf5e89833..80b337f71a9d0 100644 --- a/tests/tests_pytorch/utilities/migration/test_utils.py +++ b/tests/tests_pytorch/utilities/migration/test_utils.py @@ -13,6 +13,7 @@ # limitations under the License. import logging import sys +from unittest.mock import ANY import pytest @@ -109,26 +110,28 @@ def test_migrate_checkpoint_for_pl(caplog): """Test that the automatic migration in Lightning informs the user about how to make the upgrade permanent.""" # simulate a very recent checkpoint, no migrations needed - loaded_checkpoint = {"pytorch-lightning_version": pl.__version__, "content": 123} + loaded_checkpoint = {"pytorch-lightning_version": pl.__version__, "global_step": 2, "epoch": 0} new_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, "path/to/ckpt") - assert new_checkpoint == {"pytorch-lightning_version": pl.__version__, "content": 123} + assert new_checkpoint == {"pytorch-lightning_version": pl.__version__, "global_step": 2, "epoch": 0} # simulate an old checkpoint that needed an upgrade - loaded_checkpoint = {"pytorch-lightning_version": "0.0.1", "content": 123} + loaded_checkpoint = {"pytorch-lightning_version": "0.0.1", "global_step": 2, "epoch": 0} with caplog.at_level(logging.INFO, logger="pytorch_lightning.utilities.migration.utils"): new_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, "path/to/ckpt") assert new_checkpoint == { "legacy_pytorch-lightning_version": "0.0.1", "pytorch-lightning_version": pl.__version__, "callbacks": {}, - "content": 123, + "global_step": 2, + "epoch": 0, + "loops": ANY, } assert f"Lightning automatically upgraded your loaded checkpoint from v0.0.1 to v{pl.__version__}" in caplog.text def test_migrate_checkpoint_legacy_version(monkeypatch): """Test that the legacy version gets set and does not change if migration is applied multiple times.""" - loaded_checkpoint = {"pytorch-lightning_version": "0.0.1", "content": 123} + loaded_checkpoint = {"pytorch-lightning_version": "0.0.1", "global_step": 2, "epoch": 0} # pretend the current pl version is 2.0 monkeypatch.setattr(pl, "__version__", "2.0.0") From 3792f350f32a30299c3adc788c9fce9735c97d49 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 4 Nov 2022 18:10:37 +0100 Subject: [PATCH 22/24] notebook --- _notebooks | 1 - 1 file changed, 1 deletion(-) delete mode 160000 _notebooks diff --git a/_notebooks b/_notebooks deleted file mode 160000 index 0ad097a6fec2b..0000000000000 --- a/_notebooks +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 0ad097a6fec2b2c3f8ddd5d2263e178c41d614f5 From 1244c8b5ace8d443913acd04574c9b5ba1c10f3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 4 Nov 2022 18:10:49 +0100 Subject: [PATCH 23/24] notebook --- _notebooks | 1 + 1 file changed, 1 insertion(+) create mode 160000 _notebooks diff --git a/_notebooks b/_notebooks new file mode 160000 index 0000000000000..6d5634b794218 --- /dev/null +++ b/_notebooks @@ -0,0 +1 @@ +Subproject commit 6d5634b7942180e6ba4a30bfbd74926d1c22f1eb From a16ada6ab6ce89a98bbef6c6038b68606a3396e5 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sat, 5 Nov 2022 12:26:23 +0100 Subject: [PATCH 24/24] fix merge --- .../utilities/test_upgrade_checkpoint.py | 34 ------------------- 1 file changed, 34 deletions(-) diff --git a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py index 2a53448f5189c..1777849e09ca1 100644 --- a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py +++ b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py @@ -18,43 +18,9 @@ import pytest -import pytorch_lightning as pl -from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint -from pytorch_lightning.utilities.migration import migrate_checkpoint -from pytorch_lightning.utilities.migration.utils import _get_version, _set_legacy_version, _set_version from pytorch_lightning.utilities.upgrade_checkpoint import main as upgrade_main -@pytest.mark.parametrize( - "old_checkpoint, new_checkpoint", - [ - ( - {"epoch": 1, "global_step": 23, "checkpoint_callback_best": 0.34}, - {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.34}}}, - ), - ( - {"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_score": 0.99}, - {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_score": 0.99}}}, - ), - ( - {"epoch": 1, "global_step": 23, "checkpoint_callback_best_model_path": "path"}, - {"epoch": 1, "global_step": 23, "callbacks": {ModelCheckpoint: {"best_model_path": "path"}}}, - ), - ( - {"epoch": 1, "global_step": 23, "early_stop_callback_wait": 2, "early_stop_callback_patience": 4}, - {"epoch": 1, "global_step": 23, "callbacks": {EarlyStopping: {"wait_count": 2, "patience": 4}}}, - ), - ], -) -def test_upgrade_checkpoint(tmpdir, old_checkpoint, new_checkpoint): - _set_version(old_checkpoint, "0.9.0") - _set_legacy_version(new_checkpoint, "0.9.0") - _set_version(new_checkpoint, pl.__version__) - updated_checkpoint, _ = migrate_checkpoint(old_checkpoint) - assert updated_checkpoint == old_checkpoint == new_checkpoint - assert _get_version(updated_checkpoint) == pl.__version__ - - def test_upgrade_checkpoint_file_missing(tmp_path, caplog): # path to single file (missing) file = tmp_path / "checkpoint.ckpt"