|
41 | 41 |
|
42 | 42 |
|
43 | 43 | def wrap_qat_forward_context( |
44 | | - quant_cb, model: "pl.LightningModule", func: Callable, trigger_condition: Optional[Union[Callable, int]] = None |
| 44 | + quant_cb: Any, model: "pl.LightningModule", func: Callable, trigger_condition: Optional[Union[Callable, int]] = None |
45 | 45 | ) -> Callable: |
46 | 46 | """Decorator to wrap forward path as it is needed to quantize inputs and dequantize outputs for in/out |
47 | 47 | compatibility Moreover this version has the (de)quantization conditional as it may not be needed for the |
48 | 48 | training all the time.""" |
49 | 49 | # todo: consider using registering hook before/after forward |
50 | 50 | @functools.wraps(func) |
51 | | - def wrapper(data) -> Any: |
| 51 | + def wrapper(data: Any) -> Any: |
52 | 52 | _is_func_true = isinstance(trigger_condition, Callable) and trigger_condition(model.trainer) |
53 | 53 | _is_count_true = isinstance(trigger_condition, int) and quant_cb._forward_calls < trigger_condition |
54 | 54 | _quant_run = trigger_condition is None or _is_func_true or _is_count_true |
@@ -200,8 +200,8 @@ def __init__( |
200 | 200 | self._observer_disabled_stages = set(self.OBSERVER_STAGES) - observer_enabled_stages |
201 | 201 |
|
202 | 202 | self._forward_calls = 0 |
203 | | - self._fake_quant_to_initial_state_dict = {} |
204 | | - self._last_fake_quant_to_observer_enabled = {} |
| 203 | + self._fake_quant_to_initial_state_dict: Dict[FakeQuantizeBase, Tensor] = {} |
| 204 | + self._last_fake_quant_to_observer_enabled: Dict[FakeQuantizeBase, Tensor] = {} |
205 | 205 | self._module_prepared = False |
206 | 206 |
|
207 | 207 | def _check_feasible_fuse(self, model: "pl.LightningModule") -> bool: |
@@ -273,7 +273,7 @@ def _prepare_model(self, model: torch.nn.Module) -> None: |
273 | 273 | } |
274 | 274 | self._module_prepared = True |
275 | 275 |
|
276 | | - def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): |
| 276 | + def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
277 | 277 | self._prepare_model(pl_module) |
278 | 278 |
|
279 | 279 | def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: |
|
0 commit comments