diff --git a/pyproject.toml b/pyproject.toml index ddc903d6af9d7..7446af9579578 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,6 @@ warn_no_return = "False" # the list can be generated with: # mypy | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g' | sed 's|\/|\.|g' | xargs -I {} echo '"{}",' module = [ - "pytorch_lightning.callbacks.finetuning", "pytorch_lightning.callbacks.model_checkpoint", "pytorch_lightning.callbacks.progress.rich_progress", "pytorch_lightning.callbacks.quantization", diff --git a/src/pytorch_lightning/callbacks/finetuning.py b/src/pytorch_lightning/callbacks/finetuning.py index 4a7067f56c697..ad45ff8b0591b 100644 --- a/src/pytorch_lightning/callbacks/finetuning.py +++ b/src/pytorch_lightning/callbacks/finetuning.py @@ -32,8 +32,8 @@ log = logging.getLogger(__name__) -def multiplicative(epoch): - return 2 +def multiplicative(epoch: int) -> float: + return 2.0 class BaseFinetuning(Callback): @@ -79,7 +79,7 @@ class BaseFinetuning(Callback): ... ) """ - def __init__(self): + def __init__(self) -> None: self._internal_optimizer_metadata: Dict[int, List[Dict[str, Any]]] = {} self._restarting = False @@ -94,7 +94,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self._internal_optimizer_metadata = state_dict["internal_optimizer_metadata"] else: # compatibility to load from old checkpoints before PR #11887 - self._internal_optimizer_metadata = state_dict + self._internal_optimizer_metadata = state_dict # type: ignore[assignment] def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: # restore the param_groups created during the previous training. @@ -122,10 +122,11 @@ def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) - modules = modules.values() if isinstance(modules, Iterable): - _modules = [] + _flatten_modules = [] for m in modules: - _modules.extend(BaseFinetuning.flatten_modules(m)) + _flatten_modules.extend(BaseFinetuning.flatten_modules(m)) + _modules = iter(_flatten_modules) else: _modules = modules.modules()