@@ -159,6 +159,7 @@ class ModelSummary(object):
159159 132 K Trainable params
160160 0 Non-trainable params
161161 132 K Total params
162+ 0.530 Total estimated model params size (MB)
162163 >>> ModelSummary(model, mode='full') # doctest: +NORMALIZE_WHITESPACE
163164 | Name | Type | Params | In sizes | Out sizes
164165 --------------------------------------------------------------
@@ -169,6 +170,7 @@ class ModelSummary(object):
169170 132 K Trainable params
170171 0 Non-trainable params
171172 132 K Total params
173+ 0.530 Total estimated model params size (MB)
172174 """
173175
174176 MODE_TOP = "top"
@@ -180,6 +182,7 @@ def __init__(self, model, mode: str = MODE_DEFAULT):
180182 self ._model = model
181183 self ._mode = mode
182184 self ._layer_summary = self .summarize ()
185+ self ._precision_megabytes = (self ._model .precision / 8.0 ) * 1e-6 # 1 byte -> 8 bits
183186
184187 @property
185188 def named_modules (self ) -> List [Tuple [str , nn .Module ]]:
@@ -213,6 +216,18 @@ def out_sizes(self) -> List:
213216 def param_nums (self ) -> List [int ]:
214217 return [layer .num_parameters for layer in self ._layer_summary .values ()]
215218
219+ @property
220+ def total_parameters (self ) -> int :
221+ return sum (p .numel () for p in self ._model .parameters ())
222+
223+ @property
224+ def trainable_parameters (self ) -> int :
225+ return sum (p .numel () for p in self ._model .parameters () if p .requires_grad )
226+
227+ @property
228+ def model_size (self ) -> float :
229+ return self .total_parameters * self ._precision_megabytes
230+
216231 def summarize (self ) -> Dict [str , LayerSummary ]:
217232 summary = OrderedDict ((name , LayerSummary (module )) for name , module in self .named_modules )
218233 if self ._model .example_input_array is not None :
@@ -248,7 +263,7 @@ def __str__(self):
248263 """
249264 Makes a summary listing with:
250265
251- Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes
266+ Layer Name, Layer Type, Number of Parameters, Input Sizes, Output Sizes, Model Size
252267 """
253268 arrays = [
254269 [" " , list (map (str , range (len (self ._layer_summary ))))],
@@ -259,11 +274,11 @@ def __str__(self):
259274 if self ._model .example_input_array is not None :
260275 arrays .append (["In sizes" , self .in_sizes ])
261276 arrays .append (["Out sizes" , self .out_sizes ])
277+ total_parameters = self .total_parameters
278+ trainable_parameters = self .trainable_parameters
279+ model_size = self .model_size
262280
263- trainable_parameters = sum (p .numel () for p in self ._model .parameters () if p .requires_grad )
264- total_parameters = sum (p .numel () for p in self ._model .parameters ())
265-
266- return _format_summary_table (total_parameters , trainable_parameters , * arrays )
281+ return _format_summary_table (total_parameters , trainable_parameters , model_size , * arrays )
267282
268283 def __repr__ (self ):
269284 return str (self )
@@ -280,7 +295,7 @@ def parse_batch_shape(batch: Any) -> Union[str, List]:
280295 return UNKNOWN_SIZE
281296
282297
283- def _format_summary_table (total_parameters : int , trainable_parameters : int , * cols ) -> str :
298+ def _format_summary_table (total_parameters : int , trainable_parameters : int , model_size : float , * cols ) -> str :
284299 """
285300 Takes in a number of arrays, each specifying a column in
286301 the summary table, and combines them all into one big
@@ -316,6 +331,8 @@ def _format_summary_table(total_parameters: int, trainable_parameters: int, *col
316331 summary += "Non-trainable params"
317332 summary += "\n " + s .format (get_human_readable_count (total_parameters ), 10 )
318333 summary += "Total params"
334+ summary += "\n " + s .format (get_formatted_model_size (model_size ), 10 )
335+ summary += "Total estimated model params size (MB)"
319336
320337 return summary
321338
@@ -372,6 +389,8 @@ def get_gpu_memory_map() -> Dict[str, int]:
372389 }
373390 return gpu_memory_map
374391
392+ def get_formatted_model_size (total_model_size : float ) -> float :
393+ return f"{ total_model_size :,.3f} "
375394
376395def get_human_readable_count (number : int ) -> str :
377396 """
0 commit comments