Skip to content

Commit e534e9c

Browse files
init commit
1 parent 2845e75 commit e534e9c

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ warn_no_return = "False"
4949
module = [
5050
"pytorch_lightning.callbacks.model_checkpoint",
5151
"pytorch_lightning.callbacks.progress.rich_progress",
52-
"pytorch_lightning.callbacks.quantization",
5352
"pytorch_lightning.callbacks.stochastic_weight_avg",
5453
"pytorch_lightning.core.datamodule",
5554
"pytorch_lightning.core.decorators",

src/pytorch_lightning/callbacks/quantization.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,14 @@
4141

4242

4343
def wrap_qat_forward_context(
44-
quant_cb, model: "pl.LightningModule", func: Callable, trigger_condition: Optional[Union[Callable, int]] = None
44+
quant_cb: Any, model: "pl.LightningModule", func: Callable, trigger_condition: Optional[Union[Callable, int]] = None
4545
) -> Callable:
4646
"""Decorator to wrap forward path as it is needed to quantize inputs and dequantize outputs for in/out
4747
compatibility Moreover this version has the (de)quantization conditional as it may not be needed for the
4848
training all the time."""
4949
# todo: consider using registering hook before/after forward
5050
@functools.wraps(func)
51-
def wrapper(data) -> Any:
51+
def wrapper(data: Any) -> Any:
5252
_is_func_true = isinstance(trigger_condition, Callable) and trigger_condition(model.trainer)
5353
_is_count_true = isinstance(trigger_condition, int) and quant_cb._forward_calls < trigger_condition
5454
_quant_run = trigger_condition is None or _is_func_true or _is_count_true
@@ -200,8 +200,8 @@ def __init__(
200200
self._observer_disabled_stages = set(self.OBSERVER_STAGES) - observer_enabled_stages
201201

202202
self._forward_calls = 0
203-
self._fake_quant_to_initial_state_dict = {}
204-
self._last_fake_quant_to_observer_enabled = {}
203+
self._fake_quant_to_initial_state_dict: Dict[FakeQuantizeBase, Tensor] = {}
204+
self._last_fake_quant_to_observer_enabled: Dict[FakeQuantizeBase, Tensor] = {}
205205
self._module_prepared = False
206206

207207
def _check_feasible_fuse(self, model: "pl.LightningModule") -> bool:
@@ -273,7 +273,7 @@ def _prepare_model(self, model: torch.nn.Module) -> None:
273273
}
274274
self._module_prepared = True
275275

276-
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
276+
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
277277
self._prepare_model(pl_module)
278278

279279
def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:

0 commit comments

Comments
 (0)