From e534e9c6c467e892bd11b0bdca744f24b46d2124 Mon Sep 17 00:00:00 2001 From: Krishna Kalyan Date: Thu, 21 Jul 2022 06:26:39 -0400 Subject: [PATCH 1/8] init commit --- pyproject.toml | 1 - src/pytorch_lightning/callbacks/quantization.py | 10 +++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 989e63122f640..c8c885415acf5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,6 @@ warn_no_return = "False" module = [ "pytorch_lightning.callbacks.model_checkpoint", "pytorch_lightning.callbacks.progress.rich_progress", - "pytorch_lightning.callbacks.quantization", "pytorch_lightning.callbacks.stochastic_weight_avg", "pytorch_lightning.core.datamodule", "pytorch_lightning.core.decorators", diff --git a/src/pytorch_lightning/callbacks/quantization.py b/src/pytorch_lightning/callbacks/quantization.py index af983ef101b0b..ae603cbcfc502 100644 --- a/src/pytorch_lightning/callbacks/quantization.py +++ b/src/pytorch_lightning/callbacks/quantization.py @@ -41,14 +41,14 @@ def wrap_qat_forward_context( - quant_cb, model: "pl.LightningModule", func: Callable, trigger_condition: Optional[Union[Callable, int]] = None + quant_cb: Any, model: "pl.LightningModule", func: Callable, trigger_condition: Optional[Union[Callable, int]] = None ) -> Callable: """Decorator to wrap forward path as it is needed to quantize inputs and dequantize outputs for in/out compatibility Moreover this version has the (de)quantization conditional as it may not be needed for the training all the time.""" # todo: consider using registering hook before/after forward @functools.wraps(func) - def wrapper(data) -> Any: + def wrapper(data: Any) -> Any: _is_func_true = isinstance(trigger_condition, Callable) and trigger_condition(model.trainer) _is_count_true = isinstance(trigger_condition, int) and quant_cb._forward_calls < trigger_condition _quant_run = trigger_condition is None or _is_func_true or _is_count_true @@ -200,8 +200,8 @@ def __init__( self._observer_disabled_stages = set(self.OBSERVER_STAGES) - observer_enabled_stages self._forward_calls = 0 - self._fake_quant_to_initial_state_dict = {} - self._last_fake_quant_to_observer_enabled = {} + self._fake_quant_to_initial_state_dict: Dict[FakeQuantizeBase, Tensor] = {} + self._last_fake_quant_to_observer_enabled: Dict[FakeQuantizeBase, Tensor] = {} self._module_prepared = False def _check_feasible_fuse(self, model: "pl.LightningModule") -> bool: @@ -273,7 +273,7 @@ def _prepare_model(self, model: torch.nn.Module) -> None: } self._module_prepared = True - def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"): + def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: self._prepare_model(pl_module) def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: From 11de4f0c2706fbdd29a0f970585246f023c81b2b Mon Sep 17 00:00:00 2001 From: Krishna Kalyan Date: Tue, 2 Aug 2022 04:34:21 -0400 Subject: [PATCH 2/8] fix type --- src/pytorch_lightning/callbacks/quantization.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/callbacks/quantization.py b/src/pytorch_lightning/callbacks/quantization.py index ae603cbcfc502..156036bee05b2 100644 --- a/src/pytorch_lightning/callbacks/quantization.py +++ b/src/pytorch_lightning/callbacks/quantization.py @@ -41,7 +41,10 @@ def wrap_qat_forward_context( - quant_cb: Any, model: "pl.LightningModule", func: Callable, trigger_condition: Optional[Union[Callable, int]] = None + quant_cb: Callback, + model: "pl.LightningModule", + func: Callable, + trigger_condition: Optional[Union[Callable, int]] = None, ) -> Callable: """Decorator to wrap forward path as it is needed to quantize inputs and dequantize outputs for in/out compatibility Moreover this version has the (de)quantization conditional as it may not be needed for the From b18c761dd6aa1831cf46ce7e9231f13570a171fc Mon Sep 17 00:00:00 2001 From: Krishna Kalyan Date: Tue, 9 Aug 2022 05:03:42 -0400 Subject: [PATCH 3/8] callable function check --- src/pytorch_lightning/callbacks/quantization.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/callbacks/quantization.py b/src/pytorch_lightning/callbacks/quantization.py index 156036bee05b2..f79190782b560 100644 --- a/src/pytorch_lightning/callbacks/quantization.py +++ b/src/pytorch_lightning/callbacks/quantization.py @@ -41,7 +41,7 @@ def wrap_qat_forward_context( - quant_cb: Callback, + quant_cb, model: "pl.LightningModule", func: Callable, trigger_condition: Optional[Union[Callable, int]] = None, @@ -52,7 +52,7 @@ def wrap_qat_forward_context( # todo: consider using registering hook before/after forward @functools.wraps(func) def wrapper(data: Any) -> Any: - _is_func_true = isinstance(trigger_condition, Callable) and trigger_condition(model.trainer) + _is_func_true = callable(trigger_condition) and trigger_condition(model.trainer) _is_count_true = isinstance(trigger_condition, int) and quant_cb._forward_calls < trigger_condition _quant_run = trigger_condition is None or _is_func_true or _is_count_true # apply custom trigger @@ -184,7 +184,9 @@ def __init__( ) self._observer_type = observer_type - if collect_quantization is not None and not isinstance(collect_quantization, (int, Callable)): + if collect_quantization is not None and not ( + isinstance(collect_quantization, int) and callable(collect_quantization) + ): raise MisconfigurationException( f'Unsupported `collect_quantization` "{collect_quantization}", allowed are `int` or `Callable`.' ) From afe0550fdedb39d792fbea65ad8ff108fd31875e Mon Sep 17 00:00:00 2001 From: Krishna Kalyan Date: Wed, 10 Aug 2022 03:55:40 -0400 Subject: [PATCH 4/8] ignore tenor not callable issue --- src/pytorch_lightning/callbacks/quantization.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/pytorch_lightning/callbacks/quantization.py b/src/pytorch_lightning/callbacks/quantization.py index f79190782b560..ee7103c809f7f 100644 --- a/src/pytorch_lightning/callbacks/quantization.py +++ b/src/pytorch_lightning/callbacks/quantization.py @@ -58,11 +58,11 @@ def wrapper(data: Any) -> Any: # apply custom trigger if _quant_run: quant_cb._forward_calls += 1 - data = model.quant(data) + data = model.quant(data) # type: ignore[operator] data = func(data) # apply custom trigger if _quant_run: - data = model.dequant(data) + data = model.dequant(data) # type: ignore[operator] return data return wrapper @@ -74,9 +74,9 @@ def wrap_quantize_forward_context(model: "pl.LightningModule", func: Callable) - # todo: consider using registering hook before/after forward @functools.wraps(func) def wrapper(data) -> Any: - data = model.quant(data) + data = model.quant(data) # type: ignore[operator] data = func(data) - data = model.dequant(data) + data = model.dequant(data) # type: ignore[operator] return data return wrapper From 5825e8ebaa5f74a9defa5c51a58319c6dd0158e2 Mon Sep 17 00:00:00 2001 From: Krishna Kalyan Date: Wed, 10 Aug 2022 04:42:18 -0400 Subject: [PATCH 5/8] pair review with ota --- .../callbacks/quantization.py | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/src/pytorch_lightning/callbacks/quantization.py b/src/pytorch_lightning/callbacks/quantization.py index ee7103c809f7f..9f3358b1575ad 100644 --- a/src/pytorch_lightning/callbacks/quantization.py +++ b/src/pytorch_lightning/callbacks/quantization.py @@ -18,7 +18,7 @@ """ import copy import functools -from typing import Any, Callable, Dict, Optional, Sequence, Union +from typing import Any, Callable, Dict, Optional, OrderedDict, Sequence, Union import torch from torch import Tensor @@ -41,7 +41,7 @@ def wrap_qat_forward_context( - quant_cb, + quant_cb: "QuantizationAwareTraining", model: "pl.LightningModule", func: Callable, trigger_condition: Optional[Union[Callable, int]] = None, @@ -73,7 +73,7 @@ def wrap_quantize_forward_context(model: "pl.LightningModule", func: Callable) - compatibility.""" # todo: consider using registering hook before/after forward @functools.wraps(func) - def wrapper(data) -> Any: + def wrapper(data: Any) -> Any: data = model.quant(data) # type: ignore[operator] data = func(data) data = model.dequant(data) # type: ignore[operator] @@ -205,7 +205,7 @@ def __init__( self._observer_disabled_stages = set(self.OBSERVER_STAGES) - observer_enabled_stages self._forward_calls = 0 - self._fake_quant_to_initial_state_dict: Dict[FakeQuantizeBase, Tensor] = {} + self._fake_quant_to_initial_state_dict: Dict[FakeQuantizeBase, OrderedDict[str, Tensor]] = {} self._last_fake_quant_to_observer_enabled: Dict[FakeQuantizeBase, Tensor] = {} self._module_prepared = False @@ -232,7 +232,7 @@ def _restore_last_observer_enabled(self) -> None: for fake_quant, observer_enabled in self._last_fake_quant_to_observer_enabled.items(): fake_quant.observer_enabled.copy_(observer_enabled) - def _prepare_model(self, model: torch.nn.Module) -> None: + def _prepare_model(self, model: "pl.LightningModule") -> None: if self._module_prepared: return # QuantStub converts tensors from floating point to quantized @@ -242,7 +242,7 @@ def _prepare_model(self, model: torch.nn.Module) -> None: # manually specify where tensors will be converted from quantized # to floating point in the quantized model self.__module_forward = model.forward - model.forward = wrap_qat_forward_context( + model.forward = wrap_qat_forward_context( # type: ignore [assignment] quant_cb=self, model=model, func=model.forward, trigger_condition=self._collect_quantization ) @@ -252,7 +252,7 @@ def _prepare_model(self, model: torch.nn.Module) -> None: if self._observer_type == "histogram": model.qconfig = torch.quantization.get_default_qconfig(self._qconfig) elif self._observer_type == "average": - extra_kwargs = {} + extra_kwargs: Dict[str, Optional[int]] = {} if _TORCH_GREATER_EQUAL_1_12: extra_kwargs["version"] = 0 # version=None corresponds to using FakeQuantize rather than @@ -263,7 +263,7 @@ def _prepare_model(self, model: torch.nn.Module) -> None: model.qconfig = torch.quantization.get_default_qat_qconfig(self._qconfig, **extra_kwargs) elif isinstance(self._qconfig, QConfig): - model.qconfig = self._qconfig + model.qconfig = self._qconfig # type: ignore [assignment] if self._check_feasible_fuse(model): fuse_modules(model, self._modules_to_fuse, inplace=True) @@ -283,7 +283,7 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") - def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if not self._convert_on_fit_end: - pl_module.forward = self.__module_forward + pl_module.forward = self.__module_forward # type: ignore [assignment] return pl_module.eval() # Convert the observed model to a quantized model. This does several things: @@ -293,9 +293,12 @@ def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> torch.quantization.convert(pl_module, inplace=True) # check we shall preserve wrapper if self._input_compatible: - pl_module.forward = wrap_quantize_forward_context(model=pl_module, func=self.__module_forward) + pl_module.forward = wrap_quantize_forward_context( + model=pl_module, + func=self.__module_forward, + ) # type: ignore [assignment] else: - pl_module.forward = self.__module_forward + pl_module.forward = self.__module_forward # type: ignore [assignment] def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if "train" in self._observer_disabled_stages: @@ -341,7 +344,7 @@ def state_dict(self) -> Dict[str, Any]: keys = {"_qconfig", "_observer_type", "_collect_quantization", "_modules_to_fuse", "_input_compatible"} return {n: getattr(self, n) for n in keys} - def _load_before_model(self, model: torch.nn.Module, state_dict: Dict[str, Any]) -> None: + def _load_before_model(self, model: "pl.LightningModule", state_dict: Dict[str, Any]) -> None: """Special hook that gets called by the CheckpointConnector *before* the model gets loaded. This hook replaces the :meth:`on_load_checkpoint` and :meth:`load_state_dict` callback methods which get called From b5a4e2fc045ae61005c606f196411fa1ed2b40bf Mon Sep 17 00:00:00 2001 From: Krishna Kalyan Date: Wed, 10 Aug 2022 04:44:41 -0400 Subject: [PATCH 6/8] pair review with ota --- src/pytorch_lightning/callbacks/quantization.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/callbacks/quantization.py b/src/pytorch_lightning/callbacks/quantization.py index 9f3358b1575ad..25870217222f9 100644 --- a/src/pytorch_lightning/callbacks/quantization.py +++ b/src/pytorch_lightning/callbacks/quantization.py @@ -293,10 +293,10 @@ def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> torch.quantization.convert(pl_module, inplace=True) # check we shall preserve wrapper if self._input_compatible: - pl_module.forward = wrap_quantize_forward_context( + pl_module.forward = wrap_quantize_forward_context( # type: ignore [assignment] model=pl_module, func=self.__module_forward, - ) # type: ignore [assignment] + ) else: pl_module.forward = self.__module_forward # type: ignore [assignment] From ff6e7d45f7333c1246e6b9fb5f62a10899f8ffc8 Mon Sep 17 00:00:00 2001 From: otaj Date: Wed, 10 Aug 2022 09:34:39 -0400 Subject: [PATCH 7/8] final fixes --- src/pytorch_lightning/callbacks/quantization.py | 4 ++-- .../trainer/connectors/checkpoint_connector.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/callbacks/quantization.py b/src/pytorch_lightning/callbacks/quantization.py index 25870217222f9..89c751ec74faf 100644 --- a/src/pytorch_lightning/callbacks/quantization.py +++ b/src/pytorch_lightning/callbacks/quantization.py @@ -18,7 +18,7 @@ """ import copy import functools -from typing import Any, Callable, Dict, Optional, OrderedDict, Sequence, Union +from typing import Any, Callable, Dict, Optional, Sequence, Union import torch from torch import Tensor @@ -205,7 +205,7 @@ def __init__( self._observer_disabled_stages = set(self.OBSERVER_STAGES) - observer_enabled_stages self._forward_calls = 0 - self._fake_quant_to_initial_state_dict: Dict[FakeQuantizeBase, OrderedDict[str, Tensor]] = {} + self._fake_quant_to_initial_state_dict: Dict[FakeQuantizeBase, Dict[str, Any]] = {} self._last_fake_quant_to_observer_enabled: Dict[FakeQuantizeBase, Tensor] = {} self._module_prepared = False diff --git a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 22f61c845360d..e1dccd11a0d0a 100644 --- a/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/src/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -245,7 +245,7 @@ def _restore_quantization_callbacks(self) -> None: if state: # The Quantization callbacks have a special method that must be called before restoring the weights # of the model - callback._load_before_model(self.trainer.model, deepcopy(state)) + callback._load_before_model(self.trainer.lightning_module, deepcopy(state)) def restore_callbacks(self) -> None: """Restores all callbacks from the pre-loaded checkpoint.""" From f575a4fe6fe7aa41de4445fcfbdd9a3b8481df68 Mon Sep 17 00:00:00 2001 From: Krishna Kalyan Date: Wed, 10 Aug 2022 19:08:30 +0200 Subject: [PATCH 8/8] Update src/pytorch_lightning/callbacks/quantization.py Co-authored-by: otaj <6065855+otaj@users.noreply.github.com> --- src/pytorch_lightning/callbacks/quantization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/callbacks/quantization.py b/src/pytorch_lightning/callbacks/quantization.py index 89c751ec74faf..d89bed0394105 100644 --- a/src/pytorch_lightning/callbacks/quantization.py +++ b/src/pytorch_lightning/callbacks/quantization.py @@ -185,7 +185,7 @@ def __init__( self._observer_type = observer_type if collect_quantization is not None and not ( - isinstance(collect_quantization, int) and callable(collect_quantization) + isinstance(collect_quantization, int) or callable(collect_quantization) ): raise MisconfigurationException( f'Unsupported `collect_quantization` "{collect_quantization}", allowed are `int` or `Callable`.'