Skip to content

Commit a479cfa

Browse files
committed
handle max_depth validation at init
update test
1 parent 14d4d33 commit a479cfa

File tree

2 files changed

+6
-8
lines changed

2 files changed

+6
-8
lines changed

pytorch_lightning/core/memory.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,9 @@ def __init__(self, model, mode: Optional[str] = None, max_depth: Optional[int] =
206206
from pytorch_lightning.utilities.exceptions import MisconfigurationException
207207
raise MisconfigurationException(f"`mode` can be {', '.join(ModelSummary.MODES)}, got {mode}.")
208208

209+
if not isinstance(max_depth, int) or max_depth < -1:
210+
raise ValueError(f"`max_depth` can be -1, 0 or > 0, got {max_depth}.")
211+
209212
self._max_depth = max_depth
210213
self._layer_summary = self.summarize()
211214
# 1 byte -> 8 bits
@@ -220,14 +223,9 @@ def named_modules(self) -> List[Tuple[str, nn.Module]]:
220223
elif self._max_depth == 1:
221224
# the children are the top-level modules
222225
mods = self._model.named_children()
223-
elif self._max_depth == -1 or self._max_depth > 1:
226+
else:
224227
mods = self._model.named_modules()
225228
mods = list(mods)[1:] # do not include root module (LightningModule)
226-
else:
227-
raise ValueError(
228-
f"Invalid value for max_depth encountered. "
229-
f"Expected -1, 0 or >0, but got {self._max_depth}."
230-
)
231229
return list(mods)
232230

233231
@property

tests/core/test_memory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ def test_max_depth_param(max_depth):
384384
assert lname.count(".") < max_depth
385385

386386

387-
@pytest.mark.parametrize('max_depth', [-99, -2])
387+
@pytest.mark.parametrize('max_depth', [-99, -2, "invalid"])
388388
def test_raise_invalid_max_depth_value(max_depth):
389-
with pytest.raises(ValueError, match="Invalid value for max_depth encountered"):
389+
with pytest.raises(ValueError, match=f"`max_depth` can be -1, 0 or > 0, got {max_depth}"):
390390
DeepNestedModel().summarize(max_depth=max_depth)

0 commit comments

Comments
 (0)