Skip to content

Commit f49c2e3

Browse files
fixes mypy errors in trainer/supporters.py (#14633)
* fixes mypy errors in trainer/supporters.py * Fxes mypy error when accessing "__init__" directly * add an assertion in lr_finder.py * Make init calls `reset` in `TensorRunningAccum` * Fixes formatting * Add `self.window_length` to `__init__` Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent cd67124 commit f49c2e3

File tree

3 files changed

+41
-32
lines changed

3 files changed

+41
-32
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ module = [
5454
"pytorch_lightning.callbacks.progress.rich_progress",
5555
"pytorch_lightning.profilers.base",
5656
"pytorch_lightning.profilers.pytorch",
57-
"pytorch_lightning.trainer.supporters",
5857
"pytorch_lightning.trainer.trainer",
5958
"pytorch_lightning.tuner.batch_size_scaling",
6059
"pytorch_lightning.utilities.data",

src/pytorch_lightning/trainer/supporters.py

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from collections.abc import Sized
1516
from dataclasses import asdict, dataclass, field
1617
from typing import Any, Callable, Dict, Iterable, Iterator, List, Mapping, Optional, Sequence, Union
1718

@@ -53,23 +54,24 @@ class TensorRunningAccum:
5354

5455
def __init__(self, window_length: int):
5556
self.window_length = window_length
56-
self.memory = None
57-
self.current_idx: int = 0
58-
self.last_idx: Optional[int] = None
59-
self.rotated: bool = False
57+
self.reset(window_length)
6058

6159
def reset(self, window_length: Optional[int] = None) -> None:
6260
"""Empty the accumulator."""
63-
if window_length is None:
64-
window_length = self.window_length
65-
self.__init__(window_length)
61+
if window_length is not None:
62+
self.window_length = window_length
63+
self.memory: Optional[torch.Tensor] = None
64+
self.current_idx: int = 0
65+
self.last_idx: Optional[int] = None
66+
self.rotated: bool = False
6667

67-
def last(self):
68+
def last(self) -> Optional[torch.Tensor]:
6869
"""Get the last added element."""
6970
if self.last_idx is not None:
71+
assert isinstance(self.memory, torch.Tensor)
7072
return self.memory[self.last_idx].float()
7173

72-
def append(self, x):
74+
def append(self, x: torch.Tensor) -> None:
7375
"""Add an element to the accumulator."""
7476
if self.memory is None:
7577
# tradeoff memory for speed by keeping the memory on device
@@ -88,20 +90,21 @@ def append(self, x):
8890
if self.current_idx == 0:
8991
self.rotated = True
9092

91-
def mean(self):
93+
def mean(self) -> Optional[torch.Tensor]:
9294
"""Get mean value from stored elements."""
9395
return self._agg_memory("mean")
9496

95-
def max(self):
97+
def max(self) -> Optional[torch.Tensor]:
9698
"""Get maximal value from stored elements."""
9799
return self._agg_memory("max")
98100

99-
def min(self):
101+
def min(self) -> Optional[torch.Tensor]:
100102
"""Get minimal value from stored elements."""
101103
return self._agg_memory("min")
102104

103-
def _agg_memory(self, how: str):
105+
def _agg_memory(self, how: str) -> Optional[torch.Tensor]:
104106
if self.last_idx is not None:
107+
assert isinstance(self.memory, torch.Tensor)
105108
if self.rotated:
106109
return getattr(self.memory.float(), how)()
107110
return getattr(self.memory[: self.current_idx].float(), how)()
@@ -139,7 +142,7 @@ def done(self) -> bool:
139142
class CycleIterator:
140143
"""Iterator for restarting a dataloader if it runs out of samples."""
141144

142-
def __init__(self, loader: Any, length: Optional[int] = None, state: SharedCycleIteratorState = None):
145+
def __init__(self, loader: Any, length: Optional[Union[int, float]] = None, state: SharedCycleIteratorState = None):
143146
"""
144147
Args:
145148
loader: the loader to restart for cyclic (and optionally infinite) sampling
@@ -184,6 +187,8 @@ def __next__(self) -> Any:
184187
Raises:
185188
StopIteration: if more then :attr:`length` batches have been returned
186189
"""
190+
assert isinstance(self._loader_iter, Iterator)
191+
187192
# Note: if self.length is `inf`, then the iterator will never stop
188193
if self.counter >= self.__len__() or self.state.done:
189194
raise StopIteration
@@ -257,13 +262,13 @@ def _calc_num_data(self, datasets: Union[Sequence, Mapping], mode: str) -> Union
257262
Returns:
258263
length: the length of `CombinedDataset`
259264
"""
260-
if mode not in CombinedDataset.COMPUTE_FUNCS.keys():
265+
if mode not in self.COMPUTE_FUNCS.keys():
261266
raise MisconfigurationException(f"Invalid Mode: {mode}")
262267

263268
# extract the lengths
264269
all_lengths = self._get_len_recursive(datasets)
265270

266-
compute_func = CombinedDataset.COMPUTE_FUNCS[mode]
271+
compute_func = self.COMPUTE_FUNCS[mode]
267272

268273
if isinstance(all_lengths, (int, float)):
269274
length = all_lengths
@@ -272,8 +277,9 @@ def _calc_num_data(self, datasets: Union[Sequence, Mapping], mode: str) -> Union
272277

273278
return length
274279

275-
def _get_len_recursive(self, data) -> int:
280+
def _get_len_recursive(self, data: Any) -> Union[int, float, List, Dict]:
276281
if isinstance(data, Dataset):
282+
assert isinstance(data, Sized)
277283
return len(data)
278284

279285
if isinstance(data, (float, int)):
@@ -290,13 +296,13 @@ def _get_len_recursive(self, data) -> int:
290296
return self._get_len(data)
291297

292298
@staticmethod
293-
def _get_len(dataset) -> int:
299+
def _get_len(dataset: Any) -> Union[int, float]:
294300
try:
295301
return len(dataset)
296302
except (TypeError, NotImplementedError):
297303
return float("inf")
298304

299-
def __len__(self) -> int:
305+
def __len__(self) -> Union[int, float]:
300306
"""Return the minimum length of the datasets."""
301307
return self._calc_num_data(self.datasets, self.mode)
302308

@@ -348,8 +354,8 @@ def __init__(self, loaders: Any, mode: str = "min_size"):
348354
if self.mode == "max_size_cycle":
349355
self._wrap_loaders_max_size_cycle()
350356

351-
self._loaders_iter_state_dict = None
352-
self._iterator = None # assigned in __iter__
357+
self._loaders_iter_state_dict: Optional[Dict] = None
358+
self._iterator: Optional[Iterator] = None # assigned in __iter__
353359

354360
@staticmethod
355361
def _state_dict_fn(iterator: Optional[Iterator], has_completed: int) -> Dict:
@@ -384,7 +390,7 @@ def state_dict(self, has_completed: bool = False) -> Dict:
384390
has_completed=has_completed,
385391
)
386392

387-
def load_state_dict(self, state_dict) -> None:
393+
def load_state_dict(self, state_dict: Dict) -> None:
388394
# store the samplers state.
389395
# They would be reloaded once the `CombinedIterator` as been created
390396
# and the workers are created.
@@ -482,18 +488,18 @@ def __iter__(self) -> Any:
482488

483489
# prevent `NotImplementedError` from PyTorch:
484490
# https://github.com/pytorch/pytorch/blob/v1.9.0/torch/utils/data/dataloader.py#L541
485-
def __getstate__patch__(*_):
491+
def __getstate__patch__(*_: Any) -> Dict:
486492
return {}
487493

488-
_BaseDataLoaderIter.__getstate__ = __getstate__patch__
494+
_BaseDataLoaderIter.__getstate__ = __getstate__patch__ # type: ignore[assignment]
489495
iterator = CombinedLoaderIterator(self.loaders)
490496
# handle fault tolerant restart logic.
491497
self.on_restart(iterator)
492498
self._iterator = iterator
493499
return iterator
494500

495501
@staticmethod
496-
def _calc_num_batches(loaders: Any, mode="min_size") -> Union[int, float]:
502+
def _calc_num_batches(loaders: Any, mode: str = "min_size") -> Union[int, float]:
497503
"""Compute the length (aka the number of batches) of `CombinedLoader`.
498504
499505
Args:
@@ -509,16 +515,16 @@ def _calc_num_batches(loaders: Any, mode="min_size") -> Union[int, float]:
509515
return all_lengths
510516
return _nested_calc_num_data(all_lengths, max if mode == "max_size_cycle" else min)
511517

512-
def __len__(self) -> int:
518+
def __len__(self) -> Union[int, float]:
513519
return self._calc_num_batches(self.loaders, mode=self.mode)
514520

515521
@staticmethod
516-
def _shutdown_workers_and_reset_iterator(dataloader) -> None:
522+
def _shutdown_workers_and_reset_iterator(dataloader: DataLoader) -> None:
517523
if hasattr(dataloader, "_iterator") and isinstance(dataloader._iterator, _MultiProcessingDataLoaderIter):
518524
dataloader._iterator._shutdown_workers()
519525
dataloader._iterator = None
520526

521-
def reset(self):
527+
def reset(self) -> None:
522528
if self._iterator:
523529
self._iterator._loader_iters = None
524530
if self.loaders is not None:
@@ -535,7 +541,7 @@ def __init__(self, loaders: Any):
535541
loaders: the loaders to sample from. Can be all kind of collection
536542
"""
537543
self.loaders = loaders
538-
self._loader_iters = None
544+
self._loader_iters: Any = None
539545

540546
@property
541547
def loader_iters(self) -> Any:
@@ -584,7 +590,9 @@ def create_loader_iters(
584590
return apply_to_collection(loaders, Iterable, iter, wrong_dtype=(Sequence, Mapping))
585591

586592

587-
def _nested_calc_num_data(data: Union[Mapping, Sequence], compute_func: Callable):
593+
def _nested_calc_num_data(
594+
data: Union[Mapping, Sequence], compute_func: Callable[[List[Union[int, float]]], Union[int, float]]
595+
) -> Union[int, float]:
588596

589597
if isinstance(data, (float, int)):
590598
return data

src/pytorch_lightning/tuner/lr_finder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,9 @@ def on_train_batch_end(
356356
if self.progress_bar:
357357
self.progress_bar.update()
358358

359-
current_loss = trainer.fit_loop.running_loss.last().item()
359+
loss_tensor = trainer.fit_loop.running_loss.last()
360+
assert loss_tensor is not None
361+
current_loss = loss_tensor.item()
360362
current_step = trainer.global_step
361363

362364
# Avg loss (loss with momentum) + smoothing

0 commit comments

Comments
 (0)