diff --git a/pyproject.toml b/pyproject.toml index c266e0684e974..2471be131c41a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -120,7 +120,6 @@ module = [ "pytorch_lightning.trainer.connectors.callback_connector", "pytorch_lightning.trainer.connectors.checkpoint_connector", "pytorch_lightning.trainer.connectors.data_connector", - "pytorch_lightning.trainer.connectors.logger_connector.result", "pytorch_lightning.trainer.data_loading", "pytorch_lightning.trainer.optimizers", "pytorch_lightning.trainer.supporters", diff --git a/pytorch_lightning/core/mixins/device_dtype_mixin.py b/pytorch_lightning/core/mixins/device_dtype_mixin.py index e8b122989cd9c..d902958b9bc40 100644 --- a/pytorch_lightning/core/mixins/device_dtype_mixin.py +++ b/pytorch_lightning/core/mixins/device_dtype_mixin.py @@ -17,6 +17,16 @@ import torch from torch.nn import Module +try: + from typing_extensions import Self +except ImportError: + # workaround for Python 3.6 and 3.7. + # see https://www.python.org/dev/peps/pep-0673/ + from typing import TypeVar + + Self = TypeVar("TDeviceDtypeModuleMixin", bound="DeviceDtypeModuleMixin") + + import pytorch_lightning as pl @@ -47,7 +57,7 @@ def device(self) -> Union[str, torch.device]: return device - def to(self, *args: Any, **kwargs: Any) -> "DeviceDtypeModuleMixin": + def to(self, *args: Any, **kwargs: Any) -> Self: """Moves and/or casts the parameters and buffers. This can be called as @@ -110,7 +120,7 @@ def to(self, *args: Any, **kwargs: Any) -> "DeviceDtypeModuleMixin": self.__update_properties(device=out[0], dtype=out[1]) return super().to(*args, **kwargs) - def cuda(self, device: Optional[Union[torch.device, int]] = None) -> "DeviceDtypeModuleMixin": + def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self: """Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized. @@ -127,7 +137,7 @@ def cuda(self, device: Optional[Union[torch.device, int]] = None) -> "DeviceDtyp self.__update_properties(device=device) return super().cuda(device=device) - def cpu(self) -> "DeviceDtypeModuleMixin": + def cpu(self) -> Self: """Moves all model parameters and buffers to the CPU. Returns: @@ -136,7 +146,7 @@ def cpu(self) -> "DeviceDtypeModuleMixin": self.__update_properties(device=torch.device("cpu")) return super().cpu() - def type(self, dst_type: Union[str, torch.dtype]) -> "DeviceDtypeModuleMixin": + def type(self, dst_type: Union[str, torch.dtype]) -> Self: """Casts all parameters and buffers to :attr:`dst_type`. Arguments: @@ -148,7 +158,7 @@ def type(self, dst_type: Union[str, torch.dtype]) -> "DeviceDtypeModuleMixin": self.__update_properties(dtype=dst_type) return super().type(dst_type=dst_type) - def float(self) -> "DeviceDtypeModuleMixin": + def float(self) -> Self: """Casts all floating point parameters and buffers to ``float`` datatype. Returns: @@ -157,7 +167,7 @@ def float(self) -> "DeviceDtypeModuleMixin": self.__update_properties(dtype=torch.float) return super().float() - def double(self) -> "DeviceDtypeModuleMixin": + def double(self) -> Self: """Casts all floating point parameters and buffers to ``double`` datatype. Returns: @@ -166,7 +176,7 @@ def double(self) -> "DeviceDtypeModuleMixin": self.__update_properties(dtype=torch.double) return super().double() - def half(self) -> "DeviceDtypeModuleMixin": + def half(self) -> Self: """Casts all floating point parameters and buffers to ``half`` datatype. Returns: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index e10360a5fb564..1c27b75854d96 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -211,8 +211,10 @@ def __init__(self, metadata: _Metadata, is_tensor: bool) -> None: self.add_state("value", torch.tensor(0.0), dist_reduce_fx=torch.sum) if self.meta.is_mean_reduction: self.add_state("cumulated_batch_size", torch.tensor(0), dist_reduce_fx=torch.sum) + # this is defined here only because upstream is missing the type annotation + self._forward_cache: Optional[Any] = None - def update(self, value: _IN_METRIC, batch_size: int) -> None: + def update(self, value: _IN_METRIC, batch_size: int) -> None: # type: ignore[override] if self.is_tensor: if not torch.is_floating_point(value): dtype = torch.get_default_dtype() @@ -225,16 +227,17 @@ def update(self, value: _IN_METRIC, batch_size: int) -> None: if self.meta.on_step: self._forward_cache = self.meta.sync(value.clone()) # `clone` because `sync` is in-place - - # performance: no need to accumulate on values only logged on_step - if not self.meta.on_epoch: - self.value = self._forward_cache - return + # performance: no need to accumulate on values only logged on_step + if not self.meta.on_epoch: + self.value = self._forward_cache + return # perform accumulation with reduction if self.meta.is_mean_reduction: self.value += value.mean() * batch_size - self.cumulated_batch_size += batch_size + # `Metric.add_state` does not work well with mypy, mypy doesn't know this is a `Tensor` + # we could add an assertion, but this is a hot code path + self.cumulated_batch_size += batch_size # type: ignore[operator] elif self.meta.is_max_reduction or self.meta.is_min_reduction: self.value = self.meta.reduce_fx(self.value, value.mean()) elif self.meta.is_sum_reduction: diff --git a/requirements.txt b/requirements.txt index 34879d9290acb..94b7151d73641 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,4 @@ tensorboard>=2.2.0 torchmetrics>=0.4.1 pyDeprecate==0.3.1 packaging>=17.0 -typing-extensions +typing-extensions>=4.0.0