Skip to content

Commit 603ac0f

Browse files
committed
Add rich for Model Summary
1 parent f080a31 commit 603ac0f

File tree

3 files changed

+76
-12
lines changed

3 files changed

+76
-12
lines changed

pytorch_lightning/callbacks/progress/rich_progress.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ def render(self, task) -> Text:
8080
class RichProgressBar(ProgressBarBase):
8181
def __init__(self, refresh_rate: int = 1):
8282
if not _RICH_AVAILABLE:
83-
raise MisconfigurationException("Rich progress bar is not available")
83+
raise MisconfigurationException(
84+
"`RichProgressBar` requires `rich` to be installed. Install it by running `pip install rich`."
85+
)
8486
super().__init__()
8587
self._refresh_rate = refresh_rate
8688
self._enabled = True

pytorch_lightning/trainer/trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
import pytorch_lightning as pl
2727
from pytorch_lightning.accelerators import Accelerator, IPUAccelerator
28-
from pytorch_lightning.callbacks import Callback
28+
from pytorch_lightning.callbacks import Callback, RichProgressBar
2929
from pytorch_lightning.core.datamodule import LightningDataModule
3030
from pytorch_lightning.loggers import LightningLoggerBase
3131
from pytorch_lightning.loops import TrainingBatchLoop, TrainingEpochLoop
@@ -1029,8 +1029,9 @@ def _pre_training_routine(self):
10291029

10301030
# print model summary
10311031
if self.is_global_zero and self.weights_summary is not None and not self.testing:
1032+
use_rich = isinstance(self.progress_bar_callback, RichProgressBar)
10321033
max_depth = ModelSummary.MODES[self.weights_summary]
1033-
summarize(ref_model, max_depth=max_depth)
1034+
summarize(ref_model, max_depth=max_depth, use_rich=use_rich)
10341035

10351036
# on pretrain routine end
10361037
self.on_pretrain_routine_end()

pytorch_lightning/utilities/model_summary.py

Lines changed: 70 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,13 @@
2525
import pytorch_lightning as pl
2626
from pytorch_lightning.utilities import AMPType, DeviceType, rank_zero_deprecation
2727
from pytorch_lightning.utilities.exceptions import MisconfigurationException
28-
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8
28+
from pytorch_lightning.utilities.imports import _RICH_AVAILABLE, _TORCH_GREATER_EQUAL_1_8
2929
from pytorch_lightning.utilities.warnings import WarningCache
3030

31+
if _RICH_AVAILABLE:
32+
from rich.console import Console
33+
from rich.table import Table
34+
3135
log = logging.getLogger(__name__)
3236
warning_cache = WarningCache()
3337

@@ -299,12 +303,7 @@ def _forward_example_input(self) -> None:
299303
model(input_)
300304
model.train(mode) # restore mode of module
301305

302-
def __str__(self):
303-
"""
304-
Makes a summary listing with:
305-
306-
Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes, Model Size
307-
"""
306+
def _get_summary_data(self):
308307
arrays = [
309308
[" ", list(map(str, range(len(self._layer_summary))))],
310309
["Name", self.layer_names],
@@ -314,6 +313,62 @@ def __str__(self):
314313
if self._model.example_input_array is not None:
315314
arrays.append(["In sizes", self.in_sizes])
316315
arrays.append(["Out sizes", self.out_sizes])
316+
317+
return arrays
318+
319+
def print_rich_summary(self):
320+
321+
if not _RICH_AVAILABLE:
322+
raise MisconfigurationException(
323+
"`print_rich_summary` requires `rich` to be installed." " Install it by running `pip install rich`."
324+
)
325+
326+
arrays = self._get_summary_data()
327+
total_parameters = self.total_parameters
328+
trainable_parameters = self.trainable_parameters
329+
model_size = self.model_size
330+
331+
console = Console()
332+
333+
table = Table(title="Model Summary")
334+
335+
table.add_column(" ")
336+
table.add_column("Name", arrays[1][1], justify="left", style="cyan", no_wrap=True)
337+
table.add_column("Type", arrays[2][1], style="magenta")
338+
table.add_column("Params", arrays[3][1], justify="right", style="green")
339+
340+
rows = list(zip(*(arr[1] for arr in arrays)))
341+
for row in rows:
342+
table.add_row(*row)
343+
344+
console.print(table)
345+
346+
# Formatting
347+
s = "{:<{}}"
348+
349+
parameters = []
350+
for param in [trainable_parameters, total_parameters - trainable_parameters, total_parameters, model_size]:
351+
parameters.append(s.format(get_human_readable_count(param), 10))
352+
353+
grid = Table.grid(expand=True)
354+
grid.add_column()
355+
grid.add_column()
356+
357+
grid.add_row(f"[bold]Trainable params[/]: {parameters[0]}")
358+
grid.add_row(f"[bold]Non-trainable params[/]: {parameters[1]}")
359+
grid.add_row(f"[bold]Total params[/]: {parameters[2]}")
360+
grid.add_row(f"[bold]Total estimated model params size (MB)[/]: {parameters[3]}")
361+
362+
console.print(grid)
363+
364+
def __str__(self):
365+
"""
366+
Makes a summary listing with:
367+
368+
Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes, Model Size
369+
"""
370+
arrays = self._get_summary_data()
371+
317372
total_parameters = self.total_parameters
318373
trainable_parameters = self.trainable_parameters
319374
model_size = self.model_size
@@ -435,7 +490,10 @@ def _is_lazy_weight_tensor(p: Tensor) -> bool:
435490

436491

437492
def summarize(
438-
lightning_module: "pl.LightningModule", mode: Optional[str] = "top", max_depth: Optional[int] = None
493+
lightning_module: "pl.LightningModule",
494+
mode: Optional[str] = "top",
495+
max_depth: Optional[int] = None,
496+
use_rich: bool = False,
439497
) -> Optional[ModelSummary]:
440498
"""
441499
Summarize the LightningModule specified by `lightning_module`.
@@ -467,5 +525,8 @@ def summarize(
467525
raise MisconfigurationException(f"`mode` can be None, {', '.join(ModelSummary.MODES)}, got {mode}")
468526
else:
469527
model_summary = ModelSummary(lightning_module, max_depth=max_depth)
470-
log.info("\n" + str(model_summary))
528+
if use_rich:
529+
model_summary.print_rich_summary()
530+
else:
531+
log.info("\n" + str(model_summary))
471532
return model_summary

0 commit comments

Comments
 (0)