From 2ad3d290dc4298d1478fd5a5f7a4acdfb84e0cc2 Mon Sep 17 00:00:00 2001 From: Cyprien-Ricque Date: Mon, 4 Jul 2022 11:30:11 +0200 Subject: [PATCH 1/2] fix typing in strategies/single_tpu.py --- pyproject.toml | 1 - src/pytorch_lightning/strategies/single_tpu.py | 10 ++-------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 51781d4953935..f8630957befd8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,7 +75,6 @@ module = [ "pytorch_lightning.strategies.sharded", "pytorch_lightning.strategies.sharded_spawn", "pytorch_lightning.strategies.single_device", - "pytorch_lightning.strategies.single_tpu", "pytorch_lightning.strategies.tpu_spawn", "pytorch_lightning.strategies.strategy", "pytorch_lightning.profilers.advanced", diff --git a/src/pytorch_lightning/strategies/single_tpu.py b/src/pytorch_lightning/strategies/single_tpu.py index f4d3234e9e695..ba9b654b91796 100644 --- a/src/pytorch_lightning/strategies/single_tpu.py +++ b/src/pytorch_lightning/strategies/single_tpu.py @@ -55,13 +55,10 @@ def is_distributed(self) -> bool: return False def setup(self, trainer: "pl.Trainer") -> None: + assert self.model, "self.model must be set before find_shared_parameters(self.model)" shared_params = find_shared_parameters(self.model) self.model_to_device() - if is_overridden("on_post_move_to_device", self.lightning_module): - self.model.on_post_move_to_device() - else: - set_shared_parameters(self.model, shared_params) - + set_shared_parameters(self.model, shared_params) super().setup(trainer) if self.debug: @@ -70,9 +67,6 @@ def setup(self, trainer: "pl.Trainer") -> None: self.tpu_local_core_rank = xm.get_local_ordinal() self.tpu_global_core_rank = xm.get_ordinal() - def model_to_device(self) -> None: - self.model.to(self.root_device) - @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: strategy_registry.register( From 36e21c97f7e1309e47449d00a5cde4e47cc4691e Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 5 Jul 2022 08:19:24 +0200 Subject: [PATCH 2/2] remove unused import --- src/pytorch_lightning/strategies/single_tpu.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/pytorch_lightning/strategies/single_tpu.py b/src/pytorch_lightning/strategies/single_tpu.py index ba9b654b91796..e65078efc67ee 100644 --- a/src/pytorch_lightning/strategies/single_tpu.py +++ b/src/pytorch_lightning/strategies/single_tpu.py @@ -19,7 +19,6 @@ from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.single_device import SingleDeviceStrategy from pytorch_lightning.utilities import _TPU_AVAILABLE, find_shared_parameters, set_shared_parameters -from pytorch_lightning.utilities.model_helpers import is_overridden if _TPU_AVAILABLE: import torch_xla.core.xla_model as xm