4141
4242
4343def wrap_qat_forward_context (
44- quant_cb , model : "pl.LightningModule" , func : Callable , trigger_condition : Optional [Union [Callable , int ]] = None
44+ quant_cb : "QuantizationAwareTraining" ,
45+ model : "pl.LightningModule" ,
46+ func : Callable ,
47+ trigger_condition : Optional [Union [Callable , int ]] = None ,
4548) -> Callable :
4649 """Decorator to wrap forward path as it is needed to quantize inputs and dequantize outputs for in/out
4750 compatibility Moreover this version has the (de)quantization conditional as it may not be needed for the
4851 training all the time."""
4952 # todo: consider using registering hook before/after forward
5053 @functools .wraps (func )
51- def wrapper (data ) -> Any :
52- _is_func_true = isinstance (trigger_condition , Callable ) and trigger_condition (model .trainer )
54+ def wrapper (data : Any ) -> Any :
55+ _is_func_true = callable (trigger_condition ) and trigger_condition (model .trainer )
5356 _is_count_true = isinstance (trigger_condition , int ) and quant_cb ._forward_calls < trigger_condition
5457 _quant_run = trigger_condition is None or _is_func_true or _is_count_true
5558 # apply custom trigger
5659 if _quant_run :
5760 quant_cb ._forward_calls += 1
58- data = model .quant (data )
61+ data = model .quant (data ) # type: ignore[operator]
5962 data = func (data )
6063 # apply custom trigger
6164 if _quant_run :
62- data = model .dequant (data )
65+ data = model .dequant (data ) # type: ignore[operator]
6366 return data
6467
6568 return wrapper
@@ -70,10 +73,10 @@ def wrap_quantize_forward_context(model: "pl.LightningModule", func: Callable) -
7073 compatibility."""
7174 # todo: consider using registering hook before/after forward
7275 @functools .wraps (func )
73- def wrapper (data ) -> Any :
74- data = model .quant (data )
76+ def wrapper (data : Any ) -> Any :
77+ data = model .quant (data ) # type: ignore[operator]
7578 data = func (data )
76- data = model .dequant (data )
79+ data = model .dequant (data ) # type: ignore[operator]
7780 return data
7881
7982 return wrapper
@@ -181,7 +184,9 @@ def __init__(
181184 )
182185 self ._observer_type = observer_type
183186
184- if collect_quantization is not None and not isinstance (collect_quantization , (int , Callable )):
187+ if collect_quantization is not None and not (
188+ isinstance (collect_quantization , int ) or callable (collect_quantization )
189+ ):
185190 raise MisconfigurationException (
186191 f'Unsupported `collect_quantization` "{ collect_quantization } ", allowed are `int` or `Callable`.'
187192 )
@@ -200,8 +205,8 @@ def __init__(
200205 self ._observer_disabled_stages = set (self .OBSERVER_STAGES ) - observer_enabled_stages
201206
202207 self ._forward_calls = 0
203- self ._fake_quant_to_initial_state_dict = {}
204- self ._last_fake_quant_to_observer_enabled = {}
208+ self ._fake_quant_to_initial_state_dict : Dict [ FakeQuantizeBase , Dict [ str , Any ]] = {}
209+ self ._last_fake_quant_to_observer_enabled : Dict [ FakeQuantizeBase , Tensor ] = {}
205210 self ._module_prepared = False
206211
207212 def _check_feasible_fuse (self , model : "pl.LightningModule" ) -> bool :
@@ -227,7 +232,7 @@ def _restore_last_observer_enabled(self) -> None:
227232 for fake_quant , observer_enabled in self ._last_fake_quant_to_observer_enabled .items ():
228233 fake_quant .observer_enabled .copy_ (observer_enabled )
229234
230- def _prepare_model (self , model : torch . nn . Module ) -> None :
235+ def _prepare_model (self , model : "pl.LightningModule" ) -> None :
231236 if self ._module_prepared :
232237 return
233238 # QuantStub converts tensors from floating point to quantized
@@ -237,7 +242,7 @@ def _prepare_model(self, model: torch.nn.Module) -> None:
237242 # manually specify where tensors will be converted from quantized
238243 # to floating point in the quantized model
239244 self .__module_forward = model .forward
240- model .forward = wrap_qat_forward_context (
245+ model .forward = wrap_qat_forward_context ( # type: ignore [assignment]
241246 quant_cb = self , model = model , func = model .forward , trigger_condition = self ._collect_quantization
242247 )
243248
@@ -247,7 +252,7 @@ def _prepare_model(self, model: torch.nn.Module) -> None:
247252 if self ._observer_type == "histogram" :
248253 model .qconfig = torch .quantization .get_default_qconfig (self ._qconfig )
249254 elif self ._observer_type == "average" :
250- extra_kwargs = {}
255+ extra_kwargs : Dict [ str , Optional [ int ]] = {}
251256 if _TORCH_GREATER_EQUAL_1_12 :
252257 extra_kwargs ["version" ] = 0
253258 # version=None corresponds to using FakeQuantize rather than
@@ -258,7 +263,7 @@ def _prepare_model(self, model: torch.nn.Module) -> None:
258263 model .qconfig = torch .quantization .get_default_qat_qconfig (self ._qconfig , ** extra_kwargs )
259264
260265 elif isinstance (self ._qconfig , QConfig ):
261- model .qconfig = self ._qconfig
266+ model .qconfig = self ._qconfig # type: ignore [assignment]
262267
263268 if self ._check_feasible_fuse (model ):
264269 fuse_modules (model , self ._modules_to_fuse , inplace = True )
@@ -273,12 +278,12 @@ def _prepare_model(self, model: torch.nn.Module) -> None:
273278 }
274279 self ._module_prepared = True
275280
276- def on_fit_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ):
281+ def on_fit_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
277282 self ._prepare_model (pl_module )
278283
279284 def on_fit_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
280285 if not self ._convert_on_fit_end :
281- pl_module .forward = self .__module_forward
286+ pl_module .forward = self .__module_forward # type: ignore [assignment]
282287 return
283288 pl_module .eval ()
284289 # Convert the observed model to a quantized model. This does several things:
@@ -288,9 +293,12 @@ def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") ->
288293 torch .quantization .convert (pl_module , inplace = True )
289294 # check we shall preserve wrapper
290295 if self ._input_compatible :
291- pl_module .forward = wrap_quantize_forward_context (model = pl_module , func = self .__module_forward )
296+ pl_module .forward = wrap_quantize_forward_context ( # type: ignore [assignment]
297+ model = pl_module ,
298+ func = self .__module_forward ,
299+ )
292300 else :
293- pl_module .forward = self .__module_forward
301+ pl_module .forward = self .__module_forward # type: ignore [assignment]
294302
295303 def on_train_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
296304 if "train" in self ._observer_disabled_stages :
@@ -336,7 +344,7 @@ def state_dict(self) -> Dict[str, Any]:
336344 keys = {"_qconfig" , "_observer_type" , "_collect_quantization" , "_modules_to_fuse" , "_input_compatible" }
337345 return {n : getattr (self , n ) for n in keys }
338346
339- def _load_before_model (self , model : torch . nn . Module , state_dict : Dict [str , Any ]) -> None :
347+ def _load_before_model (self , model : "pl.LightningModule" , state_dict : Dict [str , Any ]) -> None :
340348 """Special hook that gets called by the CheckpointConnector *before* the model gets loaded.
341349
342350 This hook replaces the :meth:`on_load_checkpoint` and :meth:`load_state_dict` callback methods which get called
0 commit comments