-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🚀 Feature
It would be helpful to control the depth of modules displayed in ModelSummary when working with deep nested architectures, beyond the "full" and "top" options available.
Motivation
When working with very deep architectures with nested modules, its hard to get an overview of the model architecture being used, as current implementation will print a very long output with repeated blocks of layers when using mode="full" or almost no information when mode="top".
Pitch
I propose adding an optional max_depth parameter to ModelSummary which is used to filter out summary entries with depth > max_depth. The default value would maintain the current functionality.
class ModelSummary(object):
"""*docs*"""
def __init__(self, model, mode: str = MODE_DEFAULT, max_depth: Optional[int] = None):
self._model = model
self._mode = mode
self._layer_summary = self.summarize()
# (proposed max_depth feature):
if max_depth is not None:
# remove summary entries with depth > max_depth
for k in [k for k in self._layer_summary.keys() if k.count(".") > max_depth]:
del self._layer_summary[k]The parameter would also be added to LightningModule.summarize() method, exposing the functionality to all pl modules:
# max_depth usage example
model.summarize(mode="full", max_depth=2)Alternatives
- Should
max_depthbe applied globally to self._layer_summary (as pitched) or only inside the__str__method (thus preserving the full representation if needed?)