2525
2626
2727def patch_module (mod , qconfig , mod_dict , patched_mod = None ):
28+ """Replaces the module with patched module according to mod_dict.
29+
30+ Args:
31+ mod (nn.module): The module that will be replaced with a patched module that quantize the inputs/outputs.
32+ qconfig (ModuleExtraConfig): The quantization config object with the information how to quantize the inputs/outputs.
33+ mod_dict (dict): dictionary from module name to its patched module.
34+
35+ Returns:
36+ nn.module: The new patched module after patching.
37+ """
2838 parent = parent_child_mod_dict [mod ].parent
2939 name = parent_child_mod_dict [mod ].name
3040 if patched_mod is None :
@@ -33,6 +43,8 @@ def patch_module(mod, qconfig, mod_dict, patched_mod=None):
3343
3444
3545def apply_hf_hook (module ):
46+ """Applies hf_hook on a given module so its weights will be loaded from disk to cpu and then we can quantize it.
47+ """
3648 if hasattr (module , "_hf_hook" ):
3749 module ._hf_hook .pre_forward (module )
3850 module ._hf_hook .detach_hook (module )
@@ -43,6 +55,12 @@ def apply_hf_hook(module):
4355
4456
4557def quantize_params (mod , mod_extra_config ):
58+ """Quantizes the weights of the given module according to the quantization info from mod_extra_config.
59+
60+ Args:
61+ mod (nn.module): The module that its weights will be quantized.
62+ mod_extra_config (ModuleExtraConfig): The quantization config object with the information how to quantize the inputs/outputs.
63+ """
4664 for param_name in mod_extra_config .params :
4765 quantizer = mod_extra_config .params [param_name ]
4866 param = getattr (mod , param_name )
@@ -55,6 +73,15 @@ def quantize_params(mod, mod_extra_config):
5573
5674
5775def prepare_model (model , qconfig , mod_list , hp_dtype = torch .float ):
76+ """Replaces the model submodules according to the mod_list with patched quantization modules.
77+ Configures patched modules with the quantization/dequantization methods to apply on their input and output tensors.
78+ Quantizes the model parameters as they are static.
79+
80+ Args:
81+ model (nn.module): The model to quantize.
82+ qconfig (dict): Dict that maps between patched module and its quantization info.
83+ mod_list (list): The specific submodules that will be quantized in the model.
84+ """
5885 config = get_hqt_config (model )
5986 patched_modules = []
6087 patched_module_types = set ()
@@ -82,6 +109,12 @@ def prepare_model(model, qconfig, mod_list, hp_dtype=torch.float):
82109
83110
84111def quantize (model , mod_list ):
112+ """Builds quantization config object that contains for each submodule its quantization functions as preparation for quantization.
113+
114+ Args:
115+ model (nn.module): The model that will be quantized.
116+ mod_list (list, optional): The specific modules that will be quantized in the model.
117+ """
85118 config = get_hqt_config (model )
86119 generate_model_info (model )
87120 hp_dtype = config .cfg ["hp_dtype" ]
0 commit comments