Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ warn_no_return = "False"
# mypy --no-error-summary 2>&1 | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g; s|src/||g; s|\/|\.|g' | xargs -I {} echo '"{}",'
module = [
"pytorch_lightning.callbacks.progress.rich_progress",
"pytorch_lightning.callbacks.quantization",
"pytorch_lightning.core.datamodule",
"pytorch_lightning.core.module",
"pytorch_lightning.demos.boring_classes",
Expand Down
48 changes: 28 additions & 20 deletions src/pytorch_lightning/callbacks/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,28 @@


def wrap_qat_forward_context(
quant_cb, model: "pl.LightningModule", func: Callable, trigger_condition: Optional[Union[Callable, int]] = None
quant_cb: "QuantizationAwareTraining",
model: "pl.LightningModule",
func: Callable,
trigger_condition: Optional[Union[Callable, int]] = None,
) -> Callable:
"""Decorator to wrap forward path as it is needed to quantize inputs and dequantize outputs for in/out
compatibility Moreover this version has the (de)quantization conditional as it may not be needed for the
training all the time."""
# todo: consider using registering hook before/after forward
@functools.wraps(func)
def wrapper(data) -> Any:
_is_func_true = isinstance(trigger_condition, Callable) and trigger_condition(model.trainer)
def wrapper(data: Any) -> Any:
_is_func_true = callable(trigger_condition) and trigger_condition(model.trainer)
_is_count_true = isinstance(trigger_condition, int) and quant_cb._forward_calls < trigger_condition
_quant_run = trigger_condition is None or _is_func_true or _is_count_true
# apply custom trigger
if _quant_run:
quant_cb._forward_calls += 1
data = model.quant(data)
data = model.quant(data) # type: ignore[operator]
data = func(data)
# apply custom trigger
if _quant_run:
data = model.dequant(data)
data = model.dequant(data) # type: ignore[operator]
return data

return wrapper
Expand All @@ -70,10 +73,10 @@ def wrap_quantize_forward_context(model: "pl.LightningModule", func: Callable) -
compatibility."""
# todo: consider using registering hook before/after forward
@functools.wraps(func)
def wrapper(data) -> Any:
data = model.quant(data)
def wrapper(data: Any) -> Any:
data = model.quant(data) # type: ignore[operator]
data = func(data)
data = model.dequant(data)
data = model.dequant(data) # type: ignore[operator]
return data

return wrapper
Expand Down Expand Up @@ -181,7 +184,9 @@ def __init__(
)
self._observer_type = observer_type

if collect_quantization is not None and not isinstance(collect_quantization, (int, Callable)):
if collect_quantization is not None and not (
isinstance(collect_quantization, int) or callable(collect_quantization)
):
raise MisconfigurationException(
f'Unsupported `collect_quantization` "{collect_quantization}", allowed are `int` or `Callable`.'
)
Expand All @@ -200,8 +205,8 @@ def __init__(
self._observer_disabled_stages = set(self.OBSERVER_STAGES) - observer_enabled_stages

self._forward_calls = 0
self._fake_quant_to_initial_state_dict = {}
self._last_fake_quant_to_observer_enabled = {}
self._fake_quant_to_initial_state_dict: Dict[FakeQuantizeBase, Dict[str, Any]] = {}
self._last_fake_quant_to_observer_enabled: Dict[FakeQuantizeBase, Tensor] = {}
self._module_prepared = False

def _check_feasible_fuse(self, model: "pl.LightningModule") -> bool:
Expand All @@ -227,7 +232,7 @@ def _restore_last_observer_enabled(self) -> None:
for fake_quant, observer_enabled in self._last_fake_quant_to_observer_enabled.items():
fake_quant.observer_enabled.copy_(observer_enabled)

def _prepare_model(self, model: torch.nn.Module) -> None:
def _prepare_model(self, model: "pl.LightningModule") -> None:
if self._module_prepared:
return
# QuantStub converts tensors from floating point to quantized
Expand All @@ -237,7 +242,7 @@ def _prepare_model(self, model: torch.nn.Module) -> None:
# manually specify where tensors will be converted from quantized
# to floating point in the quantized model
self.__module_forward = model.forward
model.forward = wrap_qat_forward_context(
model.forward = wrap_qat_forward_context( # type: ignore [assignment]
quant_cb=self, model=model, func=model.forward, trigger_condition=self._collect_quantization
)

Expand All @@ -247,7 +252,7 @@ def _prepare_model(self, model: torch.nn.Module) -> None:
if self._observer_type == "histogram":
model.qconfig = torch.quantization.get_default_qconfig(self._qconfig)
elif self._observer_type == "average":
extra_kwargs = {}
extra_kwargs: Dict[str, Optional[int]] = {}
if _TORCH_GREATER_EQUAL_1_12:
extra_kwargs["version"] = 0
# version=None corresponds to using FakeQuantize rather than
Expand All @@ -258,7 +263,7 @@ def _prepare_model(self, model: torch.nn.Module) -> None:
model.qconfig = torch.quantization.get_default_qat_qconfig(self._qconfig, **extra_kwargs)

elif isinstance(self._qconfig, QConfig):
model.qconfig = self._qconfig
model.qconfig = self._qconfig # type: ignore [assignment]

if self._check_feasible_fuse(model):
fuse_modules(model, self._modules_to_fuse, inplace=True)
Expand All @@ -273,12 +278,12 @@ def _prepare_model(self, model: torch.nn.Module) -> None:
}
self._module_prepared = True

def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._prepare_model(pl_module)

def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not self._convert_on_fit_end:
pl_module.forward = self.__module_forward
pl_module.forward = self.__module_forward # type: ignore [assignment]
return
pl_module.eval()
# Convert the observed model to a quantized model. This does several things:
Expand All @@ -288,9 +293,12 @@ def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") ->
torch.quantization.convert(pl_module, inplace=True)
# check we shall preserve wrapper
if self._input_compatible:
pl_module.forward = wrap_quantize_forward_context(model=pl_module, func=self.__module_forward)
pl_module.forward = wrap_quantize_forward_context( # type: ignore [assignment]
model=pl_module,
func=self.__module_forward,
)
else:
pl_module.forward = self.__module_forward
pl_module.forward = self.__module_forward # type: ignore [assignment]

def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if "train" in self._observer_disabled_stages:
Expand Down Expand Up @@ -336,7 +344,7 @@ def state_dict(self) -> Dict[str, Any]:
keys = {"_qconfig", "_observer_type", "_collect_quantization", "_modules_to_fuse", "_input_compatible"}
return {n: getattr(self, n) for n in keys}

def _load_before_model(self, model: torch.nn.Module, state_dict: Dict[str, Any]) -> None:
def _load_before_model(self, model: "pl.LightningModule", state_dict: Dict[str, Any]) -> None:
"""Special hook that gets called by the CheckpointConnector *before* the model gets loaded.

This hook replaces the :meth:`on_load_checkpoint` and :meth:`load_state_dict` callback methods which get called
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def _restore_quantization_callbacks(self) -> None:
if state:
# The Quantization callbacks have a special method that must be called before restoring the weights
# of the model
callback._load_before_model(self.trainer.model, deepcopy(state))
callback._load_before_model(self.trainer.lightning_module, deepcopy(state))

def restore_callbacks(self) -> None:
"""Restores all callbacks from the pre-loaded checkpoint."""
Expand Down