Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed deprecated `utilities.distributed.rank_zero_{warn/deprecation}` ([#10451](https://github.com/PyTorchLightning/pytorch-lightning/pull/10451))


- Removed deprecated `mode` argument from `ModelSummary` class ([#10449](https://github.com/PyTorchLightning/pytorch-lightning/pull/10449))


- Removed deprecated `Trainer.train_loop` property in favor of `Trainer.fit_loop` ([#10482](https://github.com/PyTorchLightning/pytorch-lightning/pull/10482))


Expand Down
9 changes: 2 additions & 7 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1689,19 +1689,14 @@ def tbptt_split_batch(self, batch, split_size):

return splits

def summarize(self, mode: Optional[str] = "top", max_depth: Optional[int] = None) -> Optional[ModelSummary]:
def summarize(self, max_depth: int = 1) -> ModelSummary:
"""Summarize this LightningModule.

.. deprecated:: v1.5
This method was deprecated in v1.5 in favor of `pytorch_lightning.utilities.model_summary.summarize`
and will be removed in v1.7.

Args:
mode: Can be either ``'top'`` (summarize only direct submodules) or ``'full'`` (summarize all layers).

.. deprecated:: v1.4
This parameter was deprecated in v1.4 in favor of `max_depth` and will be removed in v1.6.

max_depth: The maximum depth of layer nesting that the summary will include. A value of 0 turns the
layer summary off. Default: 1.

Expand All @@ -1714,7 +1709,7 @@ def summarize(self, mode: Optional[str] = "top", max_depth: Optional[int] = None
stacklevel=6,
)

return summarize(self, mode, max_depth)
return summarize(self, max_depth)

def freeze(self) -> None:
r"""
Expand Down
53 changes: 4 additions & 49 deletions pytorch_lightning/utilities/model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
from torch.utils.hooks import RemovableHandle

import pytorch_lightning as pl
from pytorch_lightning.utilities import AMPType, DeviceType, ModelSummaryMode, rank_zero_deprecation
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities import AMPType, DeviceType
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8
from pytorch_lightning.utilities.warnings import WarningCache

Expand Down Expand Up @@ -130,13 +129,6 @@ class ModelSummary:

Args:
model: The model to summarize (also referred to as the root module).
mode: Can be one of

- `top` (default): only the top-level modules will be recorded (the children of the root module)
- `full`: summarizes all layers and their submodules in the root module

.. deprecated:: v1.4
This parameter was deprecated in v1.4 in favor of `max_depth` and will be removed in v1.6.

max_depth: Maximum depth of modules to show. Use -1 to show all modules or 0 to show no
summary. Defaults to 1.
Expand Down Expand Up @@ -186,22 +178,9 @@ class ModelSummary:
0.530 Total estimated model params size (MB)
"""

def __init__(self, model: "pl.LightningModule", mode: Optional[str] = None, max_depth: Optional[int] = 1) -> None:
def __init__(self, model: "pl.LightningModule", max_depth: int = 1) -> None:
self._model = model

# temporary mapping from mode to max_depth
if max_depth is None or mode is not None:
if mode in ModelSummaryMode.supported_types():
max_depth = ModelSummaryMode.get_max_depth(mode)
rank_zero_deprecation(
"Argument `mode` in `ModelSummary` is deprecated in v1.4"
f" and will be removed in v1.6. Use `max_depth={max_depth}` to replicate `mode={mode}` behaviour."
)
else:
raise MisconfigurationException(
f"`mode` can be {', '.join(ModelSummaryMode.supported_types())}, got {mode}."
)

if not isinstance(max_depth, int) or max_depth < -1:
raise ValueError(f"`max_depth` can be -1, 0 or > 0, got {max_depth}.")

Expand Down Expand Up @@ -436,40 +415,16 @@ def _is_lazy_weight_tensor(p: Tensor) -> bool:
return False


def summarize(
lightning_module: "pl.LightningModule", mode: Optional[str] = None, max_depth: Optional[int] = None
) -> ModelSummary:
def summarize(lightning_module: "pl.LightningModule", max_depth: int = 1) -> ModelSummary:
"""Summarize the LightningModule specified by `lightning_module`.

Args:
lightning_module: `LightningModule` to summarize.
mode: Can be either ``'top'`` (summarize only direct submodules) or ``'full'`` (summarize all layers).

.. deprecated:: v1.4
This parameter was deprecated in v1.4 in favor of `max_depth` and will be removed in v1.6.

max_depth: The maximum depth of layer nesting that the summary will include. A value of 0 turns the
layer summary off. Default: 1.

Return:
The model summary object
"""

# temporary mapping from mode to max_depth
if max_depth is None:
if mode is None:
model_summary = ModelSummary(lightning_module, max_depth=1)
elif mode in ModelSummaryMode.supported_types():
max_depth = ModelSummaryMode.get_max_depth(mode)
rank_zero_deprecation(
"Argument `mode` in `LightningModule.summarize` is deprecated in v1.4"
f" and will be removed in v1.6. Use `max_depth={max_depth}` to replicate `mode={mode}` behavior."
)
model_summary = ModelSummary(lightning_module, max_depth=max_depth)
else:
raise MisconfigurationException(
f"`mode` can be None, {', '.join(ModelSummaryMode.supported_types())}, got {mode}"
)
else:
model_summary = ModelSummary(lightning_module, max_depth=max_depth)
return model_summary
return ModelSummary(lightning_module, max_depth=max_depth)
10 changes: 0 additions & 10 deletions tests/deprecated_api/test_remove_1-6.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

from pytorch_lightning import Trainer
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.model_summary import ModelSummary
from tests.helpers import BoringModel


Expand Down Expand Up @@ -69,15 +68,6 @@ def test_v1_6_0_is_overridden_model():
assert not is_overridden("foo", model=model)


def test_v1_6_0_deprecated_model_summary_mode(tmpdir):
model = BoringModel()
with pytest.deprecated_call(match="Argument `mode` in `ModelSummary` is deprecated in v1.4"):
ModelSummary(model, mode="top")

with pytest.deprecated_call(match="Argument `mode` in `LightningModule.summarize` is deprecated in v1.4"):
model.summarize(mode="top")


def test_v1_6_0_deprecated_disable_validation():
trainer = Trainer()
with pytest.deprecated_call(match="disable_validation` is deprecated in v1.4"):
Expand Down
21 changes: 0 additions & 21 deletions tests/utilities/test_model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,17 +143,11 @@ def test_invalid_weights_summmary():
"""Test that invalid value for weights_summary raises an error."""
model = LightningModule()

with pytest.raises(MisconfigurationException, match="`mode` can be None, .* got temp"):
summarize(model, mode="temp")

with pytest.raises(
MisconfigurationException, match="`weights_summary` can be None, .* got temp"
), pytest.deprecated_call(match="weights_summary=temp)` is deprecated"):
Trainer(weights_summary="temp")

with pytest.raises(MisconfigurationException, match="mode` can be .* got temp"):
ModelSummary(model, mode="temp")

with pytest.raises(ValueError, match="max_depth` can be .* got temp"):
ModelSummary(model, max_depth="temp")

Expand Down Expand Up @@ -334,21 +328,6 @@ def test_lazy_model_summary():
assert summary.trainable_parameters == 7


def test_max_depth_equals_mode_interface():
"""Test summarize(model, full/top) interface mapping matches max_depth."""
model = DeepNestedModel()

with pytest.deprecated_call(match="mode` in `LightningModule.summarize` is deprecated"):
summary_top = summarize(model, mode="top")
summary_0 = summarize(model, max_depth=1)
assert str(summary_top) == str(summary_0)

with pytest.deprecated_call(match="mode` in `LightningModule.summarize` is deprecated"):
summary_full = summarize(model, mode="full")
summary_minus1 = summarize(model, max_depth=-1)
assert str(summary_full) == str(summary_minus1)


@pytest.mark.parametrize("max_depth", [-1, 0, 1, 3, 999])
def test_max_depth_param(max_depth):
"""Test that only the modules up to the desired depth are shown."""
Expand Down