diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index ab846a562a78b..3a3a409a2d7da 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -13,7 +13,7 @@ # limitations under the License. import contextlib from collections import defaultdict -from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union +from typing import Any, Callable, DefaultDict, Dict, Generator, Iterable, List, Optional, Union import torch from torch import Tensor @@ -114,7 +114,7 @@ def pre_dispatch(self, trainer: 'pl.Trainer') -> None: def _move_optimizer_state(self) -> None: """ Moves the state of the optimizers to the GPU if needed. """ for opt in self.optimizers: - state = defaultdict(dict) + state: DefaultDict = defaultdict(dict) for p, v in opt.state.items(): state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device) opt.state = state