diff --git a/CHANGELOG.md b/CHANGELOG.md index f029aeadb6d62..4e8347caa7895 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Add support for summarized model total params size in megabytes ([#5590](https://github.com/PyTorchLightning/pytorch-lightning/pull/5590)) + - Add Support for multiple train loaders ([#1959](https://github.com/PyTorchLightning/pytorch-lightning/pull/1959)) diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index 44c06dfe0f58d..4c1710cd36de0 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -159,6 +159,7 @@ class ModelSummary(object): 132 K Trainable params 0 Non-trainable params 132 K Total params + 0.530 Total estimated model params size (MB) >>> ModelSummary(model, mode='full') # doctest: +NORMALIZE_WHITESPACE | Name | Type | Params | In sizes | Out sizes -------------------------------------------------------------- @@ -169,6 +170,7 @@ class ModelSummary(object): 132 K Trainable params 0 Non-trainable params 132 K Total params + 0.530 Total estimated model params size (MB) """ MODE_TOP = "top" @@ -180,6 +182,7 @@ def __init__(self, model, mode: str = MODE_DEFAULT): self._model = model self._mode = mode self._layer_summary = self.summarize() + self._precision_megabytes = (self._model.precision / 8.0) * 1e-6 # 1 byte -> 8 bits @property def named_modules(self) -> List[Tuple[str, nn.Module]]: @@ -213,6 +216,18 @@ def out_sizes(self) -> List: def param_nums(self) -> List[int]: return [layer.num_parameters for layer in self._layer_summary.values()] + @property + def total_parameters(self) -> int: + return sum(p.numel() for p in self._model.parameters()) + + @property + def trainable_parameters(self) -> int: + return sum(p.numel() for p in self._model.parameters() if p.requires_grad) + + @property + def model_size(self) -> float: + return self.total_parameters * self._precision_megabytes + def summarize(self) -> Dict[str, LayerSummary]: summary = OrderedDict((name, LayerSummary(module)) for name, module in self.named_modules) if self._model.example_input_array is not None: @@ -248,7 +263,7 @@ def __str__(self): """ Makes a summary listing with: - Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes + Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes, Model Size """ arrays = [ [" ", list(map(str, range(len(self._layer_summary))))], @@ -259,11 +274,11 @@ def __str__(self): if self._model.example_input_array is not None: arrays.append(["In sizes", self.in_sizes]) arrays.append(["Out sizes", self.out_sizes]) + total_parameters = self.total_parameters + trainable_parameters = self.trainable_parameters + model_size = self.model_size - trainable_parameters = sum(p.numel() for p in self._model.parameters() if p.requires_grad) - total_parameters = sum(p.numel() for p in self._model.parameters()) - - return _format_summary_table(total_parameters, trainable_parameters, *arrays) + return _format_summary_table(total_parameters, trainable_parameters, model_size, *arrays) def __repr__(self): return str(self) @@ -280,7 +295,7 @@ def parse_batch_shape(batch: Any) -> Union[str, List]: return UNKNOWN_SIZE -def _format_summary_table(total_parameters: int, trainable_parameters: int, *cols) -> str: +def _format_summary_table(total_parameters: int, trainable_parameters: int, model_size: float, *cols) -> str: """ Takes in a number of arrays, each specifying a column in the summary table, and combines them all into one big @@ -316,6 +331,8 @@ def _format_summary_table(total_parameters: int, trainable_parameters: int, *col summary += "Non-trainable params" summary += "\n" + s.format(get_human_readable_count(total_parameters), 10) summary += "Total params" + summary += "\n" + s.format(get_formatted_model_size(model_size), 10) + summary += "Total estimated model params size (MB)" return summary @@ -372,6 +389,8 @@ def get_gpu_memory_map() -> Dict[str, int]: } return gpu_memory_map +def get_formatted_model_size(total_model_size: float) -> float: + return f"{total_model_size:,.3f}" def get_human_readable_count(number: int) -> str: """ diff --git a/tests/core/test_memory.py b/tests/core/test_memory.py index f5f22c7a47bc2..699b248013020 100644 --- a/tests/core/test_memory.py +++ b/tests/core/test_memory.py @@ -17,7 +17,9 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.core.memory import ModelSummary, UNKNOWN_SIZE +from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.base import BoringModel from tests.base.models import ParityModuleRNN @@ -33,6 +35,21 @@ def forward(self, *args, **kwargs): return {'loss': self.parameter.sum()} +class PreCalculatedModel(BoringModel): + """ A model with precalculated total params size in MB for FP16 and FP32. """ + + def __init__(self, precision: int = 32): + super().__init__() + self.layer = nn.Linear(32, 1000, bias=False) # 32K params + self.layer1 = nn.Linear(1000, 218, bias=False) # 218K params + + # calculate model size based on precision. + self.pre_calculated_model_size = 1.0 / (32 / precision) + + def forward(self, x): + x = self.layer(x) + return self.layer1(x) + class UnorderedModel(LightningModule): """ A model in which the layers not defined in order of execution """ @@ -247,3 +264,45 @@ def forward(self, *args, **kwargs): model.example_input_array = example_input summary = model.summarize(mode=mode) assert summary.in_sizes == [expected_size] + + +@pytest.mark.parametrize(['mode'], [ + pytest.param(ModelSummary.MODE_FULL), + pytest.param(ModelSummary.MODE_TOP), +]) +def test_model_size(mode): + """ Test model size is calculated correctly. """ + model = PreCalculatedModel() + summary = model.summarize(mode=mode) + assert model.pre_calculated_model_size == summary.model_size + + +@pytest.mark.parametrize(['mode'], [ + pytest.param(ModelSummary.MODE_FULL), + pytest.param(ModelSummary.MODE_TOP), +]) +def test_empty_model_size(mode): + """ Test empty model size is zero. """ + model = EmptyModule() + summary = model.summarize(mode=mode) + assert 0.0 == summary.model_size + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.") +@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="test requires native AMP.") +@pytest.mark.parametrize('precision', [16, 32]) +def test_model_size_precision(monkeypatch, tmpdir, precision): + """ Test model size for half and full precision. """ + model = PreCalculatedModel(precision) + + # fit model + trainer = Trainer( + default_root_dir=tmpdir, + gpus=1, + max_steps=1, + max_epochs=1, + precision=precision, + ) + trainer.fit(model) + summary = model.summarize() + assert model.pre_calculated_model_size == summary.model_size