Skip to content

Commit 03cbca4

Browse files
authored
Fix type hints of callbacks/finetuning.py (#13516)
1 parent 12b0ec6 commit 03cbca4

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ warn_no_return = "False"
4141
# the list can be generated with:
4242
# mypy | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g' | sed 's|\/|\.|g' | xargs -I {} echo '"{}",'
4343
module = [
44-
"pytorch_lightning.callbacks.finetuning",
4544
"pytorch_lightning.callbacks.model_checkpoint",
4645
"pytorch_lightning.callbacks.progress.rich_progress",
4746
"pytorch_lightning.callbacks.quantization",

src/pytorch_lightning/callbacks/finetuning.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@
3232
log = logging.getLogger(__name__)
3333

3434

35-
def multiplicative(epoch):
36-
return 2
35+
def multiplicative(epoch: int) -> float:
36+
return 2.0
3737

3838

3939
class BaseFinetuning(Callback):
@@ -79,7 +79,7 @@ class BaseFinetuning(Callback):
7979
... )
8080
"""
8181

82-
def __init__(self):
82+
def __init__(self) -> None:
8383
self._internal_optimizer_metadata: Dict[int, List[Dict[str, Any]]] = {}
8484
self._restarting = False
8585

@@ -94,7 +94,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
9494
self._internal_optimizer_metadata = state_dict["internal_optimizer_metadata"]
9595
else:
9696
# compatibility to load from old checkpoints before PR #11887
97-
self._internal_optimizer_metadata = state_dict
97+
self._internal_optimizer_metadata = state_dict # type: ignore[assignment]
9898

9999
def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
100100
# restore the param_groups created during the previous training.
@@ -122,10 +122,11 @@ def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -
122122
modules = modules.values()
123123

124124
if isinstance(modules, Iterable):
125-
_modules = []
125+
_flatten_modules = []
126126
for m in modules:
127-
_modules.extend(BaseFinetuning.flatten_modules(m))
127+
_flatten_modules.extend(BaseFinetuning.flatten_modules(m))
128128

129+
_modules = iter(_flatten_modules)
129130
else:
130131
_modules = modules.modules()
131132

0 commit comments

Comments
 (0)