Skip to content

Commit 062800a

Browse files
authored
Fix invalid value for weights_summary (#5296)
* Fix weights_summary * use mode * fix * optional * what was I thinking
1 parent 371daea commit 062800a

File tree

4 files changed

+43
-15
lines changed

4 files changed

+43
-15
lines changed

pytorch_lightning/core/lightning.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@
1414

1515
"""nn.Module with additional great features."""
1616

17-
from abc import ABC
18-
from argparse import Namespace
1917
import collections
2018
import copy
2119
import inspect
2220
import os
23-
from pathlib import Path
2421
import re
2522
import tempfile
23+
from abc import ABC
24+
from argparse import Namespace
25+
from pathlib import Path
2626
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
2727

2828
import torch
@@ -1327,9 +1327,17 @@ def tbptt_split_batch(self, batch, split_size):
13271327

13281328
return splits
13291329

1330-
def summarize(self, mode: str = ModelSummary.MODE_DEFAULT) -> ModelSummary:
1331-
model_summary = ModelSummary(self, mode=mode)
1332-
log.info("\n" + str(model_summary))
1330+
def summarize(self, mode: Optional[str] = ModelSummary.MODE_DEFAULT) -> Optional[ModelSummary]:
1331+
model_summary = None
1332+
1333+
if mode in ModelSummary.MODES:
1334+
model_summary = ModelSummary(self, mode=mode)
1335+
log.info("\n" + str(model_summary))
1336+
elif mode is not None:
1337+
raise MisconfigurationException(
1338+
f"`mode` can be None, {', '.join(ModelSummary.MODES)}, got {mode}"
1339+
)
1340+
13331341
return model_summary
13341342

13351343
def freeze(self) -> None:

pytorch_lightning/trainer/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,6 @@ def __init__(
311311
self.plugin_connector = PluginConnector(self)
312312

313313
# training state
314-
self.weights_summary = weights_summary
315314
self.model = None
316315
self.shown_warnings = set()
317316

@@ -374,7 +373,8 @@ def __init__(
374373
max_steps,
375374
min_steps,
376375
num_sanity_val_steps,
377-
automatic_optimization
376+
automatic_optimization,
377+
weights_summary,
378378
)
379379
self.evaluation_loop.on_trainer_init()
380380

pytorch_lightning/trainer/training_loop.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,14 @@ def __init__(self, trainer):
4949
self._cur_grad_norm_dict = None
5050

5151
def on_trainer_init(
52-
self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps, automatic_optimization
52+
self,
53+
max_epochs,
54+
min_epochs,
55+
max_steps,
56+
min_steps,
57+
num_sanity_val_steps,
58+
automatic_optimization,
59+
weights_summary,
5360
):
5461
self.trainer.global_step = 0
5562
self.trainer.current_epoch = 0
@@ -73,6 +80,12 @@ def on_trainer_init(
7380
else:
7481
self.trainer.num_sanity_val_steps = num_sanity_val_steps
7582

83+
self.trainer.weights_summary = weights_summary
84+
if weights_summary is not None and weights_summary not in ModelSummary.MODES:
85+
raise MisconfigurationException(
86+
f"`weights_summary` can be None, {', '.join(ModelSummary.MODES)}, got {weights_summary}"
87+
)
88+
7689
@property
7790
def num_optimizers(self):
7891
num_optimizers = len(self.get_optimizers_iterable())
@@ -161,11 +174,8 @@ def setup_training(self, model: LightningModule):
161174
ref_model.on_pretrain_routine_start()
162175

163176
# print model summary
164-
if self.trainer.is_global_zero and self.trainer.weights_summary is not None and not self.trainer.testing:
165-
if self.trainer.weights_summary in ModelSummary.MODES:
166-
ref_model.summarize(mode=self.trainer.weights_summary)
167-
else:
168-
raise MisconfigurationException("weights_summary can be None, " + ", ".join(ModelSummary.MODES))
177+
if self.trainer.is_global_zero and not self.trainer.testing:
178+
ref_model.summarize(mode=self.trainer.weights_summary)
169179

170180
# track model now.
171181
# if cluster resets state, the model will update with the saved weights

tests/core/test_memory.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
import torch
1616
import torch.nn as nn
1717

18-
from pytorch_lightning import LightningModule
18+
from pytorch_lightning import LightningModule, Trainer
1919
from pytorch_lightning.core.memory import UNKNOWN_SIZE, ModelSummary
20+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2021
from tests.base.models import ParityModuleRNN
2122

2223

@@ -68,6 +69,15 @@ def forward(self, x):
6869
return self.reduce(self.embed(x))
6970

7071

72+
def test_invalid_weights_summmary():
73+
""" Test that invalid value for weights_summary raises an error. """
74+
with pytest.raises(MisconfigurationException, match='`mode` can be None, .* got temp'):
75+
UnorderedModel().summarize(mode='temp')
76+
77+
with pytest.raises(MisconfigurationException, match='`weights_summary` can be None, .* got temp'):
78+
Trainer(weights_summary='temp')
79+
80+
7181
@pytest.mark.parametrize(['mode'], [
7282
pytest.param(ModelSummary.MODE_FULL),
7383
pytest.param(ModelSummary.MODE_TOP),

0 commit comments

Comments
 (0)