From f06ddc0db86b80420346f5598c886a70bf1a3e54 Mon Sep 17 00:00:00 2001 From: Cyprien Ricque <48893621+Cyprien-Ricque@users.noreply.github.com> Date: Tue, 5 Jul 2022 13:39:00 +0200 Subject: [PATCH] Revert "fix mypy typing errors in pytorch_lightning/strategies/single_tpu.py (#13534)" This reverts commit 61473c2290a625b83febace61c7071f8584440ef. --- pyproject.toml | 1 + src/pytorch_lightning/strategies/single_tpu.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c08d2c99bf3f5..5667f0824cce8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,7 @@ module = [ "pytorch_lightning.strategies.parallel", "pytorch_lightning.strategies.sharded", "pytorch_lightning.strategies.sharded_spawn", + "pytorch_lightning.strategies.single_tpu", "pytorch_lightning.strategies.tpu_spawn", "pytorch_lightning.strategies.strategy", "pytorch_lightning.profilers.advanced", diff --git a/src/pytorch_lightning/strategies/single_tpu.py b/src/pytorch_lightning/strategies/single_tpu.py index e65078efc67ee..f4d3234e9e695 100644 --- a/src/pytorch_lightning/strategies/single_tpu.py +++ b/src/pytorch_lightning/strategies/single_tpu.py @@ -19,6 +19,7 @@ from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.single_device import SingleDeviceStrategy from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters +from pytorch_lightning.utilities.model_helpers import is_overridden if _TPU_AVAILABLE: import torch_xla.core.xla_model as xm @@ -54,10 +55,13 @@ def is_distributed(self) -> bool: return False def setup(self, trainer: "pl.Trainer") -> None: - assert self.model, "self.model must be set before find_shared_parameters(self.model)" shared_params = find_shared_parameters(self.model) self.model_to_device() - set_shared_parameters(self.model, shared_params) + if is_overridden("on_post_move_to_device", self.lightning_module): + self.model.on_post_move_to_device() + else: + set_shared_parameters(self.model, shared_params) + super().setup(trainer) if self.debug: @@ -66,6 +70,9 @@ def setup(self, trainer: "pl.Trainer") -> None: self.tpu_local_core_rank = xm.get_local_ordinal() self.tpu_global_core_rank = xm.get_ordinal() + def model_to_device(self) -> None: + self.model.to(self.root_device) + @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: strategy_registry.register(