2525import pytorch_lightning as pl
2626from pytorch_lightning .utilities import AMPType , DeviceType , rank_zero_deprecation
2727from pytorch_lightning .utilities .exceptions import MisconfigurationException
28- from pytorch_lightning .utilities .imports import _TORCH_GREATER_EQUAL_1_8
28+ from pytorch_lightning .utilities .imports import _RICH_AVAILABLE , _TORCH_GREATER_EQUAL_1_8
2929from pytorch_lightning .utilities .warnings import WarningCache
3030
31+ if _RICH_AVAILABLE :
32+ from rich .console import Console
33+ from rich .table import Table
34+
3135log = logging .getLogger (__name__ )
3236warning_cache = WarningCache ()
3337
@@ -299,12 +303,7 @@ def _forward_example_input(self) -> None:
299303 model (input_ )
300304 model .train (mode ) # restore mode of module
301305
302- def __str__ (self ):
303- """
304- Makes a summary listing with:
305-
306- Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes, Model Size
307- """
306+ def _get_summary_data (self ):
308307 arrays = [
309308 [" " , list (map (str , range (len (self ._layer_summary ))))],
310309 ["Name" , self .layer_names ],
@@ -314,6 +313,62 @@ def __str__(self):
314313 if self ._model .example_input_array is not None :
315314 arrays .append (["In sizes" , self .in_sizes ])
316315 arrays .append (["Out sizes" , self .out_sizes ])
316+
317+ return arrays
318+
319+ def print_rich_summary (self ):
320+
321+ if not _RICH_AVAILABLE :
322+ raise MisconfigurationException (
323+ "`print_rich_summary` requires `rich` to be installed." " Install it by running `pip install rich`."
324+ )
325+
326+ arrays = self ._get_summary_data ()
327+ total_parameters = self .total_parameters
328+ trainable_parameters = self .trainable_parameters
329+ model_size = self .model_size
330+
331+ console = Console ()
332+
333+ table = Table (title = "Model Summary" )
334+
335+ table .add_column (" " )
336+ table .add_column ("Name" , arrays [1 ][1 ], justify = "left" , style = "cyan" , no_wrap = True )
337+ table .add_column ("Type" , arrays [2 ][1 ], style = "magenta" )
338+ table .add_column ("Params" , arrays [3 ][1 ], justify = "right" , style = "green" )
339+
340+ rows = list (zip (* (arr [1 ] for arr in arrays )))
341+ for row in rows :
342+ table .add_row (* row )
343+
344+ console .print (table )
345+
346+ # Formatting
347+ s = "{:<{}}"
348+
349+ parameters = []
350+ for param in [trainable_parameters , total_parameters - trainable_parameters , total_parameters , model_size ]:
351+ parameters .append (s .format (get_human_readable_count (param ), 10 ))
352+
353+ grid = Table .grid (expand = True )
354+ grid .add_column ()
355+ grid .add_column ()
356+
357+ grid .add_row (f"[bold]Trainable params[/]: { parameters [0 ]} " )
358+ grid .add_row (f"[bold]Non-trainable params[/]: { parameters [1 ]} " )
359+ grid .add_row (f"[bold]Total params[/]: { parameters [2 ]} " )
360+ grid .add_row (f"[bold]Total estimated model params size (MB)[/]: { parameters [3 ]} " )
361+
362+ console .print (grid )
363+
364+ def __str__ (self ):
365+ """
366+ Makes a summary listing with:
367+
368+ Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes, Model Size
369+ """
370+ arrays = self ._get_summary_data ()
371+
317372 total_parameters = self .total_parameters
318373 trainable_parameters = self .trainable_parameters
319374 model_size = self .model_size
@@ -435,7 +490,10 @@ def _is_lazy_weight_tensor(p: Tensor) -> bool:
435490
436491
437492def summarize (
438- lightning_module : "pl.LightningModule" , mode : Optional [str ] = "top" , max_depth : Optional [int ] = None
493+ lightning_module : "pl.LightningModule" ,
494+ mode : Optional [str ] = "top" ,
495+ max_depth : Optional [int ] = None ,
496+ use_rich : bool = False ,
439497) -> Optional [ModelSummary ]:
440498 """
441499 Summarize the LightningModule specified by `lightning_module`.
@@ -467,5 +525,8 @@ def summarize(
467525 raise MisconfigurationException (f"`mode` can be None, { ', ' .join (ModelSummary .MODES )} , got { mode } " )
468526 else :
469527 model_summary = ModelSummary (lightning_module , max_depth = max_depth )
470- log .info ("\n " + str (model_summary ))
528+ if use_rich :
529+ model_summary .print_rich_summary ()
530+ else :
531+ log .info ("\n " + str (model_summary ))
471532 return model_summary
0 commit comments