From d520c81558cd4d85de574f4e9328225b386453da Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 23 Feb 2021 23:08:19 +0100 Subject: [PATCH 01/11] prune prefix --- CHANGELOG.md | 3 +++ .../callbacks/model_checkpoint.py | 19 +------------------ tests/deprecated_api/test_remove_1-2.py | 14 -------------- tests/deprecated_api/test_remove_1-3.py | 7 ------- 4 files changed, 4 insertions(+), 39 deletions(-) delete mode 100644 tests/deprecated_api/test_remove_1-2.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 78d1621c27b77..57ff6ae5a0daa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * from `pytorch_lightning.metrics.functional.reduction` removed `reduce`, `class_reduce` +- Removed deprecated `ModelCheckpoint` arguments `prefix` + + ### Fixed - Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011)) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 461c211baab12..067b9ee08c63d 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -102,10 +102,6 @@ class ModelCheckpoint(Callback): saved (``model.save_weights(filepath)``), else the full model is saved (``model.save(filepath)``). period: Interval (number of epochs) between checkpoints. - prefix: A string to put at the beginning of checkpoint filename. - - .. warning:: - This argument has been deprecated in v1.1 and will be removed in v1.3 Note: For extra customization, ModelCheckpoint includes the following attributes: @@ -168,7 +164,6 @@ def __init__( save_weights_only: bool = False, mode: str = "auto", period: int = 1, - prefix: str = "", ): super().__init__() self.monitor = monitor @@ -178,7 +173,6 @@ def __init__( self.save_weights_only = save_weights_only self.period = period self._last_global_step_saved = -1 - self.prefix = prefix self.current_score = None self.best_k_models = {} self.kth_best_model_path = "" @@ -188,12 +182,6 @@ def __init__( self.save_function = None self.warned_result_obj = False - if prefix: - rank_zero_warn( - 'Argument `prefix` is deprecated in v1.1 and will be removed in v1.3.' - ' Please prepend your prefix in `filename` instead.', DeprecationWarning - ) - self.__init_monitor_mode(monitor, mode) self.__init_ckpt_dir(dirpath, filename, save_top_k) self.__validate_init_configuration() @@ -365,7 +353,6 @@ def _format_checkpoint_name( epoch: int, step: int, metrics: Dict[str, Any], - prefix: str = "", ) -> str: if not filename: # filename is not set, use default name @@ -382,9 +369,6 @@ def _format_checkpoint_name( metrics[name] = 0 filename = filename.format(**metrics) - if prefix: - filename = cls.CHECKPOINT_JOIN_CHAR.join([prefix, filename]) - return filename def format_checkpoint_name(self, epoch: int, step: int, metrics: Dict[str, Any], ver: Optional[int] = None) -> str: @@ -410,7 +394,7 @@ def format_checkpoint_name(self, epoch: int, step: int, metrics: Dict[str, Any], 'step=0.ckpt' """ - filename = self._format_checkpoint_name(self.filename, epoch, step, metrics, prefix=self.prefix) + filename = self._format_checkpoint_name(self.filename, epoch, step, metrics) if ver is not None: filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}")) @@ -523,7 +507,6 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics): trainer.current_epoch, trainer.global_step, ckpt_name_metrics, - prefix=self.prefix ) last_filepath = os.path.join(self.dirpath, f"{last_filepath}{self.FILE_EXTENSION}") else: diff --git a/tests/deprecated_api/test_remove_1-2.py b/tests/deprecated_api/test_remove_1-2.py deleted file mode 100644 index 54df59ce0530e..0000000000000 --- a/tests/deprecated_api/test_remove_1-2.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Test deprecated functionality which will be removed in vX.Y.Z""" diff --git a/tests/deprecated_api/test_remove_1-3.py b/tests/deprecated_api/test_remove_1-3.py index 2cf0dd990c8d7..5bdb7173f6a56 100644 --- a/tests/deprecated_api/test_remove_1-3.py +++ b/tests/deprecated_api/test_remove_1-3.py @@ -20,13 +20,6 @@ def test_v1_3_0_deprecated_arguments(tmpdir): - with pytest.deprecated_call(match='will no longer be supported in v1.3'): - callback = ModelCheckpoint() - Trainer(checkpoint_callback=callback, callbacks=[], default_root_dir=tmpdir) - - # Deprecate prefix - with pytest.deprecated_call(match='will be removed in v1.3'): - ModelCheckpoint(prefix='temp') # Deprecate auto mode with pytest.deprecated_call(match='will be removed in v1.3'): From bb4d8a199fa8a29357060a471cacf67247e54931 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 23 Feb 2021 23:13:36 +0100 Subject: [PATCH 02/11] prune mode=auto --- CHANGELOG.md | 2 +- .../callbacks/model_checkpoint.py | 34 +++++-------------- tests/deprecated_api/test_remove_1-3.py | 7 ---- 3 files changed, 9 insertions(+), 34 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 57ff6ae5a0daa..977c2a3ec0aba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,7 +31,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * from `pytorch_lightning.metrics.functional.reduction` removed `reduce`, `class_reduce` -- Removed deprecated `ModelCheckpoint` arguments `prefix` +- Removed deprecated `ModelCheckpoint` arguments `prefix`, `mode="auto"` ### Fixed diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 067b9ee08c63d..e6176be864e9c 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -86,18 +86,10 @@ class ModelCheckpoint(Callback): if ``save_top_k >= 2`` and the callback is called multiple times inside an epoch, the name of the saved file will be appended with a version count starting with ``v1``. - mode: one of {auto, min, max}. - If ``save_top_k != 0``, the decision - to overwrite the current save file is made - based on either the maximization or the - minimization of the monitored quantity. For `val_acc`, - this should be `max`, for `val_loss` this should - be `min`, etc. In `auto` mode, the direction is - automatically inferred from the name of the monitored quantity. - - .. warning:: - Setting ``mode='auto'`` has been deprecated in v1.1 and will be removed in v1.3. - + mode: one of {min, max}. + If ``save_top_k != 0``, the decision to overwrite the current save file is made + based on either the maximization or the minimization of the monitored quantity. + For `val_acc`, this should be `max`, for `val_loss` this should be `min`, etc. save_weights_only: if ``True``, then only the model's weights will be saved (``model.save_weights(filepath)``), else the full model is saved (``model.save(filepath)``). @@ -118,7 +110,7 @@ class ModelCheckpoint(Callback): MisconfigurationException: If ``save_top_k`` is neither ``None`` nor more than or equal to ``-1``, if ``monitor`` is ``None`` and ``save_top_k`` is none of ``None``, ``-1``, and ``0``, or - if ``mode`` is none of ``"min"``, ``"max"``, and ``"auto"``. + if ``mode`` is none of ``"min"``, ``"max"``. ValueError: If ``trainer.save_checkpoint`` is ``None``. @@ -162,7 +154,7 @@ def __init__( save_last: Optional[bool] = None, save_top_k: Optional[int] = None, save_weights_only: bool = False, - mode: str = "auto", + mode: str = "min", period: int = 1, ): super().__init__() @@ -288,18 +280,8 @@ def __init_monitor_mode(self, monitor, mode): "max": (-torch_inf, "max"), } - if mode not in mode_dict and mode != 'auto': - raise MisconfigurationException(f"`mode` can be auto, {', '.join(mode_dict.keys())}, got {mode}") - - # TODO: Update with MisconfigurationException when auto mode is removed in v1.3 - if mode == 'auto': - rank_zero_warn( - "mode='auto' is deprecated in v1.1 and will be removed in v1.3." - " Default value for mode with be 'min' in v1.3.", DeprecationWarning - ) - - _condition = monitor is not None and ("acc" in monitor or monitor.startswith("fmeasure")) - mode_dict['auto'] = ((-torch_inf, "max") if _condition else (torch_inf, "min")) + if mode not in mode_dict: + raise MisconfigurationException(f"`mode` can be {', '.join(mode_dict.keys())}, got {mode}") self.kth_value, self.mode = mode_dict[mode] diff --git a/tests/deprecated_api/test_remove_1-3.py b/tests/deprecated_api/test_remove_1-3.py index 5bdb7173f6a56..4b8710ddcc424 100644 --- a/tests/deprecated_api/test_remove_1-3.py +++ b/tests/deprecated_api/test_remove_1-3.py @@ -21,13 +21,6 @@ def test_v1_3_0_deprecated_arguments(tmpdir): - # Deprecate auto mode - with pytest.deprecated_call(match='will be removed in v1.3'): - ModelCheckpoint(mode='auto') - - with pytest.deprecated_call(match='will be removed in v1.3'): - EarlyStopping(mode='auto') - with pytest.deprecated_call(match="The setter for self.hparams in LightningModule is deprecated"): class DeprecatedHparamsModel(LightningModule): From 96c70e66752f62d6350addd16d045f99c86a6a95 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 23 Feb 2021 23:15:03 +0100 Subject: [PATCH 03/11] chlog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 977c2a3ec0aba..a8eac33132634 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,7 +31,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * from `pytorch_lightning.metrics.functional.reduction` removed `reduce`, `class_reduce` -- Removed deprecated `ModelCheckpoint` arguments `prefix`, `mode="auto"` +- Removed deprecated `ModelCheckpoint` arguments `prefix`, `mode="auto"` ([#6162](https://github.com/PyTorchLightning/pytorch-lightning/pull/6162)) ### Fixed From 72e90725167ea4ba97f142b197423995f815d3a2 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 23 Feb 2021 23:20:55 +0100 Subject: [PATCH 04/11] import --- tests/deprecated_api/test_remove_1-3.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/deprecated_api/test_remove_1-3.py b/tests/deprecated_api/test_remove_1-3.py index 4b8710ddcc424..0e02d6ac70007 100644 --- a/tests/deprecated_api/test_remove_1-3.py +++ b/tests/deprecated_api/test_remove_1-3.py @@ -15,8 +15,7 @@ import pytest -from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning import Trainer def test_v1_3_0_deprecated_arguments(tmpdir): From 038b425b6075fd7d98f268c88c9590daef67d945 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 24 Feb 2021 00:06:56 +0100 Subject: [PATCH 05/11] Use default --- pytorch_lightning/trainer/connectors/callback_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index ee768c05cc8a2..4d8fe9b7b28b9 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -79,7 +79,7 @@ def configure_checkpoint_callbacks(self, checkpoint_callback: Union[ModelCheckpo ) if not self._trainer_has_checkpoint_callbacks() and checkpoint_callback is True: - self.trainer.callbacks.append(ModelCheckpoint(dirpath=None, filename=None, mode='min')) + self.trainer.callbacks.append(ModelCheckpoint()) def _configure_swa_callbacks(self): if not self.trainer._stochastic_weight_avg: From 8e63eb30b09588dae9323fe213ab7c0ecb63873f Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Wed, 24 Feb 2021 00:45:01 +0100 Subject: [PATCH 06/11] Remove hparams code --- docs/source/common/hyperparameters.rst | 3 --- pytorch_lightning/core/lightning.py | 15 --------------- 2 files changed, 18 deletions(-) diff --git a/docs/source/common/hyperparameters.rst b/docs/source/common/hyperparameters.rst index 4f8ca71af5a12..5240a4690e388 100644 --- a/docs/source/common/hyperparameters.rst +++ b/docs/source/common/hyperparameters.rst @@ -167,9 +167,6 @@ improve readability and reproducibility. def train_dataloader(self): return DataLoader(mnist_train, batch_size=self.hparams.batch_size) - .. warning:: Deprecated since v1.1.0. This method of assigning hyperparameters to the LightningModule - will no longer be supported from v1.3.0. Use the ``self.save_hyperparameters()`` method from above instead. - 4. You can also save full objects such as `dict` or `Namespace` to the checkpoint. diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index c4d63cff4637b..73fa47a1bac84 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1806,21 +1806,6 @@ def hparams_initial(self) -> AttributeDict: # prevent any change return copy.deepcopy(self._hparams_initial) - @hparams.setter - def hparams(self, hp: Union[dict, Namespace, Any]): - # TODO: remove this method in v1.3.0. - rank_zero_warn( - "The setter for self.hparams in LightningModule is deprecated since v1.1.0 and will be" - " removed in v1.3.0. Replace the assignment `self.hparams = hparams` with " - " `self.save_hyperparameters()`.", DeprecationWarning - ) - hparams_assignment_name = self.__get_hparams_assignment_variable() - self._hparams_name = hparams_assignment_name - self._set_hparams(hp) - # this resolves case when user does not uses `save_hyperparameters` and do hard assignement in init - if not hasattr(self, "_hparams_initial"): - self._hparams_initial = copy.deepcopy(self._hparams) - def __get_hparams_assignment_variable(self): """ looks at the code of the class to figure out what the user named self.hparams From 486ec5a3499143ac252692a8a929353e3402387a Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 24 Feb 2021 00:50:57 +0100 Subject: [PATCH 07/11] rev --- pytorch_lightning/core/lightning.py | 15 +++++++++++++++ tests/deprecated_api/test_remove_1-3.py | 15 ++++++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 73fa47a1bac84..c4d63cff4637b 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1806,6 +1806,21 @@ def hparams_initial(self) -> AttributeDict: # prevent any change return copy.deepcopy(self._hparams_initial) + @hparams.setter + def hparams(self, hp: Union[dict, Namespace, Any]): + # TODO: remove this method in v1.3.0. + rank_zero_warn( + "The setter for self.hparams in LightningModule is deprecated since v1.1.0 and will be" + " removed in v1.3.0. Replace the assignment `self.hparams = hparams` with " + " `self.save_hyperparameters()`.", DeprecationWarning + ) + hparams_assignment_name = self.__get_hparams_assignment_variable() + self._hparams_name = hparams_assignment_name + self._set_hparams(hp) + # this resolves case when user does not uses `save_hyperparameters` and do hard assignement in init + if not hasattr(self, "_hparams_initial"): + self._hparams_initial = copy.deepcopy(self._hparams) + def __get_hparams_assignment_variable(self): """ looks at the code of the class to figure out what the user named self.hparams diff --git a/tests/deprecated_api/test_remove_1-3.py b/tests/deprecated_api/test_remove_1-3.py index 0e02d6ac70007..86e03b88ef64a 100644 --- a/tests/deprecated_api/test_remove_1-3.py +++ b/tests/deprecated_api/test_remove_1-3.py @@ -15,7 +15,20 @@ import pytest -from pytorch_lightning import Trainer +from pytorch_lightning import LightningModule, Trainer + + +def test_v1_3_0_deprecated_arguments(tmpdir): + + with pytest.deprecated_call(match="The setter for self.hparams in LightningModule is deprecated"): + + class DeprecatedHparamsModel(LightningModule): + + def __init__(self, hparams): + super().__init__() + self.hparams = hparams + + DeprecatedHparamsModel({}) def test_v1_3_0_deprecated_arguments(tmpdir): From db24567d2c15345dac649b58297c063de9bc88e6 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 24 Feb 2021 01:49:49 +0100 Subject: [PATCH 08/11] tests --- .../callbacks/model_checkpoint.py | 4 +++ tests/checkpointing/test_model_checkpoint.py | 4 +-- tests/trainer/test_trainer.py | 34 +++++-------------- 3 files changed, 14 insertions(+), 28 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index e6176be864e9c..1cdca5f2a9ad7 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -335,6 +335,7 @@ def _format_checkpoint_name( epoch: int, step: int, metrics: Dict[str, Any], + prefix: str = "", ) -> str: if not filename: # filename is not set, use default name @@ -351,6 +352,9 @@ def _format_checkpoint_name( metrics[name] = 0 filename = filename.format(**metrics) + if prefix: + filename = cls.CHECKPOINT_JOIN_CHAR.join([prefix, filename]) + return filename def format_checkpoint_name(self, epoch: int, step: int, metrics: Dict[str, Any], ver: Optional[int] = None) -> str: diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index bd4a02536c5c3..6e9bb071f106f 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -294,9 +294,9 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir): assert ckpt_name == str(Path('.').resolve() / 'epoch=3-step=4.ckpt') # with version - ckpt = ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, filename='name', prefix='test') + ckpt = ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, filename='name') ckpt_name = ckpt.format_checkpoint_name(3, 2, {}, ver=3) - assert ckpt_name == tmpdir / 'test-name-v3.ckpt' + assert ckpt_name == tmpdir / 'name-v3.ckpt' # using slashes ckpt = ModelCheckpoint(monitor='early_stop_on', dirpath=None, filename='{epoch}_{val/loss:.5f}') diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 34305e434575a..2931d8dda8e89 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -421,34 +421,17 @@ def test_dp_output_reduce(): @pytest.mark.parametrize( - ["save_top_k", "save_last", "file_prefix", "expected_files"], + ["save_top_k", "save_last", "expected_files"], [ - pytest.param( - -1, - False, - "", - {"epoch=4.ckpt", "epoch=3.ckpt", "epoch=2.ckpt", "epoch=1.ckpt", "epoch=0.ckpt"}, - id="CASE K=-1 (all)", - ), - pytest.param(1, False, "test_prefix", {"test_prefix-epoch=4.ckpt"}, id="CASE K=1 (2.5, epoch 4)"), - pytest.param(2, False, "", {"epoch=4.ckpt", "epoch=2.ckpt"}, id="CASE K=2 (2.5 epoch 4, 2.8 epoch 2)"), - pytest.param( - 4, - False, - "", - {"epoch=1.ckpt", "epoch=4.ckpt", "epoch=3.ckpt", "epoch=2.ckpt"}, - id="CASE K=4 (save all 4 base)", - ), - pytest.param( - 3, - False, - "", {"epoch=2.ckpt", "epoch=3.ckpt", "epoch=4.ckpt"}, - id="CASE K=3 (save the 2nd, 3rd, 4th model)" - ), - pytest.param(1, True, "", {"epoch=4.ckpt", "last.ckpt"}, id="CASE K=1 (save the 4th model and the last model)"), + pytest.param(-1, False, [f"epoch={i}.ckpt" for i in range(5)], id="CASE K=-1 (all)"), + pytest.param(1, False, {"epoch=4.ckpt"}, id="CASE K=1 (2.5, epoch 4)"), + pytest.param(2, False, [f"epoch={i}.ckpt" for i in (2, 4)], id="CASE K=2 (2.5 epoch 4, 2.8 epoch 2)"), + pytest.param(4, False, [f"epoch={i}.ckpt" for i in range(1, 5)], id="CASE K=4 (save all 4 base)"), + pytest.param(3, False, [f"epoch={i}.ckpt" for i in range(2, 5)], id="CASE K=3 (save the 2nd, 3rd, 4th model)"), + pytest.param(1, True, {"epoch=4.ckpt", "last.ckpt"}, id="CASE K=1 (save the 4th model and the last model)"), ], ) -def test_model_checkpoint_options(tmpdir, save_top_k, save_last, file_prefix, expected_files): +def test_model_checkpoint_options(tmpdir, save_top_k, save_last, expected_files): """Test ModelCheckpoint options.""" def mock_save_function(filepath, *args): @@ -463,7 +446,6 @@ def mock_save_function(filepath, *args): monitor='checkpoint_on', save_top_k=save_top_k, save_last=save_last, - prefix=file_prefix, verbose=1 ) checkpoint_callback.save_function = mock_save_function From 77a4f402c24b3cb13050a6e0a459738b5270f67d Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 24 Feb 2021 09:39:23 +0100 Subject: [PATCH 09/11] Apply suggestions from code review Co-authored-by: Rohit Gupta --- pytorch_lightning/callbacks/model_checkpoint.py | 4 ++-- tests/trainer/test_trainer.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 1cdca5f2a9ad7..0dc1773e279b7 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -89,7 +89,7 @@ class ModelCheckpoint(Callback): mode: one of {min, max}. If ``save_top_k != 0``, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. - For `val_acc`, this should be `max`, for `val_loss` this should be `min`, etc. + For ``'val_acc'``, this should be ``'max'``, for ``'val_loss'`` this should be ``'min'``, etc. save_weights_only: if ``True``, then only the model's weights will be saved (``model.save_weights(filepath)``), else the full model is saved (``model.save(filepath)``). @@ -110,7 +110,7 @@ class ModelCheckpoint(Callback): MisconfigurationException: If ``save_top_k`` is neither ``None`` nor more than or equal to ``-1``, if ``monitor`` is ``None`` and ``save_top_k`` is none of ``None``, ``-1``, and ``0``, or - if ``mode`` is none of ``"min"``, ``"max"``. + if ``mode`` is none of ``"min"`` or ``"max"``. ValueError: If ``trainer.save_checkpoint`` is ``None``. diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 2931d8dda8e89..f14ce984ffb67 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -421,7 +421,7 @@ def test_dp_output_reduce(): @pytest.mark.parametrize( - ["save_top_k", "save_last", "expected_files"], + "save_top_k,save_last,expected_files", [ pytest.param(-1, False, [f"epoch={i}.ckpt" for i in range(5)], id="CASE K=-1 (all)"), pytest.param(1, False, {"epoch=4.ckpt"}, id="CASE K=1 (2.5, epoch 4)"), From db8dd2dd111091f819ea85627443fa03a4b6edd3 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 24 Feb 2021 11:50:42 +0100 Subject: [PATCH 10/11] flake8 --- pytorch_lightning/callbacks/model_checkpoint.py | 2 +- tests/checkpointing/test_model_checkpoint.py | 2 +- tests/deprecated_api/test_remove_1-3.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 0dc1773e279b7..8f2ad2a45a3a2 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -281,7 +281,7 @@ def __init_monitor_mode(self, monitor, mode): } if mode not in mode_dict: - raise MisconfigurationException(f"`mode` can be {', '.join(mode_dict.keys())}, got {mode}") + raise MisconfigurationException(f"`mode` can be {', '.join(mode_dict.keys())} but got {mode}") self.kth_value, self.mode = mode_dict[mode] diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 6e9bb071f106f..50b2ffb83e6c1 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -975,5 +975,5 @@ def test_ckpt_version_after_rerun_same_trainer(tmpdir): def test_model_checkpoint_mode_options(): - with pytest.raises(MisconfigurationException, match="`mode` can be auto, .* got unknown_option"): + with pytest.raises(MisconfigurationException, match="`mode` can be .* but got unknown_option"): ModelCheckpoint(mode="unknown_option") diff --git a/tests/deprecated_api/test_remove_1-3.py b/tests/deprecated_api/test_remove_1-3.py index 86e03b88ef64a..17e767355a784 100644 --- a/tests/deprecated_api/test_remove_1-3.py +++ b/tests/deprecated_api/test_remove_1-3.py @@ -15,7 +15,7 @@ import pytest -from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning import LightningModule def test_v1_3_0_deprecated_arguments(tmpdir): From 4a17f0021d3d2ea02902dd69d23562b63a5a454c Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 24 Feb 2021 12:16:46 +0100 Subject: [PATCH 11/11] ... --- tests/deprecated_api/test_remove_1-3.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/tests/deprecated_api/test_remove_1-3.py b/tests/deprecated_api/test_remove_1-3.py index 17e767355a784..1710bb8777e31 100644 --- a/tests/deprecated_api/test_remove_1-3.py +++ b/tests/deprecated_api/test_remove_1-3.py @@ -18,19 +18,6 @@ from pytorch_lightning import LightningModule -def test_v1_3_0_deprecated_arguments(tmpdir): - - with pytest.deprecated_call(match="The setter for self.hparams in LightningModule is deprecated"): - - class DeprecatedHparamsModel(LightningModule): - - def __init__(self, hparams): - super().__init__() - self.hparams = hparams - - DeprecatedHparamsModel({}) - - def test_v1_3_0_deprecated_arguments(tmpdir): with pytest.deprecated_call(match="The setter for self.hparams in LightningModule is deprecated"):