From 72309c2ea6b781f68eebdc293d804e1fbdccb788 Mon Sep 17 00:00:00 2001 From: shenoynikhil Date: Sat, 14 May 2022 11:35:39 +0530 Subject: [PATCH 01/12] removed weights summary from trainer and callback connector --- .../trainer/connectors/callback_connector.py | 34 ++++--------------- pytorch_lightning/trainer/trainer.py | 23 ------------- 2 files changed, 7 insertions(+), 50 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 7514e5c85eef7..f3d618342d857 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -50,7 +50,6 @@ def on_trainer_init( default_root_dir: Optional[str], weights_save_path: Optional[str], enable_model_summary: bool, - weights_summary: Optional[str], max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None, accumulate_grad_batches: Optional[Union[int, Dict[int, int]]] = None, ): @@ -88,7 +87,7 @@ def on_trainer_init( self._configure_progress_bar(process_position, enable_progress_bar) # configure the ModelSummary callback - self._configure_model_summary_callback(enable_model_summary, weights_summary) + self._configure_model_summary_callback(enable_model_summary) # accumulated grads self._configure_accumulated_gradients(accumulate_grad_batches) @@ -151,15 +150,7 @@ def _configure_checkpoint_callbacks(self, checkpoint_callback: Optional[bool], e elif enable_checkpointing: self.trainer.callbacks.append(ModelCheckpoint()) - def _configure_model_summary_callback( - self, enable_model_summary: bool, weights_summary: Optional[str] = None - ) -> None: - if weights_summary is None: - rank_zero_deprecation( - "Setting `Trainer(weights_summary=None)` is deprecated in v1.5 and will be removed" - " in v1.7. Please set `Trainer(enable_model_summary=False)` instead." - ) - return + def _configure_model_summary_callback(self, enable_model_summary: bool) -> None: if not enable_model_summary: return @@ -171,21 +162,11 @@ def _configure_model_summary_callback( ) return - if weights_summary == "top": - # special case the default value for weights_summary to preserve backward compatibility - max_depth = 1 - else: - rank_zero_deprecation( - f"Setting `Trainer(weights_summary={weights_summary})` is deprecated in v1.5 and will be removed" - " in v1.7. Please pass `pytorch_lightning.callbacks.model_summary.ModelSummary` with" - " `max_depth` directly to the Trainer's `callbacks` argument instead." - ) - if weights_summary not in ModelSummaryMode.supported_types(): - raise MisconfigurationException( - f"`weights_summary` can be None, {', '.join(ModelSummaryMode.supported_types())}", - f" but got {weights_summary}", - ) - max_depth = ModelSummaryMode.get_max_depth(weights_summary) + # If the user wants to configure a model summary callback, without explicitly passing it, + # to the callbacks list, we will create it with max depth of 1 + # This corresponds to weights_summary == "top" + weights_summary = "top" + max_depth = ModelSummaryMode.get_max_depth(weights_summary) progress_bar_callback = self.trainer.progress_bar_callback is_progress_bar_rich = isinstance(progress_bar_callback, RichProgressBar) @@ -195,7 +176,6 @@ def _configure_model_summary_callback( else: model_summary = ModelSummary(max_depth=max_depth) self.trainer.callbacks.append(model_summary) - self.trainer._weights_summary = weights_summary def _configure_progress_bar(self, process_position: int = 0, enable_progress_bar: bool = True) -> None: progress_bars = [c for c in self.trainer.callbacks if isinstance(c, ProgressBarBase)] diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f72c2a8d08df2..78fa53d9cd1a4 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -169,7 +169,6 @@ def __init__( sync_batchnorm: bool = False, precision: Union[int, str] = 32, enable_model_summary: bool = True, - weights_summary: Optional[str] = "top", weights_save_path: Optional[str] = None, # TODO: Remove in 1.8 num_sanity_val_steps: int = 2, resume_from_checkpoint: Optional[Union[Path, str]] = None, @@ -417,14 +416,6 @@ def __init__( enable_model_summary: Whether to enable model summarization by default. Default: ``True``. - weights_summary: Prints a summary of the weights when training begins. - - .. deprecated:: v1.5 - ``weights_summary`` has been deprecated in v1.5 and will be removed in v1.7. - To disable the summary, pass ``enable_model_summary = False`` to the Trainer. - To customize the summary, pass :class:`~pytorch_lightning.callbacks.model_summary.ModelSummary` - directly to the Trainer's ``callbacks`` argument. - weights_save_path: Where to save weights if specified. Will override default_root_dir for checkpoints only. Use this if for whatever reason you need the checkpoints stored in a different place than the logs written in `default_root_dir`. @@ -507,9 +498,6 @@ def __init__( self._tested_ckpt_path: Optional[str] = None # TODO: remove in v1.8 self._predicted_ckpt_path: Optional[str] = None # TODO: remove in v1.8 - # todo: remove in v1.7 - self._weights_summary: Optional[str] = None - # init callbacks # Declare attributes to be set in _callback_connector on_trainer_init self._callback_connector.on_trainer_init( @@ -521,7 +509,6 @@ def __init__( default_root_dir, weights_save_path, enable_model_summary, - weights_summary, max_time, accumulate_grad_batches, ) @@ -2740,16 +2727,6 @@ def _should_terminate_gracefully(self) -> bool: value = torch.tensor(int(self._terminate_gracefully), device=self.strategy.root_device) return self.strategy.reduce(value, reduce_op="sum") > 0 - @property - def weights_summary(self) -> Optional[str]: - rank_zero_deprecation("`Trainer.weights_summary` is deprecated in v1.5 and will be removed in v1.7.") - return self._weights_summary - - @weights_summary.setter - def weights_summary(self, val: Optional[str]) -> None: - rank_zero_deprecation("Setting `Trainer.weights_summary` is deprecated in v1.5 and will be removed in v1.7.") - self._weights_summary = val - """ Other """ From 712a829d5539d397957cbe4091301a01806b7e21 Mon Sep 17 00:00:00 2001 From: shenoynikhil Date: Sat, 14 May 2022 11:36:21 +0530 Subject: [PATCH 02/12] Updated trainer docs by removing weights_summary arguments --- docs/source/common/trainer.rst | 30 ------------------------------ 1 file changed, 30 deletions(-) diff --git a/docs/source/common/trainer.rst b/docs/source/common/trainer.rst index e63640e99f8ce..9e38b5bc78beb 100644 --- a/docs/source/common/trainer.rst +++ b/docs/source/common/trainer.rst @@ -1598,36 +1598,6 @@ Example:: weights_save_path='my/path' ) -weights_summary -^^^^^^^^^^^^^^^ - -.. warning:: `weights_summary` is deprecated in v1.5 and will be removed in v1.7. Please pass :class:`~pytorch_lightning.callbacks.model_summary.ModelSummary` - directly to the Trainer's ``callbacks`` argument instead. To disable the model summary, - pass ``enable_model_summary = False`` to the Trainer. - - -.. raw:: html - - - -| - -Prints a summary of the weights when training begins. -Options: 'full', 'top', None. - -.. testcode:: - - # default used by the Trainer (ie: print summary of top level modules) - trainer = Trainer(weights_summary="top") - - # print full summary of all modules and submodules - trainer = Trainer(weights_summary="full") - - # don't print a summary - trainer = Trainer(weights_summary=None) - enable_model_summary ^^^^^^^^^^^^^^^^^^^^ From e8b7c849d6ada3ae9f6ddb0ea68b577f71f5db5f Mon Sep 17 00:00:00 2001 From: shenoynikhil Date: Sat, 14 May 2022 11:54:49 +0530 Subject: [PATCH 03/12] updated tests --- tests/callbacks/test_model_summary.py | 28 ++++--------------------- tests/deprecated_api/test_remove_1-7.py | 15 ------------- tests/utilities/test_model_summary.py | 13 ------------ 3 files changed, 4 insertions(+), 52 deletions(-) diff --git a/tests/callbacks/test_model_summary.py b/tests/callbacks/test_model_summary.py index f588d696c4e7e..3b750383da856 100644 --- a/tests/callbacks/test_model_summary.py +++ b/tests/callbacks/test_model_summary.py @@ -29,38 +29,18 @@ def test_model_summary_callback_present_trainer(): assert any(isinstance(cb, ModelSummary) for cb in trainer.callbacks) -def test_model_summary_callback_with_weights_summary_none(): - with pytest.deprecated_call(match=r"weights_summary=None\)` is deprecated"): - trainer = Trainer(weights_summary=None) - assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks) - +def test_model_summary_callback_with_enable_model_summary_false(): trainer = Trainer(enable_model_summary=False) assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks) - trainer = Trainer(enable_model_summary=False, weights_summary="full") - assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks) - - with pytest.deprecated_call(match=r"weights_summary=None\)` is deprecated"): - trainer = Trainer(enable_model_summary=True, weights_summary=None) - assert not any(isinstance(cb, ModelSummary) for cb in trainer.callbacks) +def test_model_summary_callback_with_enable_model_summary_true(): + trainer = Trainer(enable_model_summary=True) + assert any(isinstance(cb, ModelSummary) for cb in trainer.callbacks) -def test_model_summary_callback_with_weights_summary(): - trainer = Trainer(weights_summary="top") model_summary_callback = list(filter(lambda cb: isinstance(cb, ModelSummary), trainer.callbacks))[0] assert model_summary_callback._max_depth == 1 - with pytest.deprecated_call(match=r"weights_summary=full\)` is deprecated"): - trainer = Trainer(weights_summary="full") - model_summary_callback = list(filter(lambda cb: isinstance(cb, ModelSummary), trainer.callbacks))[0] - assert model_summary_callback._max_depth == -1 - - -def test_model_summary_callback_override_weights_summary_flag(): - with pytest.deprecated_call(match=r"weights_summary=None\)` is deprecated"): - trainer = Trainer(callbacks=ModelSummary(), weights_summary=None) - assert any(isinstance(cb, ModelSummary) for cb in trainer.callbacks) - def test_custom_model_summary_callback_summarize(tmpdir): class CustomModelSummary(ModelSummary): diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 270cd7ecd9769..271cbfd55d35a 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -183,21 +183,6 @@ def test_v1_7_0_deprecate_parameter_validation(): from pytorch_lightning.core.decorators import parameter_validation # noqa: F401 -def test_v1_7_0_weights_summary_trainer(tmpdir): - with pytest.deprecated_call(match=r"Setting `Trainer\(weights_summary=full\)` is deprecated in v1.5"): - t = Trainer(weights_summary="full") - - with pytest.deprecated_call(match=r"Setting `Trainer\(weights_summary=None\)` is deprecated in v1.5"): - t = Trainer(weights_summary=None) - - t = Trainer(weights_summary="top") - with pytest.deprecated_call(match=r"`Trainer.weights_summary` is deprecated in v1.5"): - _ = t.weights_summary - - with pytest.deprecated_call(match=r"Setting `Trainer.weights_summary` is deprecated in v1.5"): - t.weights_summary = "blah" - - def test_v1_7_0_deprecated_slurm_job_id(): trainer = Trainer() with pytest.deprecated_call(match="Method `slurm_job_id` is deprecated in v1.6.0 and will be removed in v1.7.0."): diff --git a/tests/utilities/test_model_summary.py b/tests/utilities/test_model_summary.py index b143242fa4dcf..0196430f83888 100644 --- a/tests/utilities/test_model_summary.py +++ b/tests/utilities/test_model_summary.py @@ -139,19 +139,6 @@ def forward(self, inp): return self.head(self.branch1(inp), self.branch2(inp)) -def test_invalid_weights_summary(): - """Test that invalid value for weights_summary raises an error.""" - model = LightningModule() - - with pytest.raises( - MisconfigurationException, match="`weights_summary` can be None, .* got temp" - ), pytest.deprecated_call(match="weights_summary=temp)` is deprecated"): - Trainer(weights_summary="temp") - - with pytest.raises(ValueError, match="max_depth` can be .* got temp"): - ModelSummary(model, max_depth="temp") - - @pytest.mark.parametrize("max_depth", [-1, 1]) def test_empty_model_summary_shapes(max_depth): """Test that the summary works for models that have no submodules.""" From 65f2313b04be4106868512b7ca57d631d1149ed9 Mon Sep 17 00:00:00 2001 From: shenoynikhil Date: Sat, 14 May 2022 12:01:49 +0530 Subject: [PATCH 04/12] Made changes to adhere to flake8 tests --- tests/callbacks/test_model_summary.py | 2 -- tests/utilities/test_model_summary.py | 1 - 2 files changed, 3 deletions(-) diff --git a/tests/callbacks/test_model_summary.py b/tests/callbacks/test_model_summary.py index 3b750383da856..d9c531da6e371 100644 --- a/tests/callbacks/test_model_summary.py +++ b/tests/callbacks/test_model_summary.py @@ -13,8 +13,6 @@ # limitations under the License. from typing import List, Union -import pytest - from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelSummary from tests.helpers.boring_model import BoringModel diff --git a/tests/utilities/test_model_summary.py b/tests/utilities/test_model_summary.py index 0196430f83888..3af152cbda2e6 100644 --- a/tests/utilities/test_model_summary.py +++ b/tests/utilities/test_model_summary.py @@ -19,7 +19,6 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_9 -from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_summary import ModelSummary, summarize, UNKNOWN_SIZE from tests.helpers import BoringModel from tests.helpers.advanced_models import ParityModuleRNN From 1b96c471418eae2f0a7a84b77160a622d98b6703 Mon Sep 17 00:00:00 2001 From: shenoynikhil Date: Sat, 14 May 2022 12:17:31 +0530 Subject: [PATCH 05/12] Updated CHANGELOG.md and tests/callbacks/test_model_summary with comment --- CHANGELOG.md | 3 +++ tests/callbacks/test_model_summary.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c803239d5fccb..ad535cec57e3c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -112,6 +112,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Removed +- Removed the deprecated `weights_summary` argument from the `Trainer` constructor ([#13070](https://github.com/PyTorchLightning/pytorch-lightning/pull/13070)) + + - Removed the deprecated `TestTubeLogger` ([#12859](https://github.com/PyTorchLightning/pytorch-lightning/pull/12859)) diff --git a/tests/callbacks/test_model_summary.py b/tests/callbacks/test_model_summary.py index d9c531da6e371..765911b8be9a1 100644 --- a/tests/callbacks/test_model_summary.py +++ b/tests/callbacks/test_model_summary.py @@ -36,6 +36,8 @@ def test_model_summary_callback_with_enable_model_summary_true(): trainer = Trainer(enable_model_summary=True) assert any(isinstance(cb, ModelSummary) for cb in trainer.callbacks) + # Default value of max_depth is set as 1, when enable_model_summary is True + # and ModelSummary is not passed in callbacks list model_summary_callback = list(filter(lambda cb: isinstance(cb, ModelSummary), trainer.callbacks))[0] assert model_summary_callback._max_depth == 1 From 0b656315878cd31584b45bb627c4bd9f53ec6ab9 Mon Sep 17 00:00:00 2001 From: shenoynikhil Date: Sat, 14 May 2022 12:34:45 +0530 Subject: [PATCH 06/12] Updated CHANGELOG --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 88b7281092420..e4e8efcca5e14 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -112,6 +112,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Removed +- Removed the deprecated `weights_summary` argument from the `Trainer` constructor ([#13070](https://github.com/PyTorchLightning/pytorch-lightning/pull/13070)) + - Removed the deprecated `checkpoint_callback` argument from the `Trainer` constructor ([#13027](https://github.com/PyTorchLightning/pytorch-lightning/pull/13027)) From 15cbcb86e5f1e00175142d6325b812adf1687c37 Mon Sep 17 00:00:00 2001 From: shenoynikhil Date: Sat, 14 May 2022 12:39:06 +0530 Subject: [PATCH 07/12] Removed use of ModelSummaryMode which uses weights_summary --- pytorch_lightning/trainer/connectors/callback_connector.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 83cc47fb51556..d91de5e417bb7 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -155,9 +155,7 @@ def _configure_model_summary_callback(self, enable_model_summary: bool) -> None: # If the user wants to configure a model summary callback, without explicitly passing it, # to the callbacks list, we will create it with max depth of 1 - # This corresponds to weights_summary == "top" - weights_summary = "top" - max_depth = ModelSummaryMode.get_max_depth(weights_summary) + max_depth = 1 progress_bar_callback = self.trainer.progress_bar_callback is_progress_bar_rich = isinstance(progress_bar_callback, RichProgressBar) From 7d8bec928c99d112fa7dc4714ed078651bca1edd Mon Sep 17 00:00:00 2001 From: shenoynikhil Date: Sat, 14 May 2022 12:41:31 +0530 Subject: [PATCH 08/12] removed unnecessary import --- pytorch_lightning/trainer/connectors/callback_connector.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index d91de5e417bb7..71d4fc862330e 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -28,7 +28,6 @@ ) from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary from pytorch_lightning.callbacks.timer import Timer -from pytorch_lightning.utilities.enums import ModelSummaryMode from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0 from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info From 390ad7ca7ebc8cda439e219602c21276097ff629 Mon Sep 17 00:00:00 2001 From: shenoynikhil Date: Sat, 14 May 2022 23:46:33 +0530 Subject: [PATCH 09/12] Updated model summary with invalid max depth test --- tests/utilities/test_model_summary.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/utilities/test_model_summary.py b/tests/utilities/test_model_summary.py index 3af152cbda2e6..ebdc85b8b970e 100644 --- a/tests/utilities/test_model_summary.py +++ b/tests/utilities/test_model_summary.py @@ -138,6 +138,14 @@ def forward(self, inp): return self.head(self.branch1(inp), self.branch2(inp)) +def test_invalid_weights_summary(): + """Test that invalid value for weights_summary raises an error.""" + model = LightningModule() + + with pytest.raises(ValueError, match="max_depth` can be .* got temp"): + ModelSummary(model, max_depth="temp") + + @pytest.mark.parametrize("max_depth", [-1, 1]) def test_empty_model_summary_shapes(max_depth): """Test that the summary works for models that have no submodules.""" From ec79de67db4eb9cbeca93847126cf9adb86831fc Mon Sep 17 00:00:00 2001 From: shenoynikhil Date: Sat, 14 May 2022 23:46:54 +0530 Subject: [PATCH 10/12] updated callback connector for weights summary --- .../trainer/connectors/callback_connector.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 71d4fc862330e..ceb8595d6e195 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -152,17 +152,13 @@ def _configure_model_summary_callback(self, enable_model_summary: bool) -> None: ) return - # If the user wants to configure a model summary callback, without explicitly passing it, - # to the callbacks list, we will create it with max depth of 1 - max_depth = 1 - progress_bar_callback = self.trainer.progress_bar_callback is_progress_bar_rich = isinstance(progress_bar_callback, RichProgressBar) if progress_bar_callback is not None and is_progress_bar_rich: - model_summary = RichModelSummary(max_depth=max_depth) + model_summary = RichModelSummary() else: - model_summary = ModelSummary(max_depth=max_depth) + model_summary = ModelSummary() self.trainer.callbacks.append(model_summary) def _configure_progress_bar(self, process_position: int = 0, enable_progress_bar: bool = True) -> None: From ad9c267029509993f9ff62e8c75156b67fd8fb24 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sun, 15 May 2022 11:33:49 +0530 Subject: [PATCH 11/12] update test --- tests/utilities/test_model_summary.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/utilities/test_model_summary.py b/tests/utilities/test_model_summary.py index ebdc85b8b970e..11b6c62387547 100644 --- a/tests/utilities/test_model_summary.py +++ b/tests/utilities/test_model_summary.py @@ -138,8 +138,8 @@ def forward(self, inp): return self.head(self.branch1(inp), self.branch2(inp)) -def test_invalid_weights_summary(): - """Test that invalid value for weights_summary raises an error.""" +def test_invalid_max_depth(): + """Test that invalid value for max_depth raises an error.""" model = LightningModule() with pytest.raises(ValueError, match="max_depth` can be .* got temp"): From 39a9d54a7ef49a8db65777915d1bc16553136a75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 18 May 2022 17:08:42 +0200 Subject: [PATCH 12/12] Remove ModelSummaryMode too --- pytorch_lightning/utilities/__init__.py | 1 - pytorch_lightning/utilities/enums.py | 32 ------------------------- tests/utilities/test_enums.py | 14 +---------- 3 files changed, 1 insertion(+), 46 deletions(-) diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 79194d16f918e..c1d64b8ae7808 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -24,7 +24,6 @@ DistributedType, GradClipAlgorithmType, LightningEnum, - ModelSummaryMode, ) from pytorch_lightning.utilities.grads import grad_norm # noqa: F401 from pytorch_lightning.utilities.imports import ( # noqa: F401 diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index de2a0af661c2a..f4b0f29d8be41 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -199,38 +199,6 @@ class AutoRestartBatchKeys(LightningEnum): PL_RESTART_META = "__pl_restart_meta" -class ModelSummaryMode(LightningEnum): - # TODO: remove in v1.6 (as `mode` would be deprecated for `max_depth`) - """Define the Model Summary mode to be used. - - Can be one of - - `top`: only the top-level modules will be recorded (the children of the root module) - - `full`: summarizes all layers and their submodules in the root module - - >>> # you can match the type with string - >>> ModelSummaryMode.TOP == 'TOP' - True - >>> # which is case invariant - >>> ModelSummaryMode.TOP in ('top', 'FULL') - True - """ - - TOP = "top" - FULL = "full" - - @staticmethod - def get_max_depth(mode: str) -> int: - if mode == ModelSummaryMode.TOP: - return 1 - if mode == ModelSummaryMode.FULL: - return -1 - raise ValueError(f"`mode` can be {', '.join(list(ModelSummaryMode))}, got {mode}.") - - @staticmethod - def supported_types() -> list[str]: - return [x.value for x in ModelSummaryMode] - - class _StrategyType(LightningEnum): """Define type of training strategy. diff --git a/tests/utilities/test_enums.py b/tests/utilities/test_enums.py index 99158e2e83c79..dcd4410952308 100644 --- a/tests/utilities/test_enums.py +++ b/tests/utilities/test_enums.py @@ -11,9 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import pytest - -from pytorch_lightning.utilities.enums import _AcceleratorType, GradClipAlgorithmType, ModelSummaryMode, PrecisionType +from pytorch_lightning.utilities.enums import _AcceleratorType, GradClipAlgorithmType, PrecisionType def test_consistency(): @@ -34,16 +32,6 @@ def test_precision_supported_types(): assert not PrecisionType.supported_type("invalid") -def test_model_summary_mode(): - assert ModelSummaryMode.supported_types() == ["top", "full"] - assert ModelSummaryMode.TOP in ("top", "full") - assert ModelSummaryMode.get_max_depth("top") == 1 - assert ModelSummaryMode.get_max_depth("full") == -1 - - with pytest.raises(ValueError, match=f"`mode` can be {', '.join(list(ModelSummaryMode))}, got invalid."): - ModelSummaryMode.get_max_depth("invalid") - - def test_gradient_clip_algorithms(): assert GradClipAlgorithmType.supported_types() == ["value", "norm"] assert GradClipAlgorithmType.supported_type("norm")