Skip to content

Commit 0ca3b5a

Browse files
krishnakalyan3otajawaelchli
authored
Fix mypy errors attributed to pytorch_lightning.callbacks.quantization (#13782)
Co-authored-by: otaj <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 0b1a29b commit 0ca3b5a

File tree

3 files changed

+29
-22
lines changed

3 files changed

+29
-22
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ warn_no_return = "False"
5050
# mypy --no-error-summary 2>&1 | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g; s|src/||g; s|\/|\.|g' | xargs -I {} echo '"{}",'
5151
module = [
5252
"pytorch_lightning.callbacks.progress.rich_progress",
53-
"pytorch_lightning.callbacks.quantization",
5453
"pytorch_lightning.core.datamodule",
5554
"pytorch_lightning.demos.boring_classes",
5655
"pytorch_lightning.demos.mnist_datamodule",

src/pytorch_lightning/callbacks/quantization.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -41,25 +41,28 @@
4141

4242

4343
def wrap_qat_forward_context(
44-
quant_cb, model: "pl.LightningModule", func: Callable, trigger_condition: Optional[Union[Callable, int]] = None
44+
quant_cb: "QuantizationAwareTraining",
45+
model: "pl.LightningModule",
46+
func: Callable,
47+
trigger_condition: Optional[Union[Callable, int]] = None,
4548
) -> Callable:
4649
"""Decorator to wrap forward path as it is needed to quantize inputs and dequantize outputs for in/out
4750
compatibility Moreover this version has the (de)quantization conditional as it may not be needed for the
4851
training all the time."""
4952
# todo: consider using registering hook before/after forward
5053
@functools.wraps(func)
51-
def wrapper(data) -> Any:
52-
_is_func_true = isinstance(trigger_condition, Callable) and trigger_condition(model.trainer)
54+
def wrapper(data: Any) -> Any:
55+
_is_func_true = callable(trigger_condition) and trigger_condition(model.trainer)
5356
_is_count_true = isinstance(trigger_condition, int) and quant_cb._forward_calls < trigger_condition
5457
_quant_run = trigger_condition is None or _is_func_true or _is_count_true
5558
# apply custom trigger
5659
if _quant_run:
5760
quant_cb._forward_calls += 1
58-
data = model.quant(data)
61+
data = model.quant(data) # type: ignore[operator]
5962
data = func(data)
6063
# apply custom trigger
6164
if _quant_run:
62-
data = model.dequant(data)
65+
data = model.dequant(data) # type: ignore[operator]
6366
return data
6467

6568
return wrapper
@@ -70,10 +73,10 @@ def wrap_quantize_forward_context(model: "pl.LightningModule", func: Callable) -
7073
compatibility."""
7174
# todo: consider using registering hook before/after forward
7275
@functools.wraps(func)
73-
def wrapper(data) -> Any:
74-
data = model.quant(data)
76+
def wrapper(data: Any) -> Any:
77+
data = model.quant(data) # type: ignore[operator]
7578
data = func(data)
76-
data = model.dequant(data)
79+
data = model.dequant(data) # type: ignore[operator]
7780
return data
7881

7982
return wrapper
@@ -181,7 +184,9 @@ def __init__(
181184
)
182185
self._observer_type = observer_type
183186

184-
if collect_quantization is not None and not isinstance(collect_quantization, (int, Callable)):
187+
if collect_quantization is not None and not (
188+
isinstance(collect_quantization, int) or callable(collect_quantization)
189+
):
185190
raise MisconfigurationException(
186191
f'Unsupported `collect_quantization` "{collect_quantization}", allowed are `int` or `Callable`.'
187192
)
@@ -200,8 +205,8 @@ def __init__(
200205
self._observer_disabled_stages = set(self.OBSERVER_STAGES) - observer_enabled_stages
201206

202207
self._forward_calls = 0
203-
self._fake_quant_to_initial_state_dict = {}
204-
self._last_fake_quant_to_observer_enabled = {}
208+
self._fake_quant_to_initial_state_dict: Dict[FakeQuantizeBase, Dict[str, Any]] = {}
209+
self._last_fake_quant_to_observer_enabled: Dict[FakeQuantizeBase, Tensor] = {}
205210
self._module_prepared = False
206211

207212
def _check_feasible_fuse(self, model: "pl.LightningModule") -> bool:
@@ -227,7 +232,7 @@ def _restore_last_observer_enabled(self) -> None:
227232
for fake_quant, observer_enabled in self._last_fake_quant_to_observer_enabled.items():
228233
fake_quant.observer_enabled.copy_(observer_enabled)
229234

230-
def _prepare_model(self, model: torch.nn.Module) -> None:
235+
def _prepare_model(self, model: "pl.LightningModule") -> None:
231236
if self._module_prepared:
232237
return
233238
# QuantStub converts tensors from floating point to quantized
@@ -237,7 +242,7 @@ def _prepare_model(self, model: torch.nn.Module) -> None:
237242
# manually specify where tensors will be converted from quantized
238243
# to floating point in the quantized model
239244
self.__module_forward = model.forward
240-
model.forward = wrap_qat_forward_context(
245+
model.forward = wrap_qat_forward_context( # type: ignore [assignment]
241246
quant_cb=self, model=model, func=model.forward, trigger_condition=self._collect_quantization
242247
)
243248

@@ -247,7 +252,7 @@ def _prepare_model(self, model: torch.nn.Module) -> None:
247252
if self._observer_type == "histogram":
248253
model.qconfig = torch.quantization.get_default_qconfig(self._qconfig)
249254
elif self._observer_type == "average":
250-
extra_kwargs = {}
255+
extra_kwargs: Dict[str, Optional[int]] = {}
251256
if _TORCH_GREATER_EQUAL_1_12:
252257
extra_kwargs["version"] = 0
253258
# version=None corresponds to using FakeQuantize rather than
@@ -258,7 +263,7 @@ def _prepare_model(self, model: torch.nn.Module) -> None:
258263
model.qconfig = torch.quantization.get_default_qat_qconfig(self._qconfig, **extra_kwargs)
259264

260265
elif isinstance(self._qconfig, QConfig):
261-
model.qconfig = self._qconfig
266+
model.qconfig = self._qconfig # type: ignore [assignment]
262267

263268
if self._check_feasible_fuse(model):
264269
fuse_modules(model, self._modules_to_fuse, inplace=True)
@@ -273,12 +278,12 @@ def _prepare_model(self, model: torch.nn.Module) -> None:
273278
}
274279
self._module_prepared = True
275280

276-
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
281+
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
277282
self._prepare_model(pl_module)
278283

279284
def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
280285
if not self._convert_on_fit_end:
281-
pl_module.forward = self.__module_forward
286+
pl_module.forward = self.__module_forward # type: ignore [assignment]
282287
return
283288
pl_module.eval()
284289
# Convert the observed model to a quantized model. This does several things:
@@ -288,9 +293,12 @@ def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") ->
288293
torch.quantization.convert(pl_module, inplace=True)
289294
# check we shall preserve wrapper
290295
if self._input_compatible:
291-
pl_module.forward = wrap_quantize_forward_context(model=pl_module, func=self.__module_forward)
296+
pl_module.forward = wrap_quantize_forward_context( # type: ignore [assignment]
297+
model=pl_module,
298+
func=self.__module_forward,
299+
)
292300
else:
293-
pl_module.forward = self.__module_forward
301+
pl_module.forward = self.__module_forward # type: ignore [assignment]
294302

295303
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
296304
if "train" in self._observer_disabled_stages:
@@ -336,7 +344,7 @@ def state_dict(self) -> Dict[str, Any]:
336344
keys = {"_qconfig", "_observer_type", "_collect_quantization", "_modules_to_fuse", "_input_compatible"}
337345
return {n: getattr(self, n) for n in keys}
338346

339-
def _load_before_model(self, model: torch.nn.Module, state_dict: Dict[str, Any]) -> None:
347+
def _load_before_model(self, model: "pl.LightningModule", state_dict: Dict[str, Any]) -> None:
340348
"""Special hook that gets called by the CheckpointConnector *before* the model gets loaded.
341349
342350
This hook replaces the :meth:`on_load_checkpoint` and :meth:`load_state_dict` callback methods which get called

src/pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def _restore_quantization_callbacks(self) -> None:
245245
if state:
246246
# The Quantization callbacks have a special method that must be called before restoring the weights
247247
# of the model
248-
callback._load_before_model(self.trainer.model, deepcopy(state))
248+
callback._load_before_model(self.trainer.lightning_module, deepcopy(state))
249249

250250
def restore_callbacks(self) -> None:
251251
"""Restores all callbacks from the pre-loaded checkpoint."""

0 commit comments

Comments
 (0)