From 992460a26538bf9ab45468e0209a6b037abda4e4 Mon Sep 17 00:00:00 2001 From: Yuchen Jin <32186723+cainmagi@users.noreply.github.com> Date: Sat, 27 Feb 2021 01:21:15 -0600 Subject: [PATCH 1/6] Fix the multi-output problem. 1. Fix the bug of parameter number calculation when there are more than one output variables, including both sequence case and dict case. 2. Make multuple output variables split into multiple lines. 3. Remove the last line break of summary_string() 4. Enable argument "device" to accept both str and torch.device. 5. Fix a bug when the model requires "batch_size" to be a specific number. 6. Fix a bug caused by multiple input case when "dtypes=None". 7. Add text auto wrap when the layer name is too long. 8. Add docstring. --- torchsummary/torchsummary.py | 136 +++++++++++++++++++++++++---------- 1 file changed, 100 insertions(+), 36 deletions(-) diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index 1ed065f..829af3b 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -1,12 +1,29 @@ +import collections +import numpy as np + import torch import torch.nn as nn -from torch.autograd import Variable - -from collections import OrderedDict -import numpy as np -def summary(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None): +def summary(model, input_size, batch_size=-1, device='cuda:0', dtypes=None): + '''Keras-style torch summary + Iterate the whole pytorch model and summarize the infomation as a Keras-style + text report. The output would be store in a str. + Arguments: + model: an instance of nn.Module + input_size: a sequence (list/tuple) or a sequence of sequnces, indicating + the size of the each model input variable. + batch_size: a int. The batch size used for testing and displaying the + results. + device: a str or torch.device. Should be set according to the deployed + device of the argument "model". + dtype: a list or torch data type for each input variable. + Returns: + 1. tensor, total parameter numbers. + 2. tensor, trainable parameter numbers. + ''' + if isinstance(device, str): + device = torch.device(device) result, params_info = summary_string( model, input_size, batch_size, device, dtypes) print(result) @@ -14,9 +31,25 @@ def summary(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dty return params_info -def summary_string(model, input_size, batch_size=-1, device=torch.device('cuda:0'), dtypes=None): - if dtypes == None: - dtypes = [torch.FloatTensor]*len(input_size) +def summary_string(model, input_size, batch_size=-1, device='cuda:0', dtypes=None): + '''Keras-style torch summary (string output) + Iterate the whole pytorch model and summarize the infomation as a Keras-style + text report. The output would be store in a str. + Arguments: + model: an instance of nn.Module + input_size: a sequence (list/tuple) or a sequence of sequnces, indicating + the size of the each model input variable. + batch_size: a int. The batch size used for testing and displaying the + results. + device: a str or torch.device. Should be set according to the deployed + device of the argument "model". + dtype: a list or torch data type for each input variable. + Returns: + 1. str, the summary text report. + 2. tuple, (total parameter numbers, trainable parameter numbers) + ''' + if isinstance(device, str): + device = torch.device(device) summary_str = '' @@ -26,10 +59,14 @@ def hook(module, input, output): module_idx = len(summary) m_key = "%s-%i" % (class_name, module_idx + 1) - summary[m_key] = OrderedDict() + summary[m_key] = collections.OrderedDict() summary[m_key]["input_shape"] = list(input[0].size()) summary[m_key]["input_shape"][0] = batch_size - if isinstance(output, (list, tuple)): + if isinstance(output, dict): + summary[m_key]["output_shape"] = [ + [-1] + list(o.size())[1:] for o in output.values() + ] + elif isinstance(output, (list, tuple)): summary[m_key]["output_shape"] = [ [-1] + list(o.size())[1:] for o in output ] @@ -45,22 +82,31 @@ def hook(module, input, output): params += torch.prod(torch.LongTensor(list(module.bias.size()))) summary[m_key]["nb_params"] = params - if ( - not isinstance(module, nn.Sequential) - and not isinstance(module, nn.ModuleList) - ): + if (not isinstance(module, nn.Sequential) and not isinstance(module, nn.ModuleList)): hooks.append(module.register_forward_hook(hook)) # multiple inputs to the network - if isinstance(input_size, tuple): - input_size = [input_size] + if isinstance(input_size, (list, tuple)) and len(input_size) > 0: + if not isinstance(input_size[0], (list, tuple)): + input_size = (input_size, ) + else: + raise ValueError('The argument "input_size" is not a tuple of a sequence of tuple. Given "{0}".'.format(input_size)) + + if dtypes is None: + dtypes = [torch.FloatTensor] * len(input_size) + if len(dtypes) != len(input_size): + raise ValueError('The lengths of the arguments "input_size" and "dtypes" does not correspond to each other.') # batch_size of 2 for batchnorm - x = [torch.rand(2, *in_size).type(dtype).to(device=device) + if batch_size == -1: + batch_size_ = 2 + else: + batch_size_ = batch_size + x = [torch.rand(batch_size_, *in_size).type(dtype).to(device=device) for in_size, dtype in zip(input_size, dtypes)] # create properties - summary = OrderedDict() + summary = collections.OrderedDict() hooks = [] # register hook @@ -84,37 +130,55 @@ def hook(module, input, output): trainable_params = 0 for layer in summary: # input_shape, output_shape, trainable, nb_params - line_new = "{:>20} {:>25} {:>15}".format( - layer, - str(summary[layer]["output_shape"]), - "{0:,}".format(summary[layer]["nb_params"]), - ) - total_params += summary[layer]["nb_params"] - - total_output += np.prod(summary[layer]["output_shape"]) - if "trainable" in summary[layer]: - if summary[layer]["trainable"] == True: - trainable_params += summary[layer]["nb_params"] + sum_layer = summary[layer] + if len(layer) > 20: + layer_disp = '{lhead}...{ltail}'.format(lhead=layer[:8], ltail=layer[-9:]) # 20 = 9 + 8 + 3 + else: + layer_disp = layer + if len(sum_layer["output_shape"]) > 0 and isinstance(sum_layer["output_shape"][0], (list, tuple)): # Add multiple output support + line_new = ["{:>20} {:>25} {:>15}".format( + layer_disp, + str(sum_layer["output_shape"][0]), + "{0:,}".format(sum_layer["nb_params"]), + )] + for oshape in sum_layer["output_shape"][1:]: + line_new.append("{:>20} {:>25} {:>15}".format( + '', str(oshape), '' + )) + line_new = '\n'.join(line_new) + else: + line_new = "{:>20} {:>25} {:>15}".format( + layer_disp, + str(sum_layer["output_shape"]), + "{0:,}".format(sum_layer["nb_params"]), + ) + total_params += sum_layer["nb_params"] + + output_shape = sum_layer["output_shape"] + if isinstance(output_shape[0], (list, tuple)): + total_output += np.sum(list(map(np.prod, output_shape)), dtype=np.int) + else: + total_output += np.prod(output_shape, dtype=np.int) + if "trainable" in sum_layer: + if sum_layer["trainable"] is True: + trainable_params += sum_layer["nb_params"] summary_str += line_new + "\n" # assume 4 bytes/number (float on cuda). - total_input_size = abs(np.prod(sum(input_size, ())) - * batch_size * 4. / (1024 ** 2.)) - total_output_size = abs(2. * total_output * 4. / - (1024 ** 2.)) # x2 for gradients + total_input_size = abs(np.sum(list(map(np.prod, input_size))) * batch_size * 4. / (1024 ** 2.)) + total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients total_params_size = abs(total_params * 4. / (1024 ** 2.)) total_size = total_params_size + total_output_size + total_input_size summary_str += "================================================================" + "\n" summary_str += "Total params: {0:,}".format(total_params) + "\n" summary_str += "Trainable params: {0:,}".format(trainable_params) + "\n" - summary_str += "Non-trainable params: {0:,}".format(total_params - - trainable_params) + "\n" + summary_str += "Non-trainable params: {0:,}".format(total_params - trainable_params) + "\n" summary_str += "----------------------------------------------------------------" + "\n" summary_str += "Input size (MB): %0.2f" % total_input_size + "\n" summary_str += "Forward/backward pass size (MB): %0.2f" % total_output_size + "\n" summary_str += "Params size (MB): %0.2f" % total_params_size + "\n" summary_str += "Estimated Total Size (MB): %0.2f" % total_size + "\n" - summary_str += "----------------------------------------------------------------" + "\n" + summary_str += "----------------------------------------------------------------" # return summary return summary_str, (total_params, trainable_params) From 18bf2109eed6217b0f032f0e3b72539a203be63c Mon Sep 17 00:00:00 2001 From: Yuchen Jin Date: Sat, 27 Feb 2021 21:24:03 -0600 Subject: [PATCH 2/6] Fix parameter counting problem. Support counting all parameters instead of `weight` and `bias`. --- torchsummary/torchsummary.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index 829af3b..7b05233 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -58,29 +58,31 @@ def hook(module, input, output): class_name = str(module.__class__).split(".")[-1].split("'")[0] module_idx = len(summary) - m_key = "%s-%i" % (class_name, module_idx + 1) - summary[m_key] = collections.OrderedDict() - summary[m_key]["input_shape"] = list(input[0].size()) - summary[m_key]["input_shape"][0] = batch_size + m_key = '{name:s}-{idx:d}'.format(name=class_name, idx=module_idx + 1) + sum_layer = collections.OrderedDict() + summary[m_key] = sum_layer + sum_layer["input_shape"] = list(input[0].size()) + sum_layer["input_shape"][0] = batch_size if isinstance(output, dict): - summary[m_key]["output_shape"] = [ + sum_layer["output_shape"] = [ [-1] + list(o.size())[1:] for o in output.values() ] elif isinstance(output, (list, tuple)): - summary[m_key]["output_shape"] = [ + sum_layer["output_shape"] = [ [-1] + list(o.size())[1:] for o in output ] else: - summary[m_key]["output_shape"] = list(output.size()) - summary[m_key]["output_shape"][0] = batch_size + sum_layer["output_shape"] = list(output.size()) + sum_layer["output_shape"][0] = batch_size params = 0 - if hasattr(module, "weight") and hasattr(module.weight, "size"): - params += torch.prod(torch.LongTensor(list(module.weight.size()))) - summary[m_key]["trainable"] = module.weight.requires_grad - if hasattr(module, "bias") and hasattr(module.bias, "size"): - params += torch.prod(torch.LongTensor(list(module.bias.size()))) - summary[m_key]["nb_params"] = params + params_trainable = 0 + for param in module.parameters(recurse=False): + nb_param = torch.prod(torch.LongTensor(list(param.size()))) + params += nb_param + params_trainable += nb_param if param.requires_grad else 0 + sum_layer["nb_params"] = params + sum_layer["nb_params_trainable"] = params_trainable if (not isinstance(module, nn.Sequential) and not isinstance(module, nn.ModuleList)): hooks.append(module.register_forward_hook(hook)) @@ -159,9 +161,7 @@ def hook(module, input, output): total_output += np.sum(list(map(np.prod, output_shape)), dtype=np.int) else: total_output += np.prod(output_shape, dtype=np.int) - if "trainable" in sum_layer: - if sum_layer["trainable"] is True: - trainable_params += sum_layer["nb_params"] + trainable_params += sum_layer["nb_params_trainable"] summary_str += line_new + "\n" # assume 4 bytes/number (float on cuda). From c8836d5e9f0437af87d0f6b1126eeab5a44f023d Mon Sep 17 00:00:00 2001 From: Yuchen Jin <32186723+cainmagi@users.noreply.github.com> Date: Sat, 27 Feb 2021 23:20:44 -0600 Subject: [PATCH 3/6] Fix the long int overflow problem. Using numpy sum/prod to calculate the total size may cause overflow problem. This modification would drop the numpy and use the python built-in method to calculate the size. --- torchsummary/torchsummary.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index 7b05233..7b65d7f 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -1,10 +1,22 @@ +import functools import collections -import numpy as np import torch import torch.nn as nn +def long_sum(v): + if not all(map(lambda x: isinstance(x, int), v)): + raise ValueError('The long_sum only supports the sequence with all int elements.') + return functools.reduce(lambda x, y: x + y, v) + + +def long_prod(v): + if not all(map(lambda x: isinstance(x, int), v)): + raise ValueError('The long_sum only supports the sequence with all int elements.') + return functools.reduce(lambda x, y: x * y, v) + + def summary(model, input_size, batch_size=-1, device='cuda:0', dtypes=None): '''Keras-style torch summary Iterate the whole pytorch model and summarize the infomation as a Keras-style @@ -158,14 +170,14 @@ def hook(module, input, output): output_shape = sum_layer["output_shape"] if isinstance(output_shape[0], (list, tuple)): - total_output += np.sum(list(map(np.prod, output_shape)), dtype=np.int) + total_output += long_sum(list(map(long_prod, output_shape))) else: - total_output += np.prod(output_shape, dtype=np.int) + total_output += long_prod(output_shape) trainable_params += sum_layer["nb_params_trainable"] summary_str += line_new + "\n" # assume 4 bytes/number (float on cuda). - total_input_size = abs(np.sum(list(map(np.prod, input_size))) * batch_size * 4. / (1024 ** 2.)) + total_input_size = abs(long_sum(list(map(long_prod, input_size))) * batch_size * 4. / (1024 ** 2.)) total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients total_params_size = abs(total_params * 4. / (1024 ** 2.)) total_size = total_params_size + total_output_size + total_input_size From 37f8e5aa3ccbdfaf87e69d4149bd5fb870e49103 Mon Sep 17 00:00:00 2001 From: Yuchen Jin <32186723+cainmagi@users.noreply.github.com> Date: Sat, 27 Feb 2021 23:43:56 -0600 Subject: [PATCH 4/6] Fix dict input problem. Fix the bug caused by layers with dict input values. --- torchsummary/torchsummary.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index 7b65d7f..b20f82e 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -73,7 +73,10 @@ def hook(module, input, output): m_key = '{name:s}-{idx:d}'.format(name=class_name, idx=module_idx + 1) sum_layer = collections.OrderedDict() summary[m_key] = sum_layer - sum_layer["input_shape"] = list(input[0].size()) + if isinstance(input[0], dict): + sum_layer["input_shape"] = list(next(iter(input[0].values())).size()) + else: + sum_layer["input_shape"] = list(input[0].size()) sum_layer["input_shape"][0] = batch_size if isinstance(output, dict): sum_layer["output_shape"] = [ From 37ab4addde601ec6a9d03ab4db1dc9fd917bf63a Mon Sep 17 00:00:00 2001 From: Yuchen Jin <32186723+cainmagi@users.noreply.github.com> Date: Sat, 27 Feb 2021 23:56:48 -0600 Subject: [PATCH 5/6] Fix the output params_info type. Fix the data type of the output params_info from torch.tensor to int. --- torchsummary/torchsummary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index b20f82e..ca14008 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -93,7 +93,7 @@ def hook(module, input, output): params = 0 params_trainable = 0 for param in module.parameters(recurse=False): - nb_param = torch.prod(torch.LongTensor(list(param.size()))) + nb_param = torch.prod(torch.LongTensor(list(param.size()))).item() params += nb_param params_trainable += nb_param if param.requires_grad else 0 sum_layer["nb_params"] = params From 2179d8e1eaf0686dd04027280e42f7a8c6fa9b6a Mon Sep 17 00:00:00 2001 From: GCS-ZHN Date: Wed, 26 Oct 2022 13:54:26 +0800 Subject: [PATCH 6/6] Fix NoneType Error for multi head attention module --- .gitignore | 5 ++++- setup.py | 2 +- torchsummary/torchsummary.py | 36 +++++++++++++++++++++++++++++++++--- 3 files changed, 38 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 1e7faf7..269c56c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ __pycache__ *.pyc -.vscode/ \ No newline at end of file +.vscode/ +build/ +dist/ +*.egg-info/ \ No newline at end of file diff --git a/setup.py b/setup.py index 35f13cd..a6b3ef5 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="torchsummary", - version="1.5.1", + version="1.5.1rc1", description="Model summary in PyTorch similar to `model.summary()` in Keras", url="https://github.com/sksq96/pytorch-summary", author="Shubham Chandel @sksq96", diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index ca14008..e0217eb 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -1,5 +1,6 @@ import functools import collections +import warnings import torch import torch.nn as nn @@ -43,6 +44,26 @@ def summary(model, input_size, batch_size=-1, device='cuda:0', dtypes=None): return params_info +def raise_no_tensor_error(module): + raise ValueError('The module {} does not return at list one tensor.'.format(module)) + + +def filter_not_array_like(input_data, module): + not_array_like = [] + def _filter(x): + if hasattr(x, "size"): + return True + else: + not_array_like.append(x) + return False + f = list(filter(_filter, input_data)) + if len(not_array_like) > 0: + warnings.warn( + 'Output of module {} contains some elements which not like arrays or tensors.'.format( + module.__class__)) + return f + + def summary_string(model, input_size, batch_size=-1, device='cuda:0', dtypes=None): '''Keras-style torch summary (string output) Iterate the whole pytorch model and summarize the infomation as a Keras-style @@ -80,15 +101,24 @@ def hook(module, input, output): sum_layer["input_shape"][0] = batch_size if isinstance(output, dict): sum_layer["output_shape"] = [ - [-1] + list(o.size())[1:] for o in output.values() + [batch_size] + list(o.size())[1:] for o in filter_not_array_like(output.values(), module) ] + if len(sum_layer["output_shape"]) == 0: + raise_no_tensor_error(module) elif isinstance(output, (list, tuple)): sum_layer["output_shape"] = [ - [-1] + list(o.size())[1:] for o in output + [batch_size] + list(o.size())[1:] for o in filter_not_array_like(output, module) ] - else: + if len(sum_layer["output_shape"]) == 0: + raise_no_tensor_error(module) + elif output is None: + raise_no_tensor_error(module) + elif isinstance(output, torch.Tensor): sum_layer["output_shape"] = list(output.size()) sum_layer["output_shape"][0] = batch_size + else: + raise ValueError('The output type {} of the module {} is not supported yet.'.format( + type(output), module)) params = 0 params_trainable = 0