diff --git a/dockers/tpu-tests/tpu_test_cases.jsonnet b/dockers/tpu-tests/tpu_test_cases.jsonnet index bf2aab9317fcf..e2fdeac6a0f02 100644 --- a/dockers/tpu-tests/tpu_test_cases.jsonnet +++ b/dockers/tpu-tests/tpu_test_cases.jsonnet @@ -42,7 +42,8 @@ local tputests = base.BaseTest { profilers/test_xla_profiler.py \ accelerators/test_tpu.py \ models/test_tpu.py \ - plugins/environments/test_xla_environment.py + plugins/environments/test_xla_environment.py \ + utilities/test_xla_device_utils.py test_exit_code=$? echo "\n||| END PYTEST LOGS |||\n" coverage xml diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 340b6ae15e2ce..5015582ded220 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -272,6 +272,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed deprecated `get_progress_bar_dict` property from `LightningModule` ([#12839](https://github.com/PyTorchLightning/pytorch-lightning/pull/12839)) + - Removed sanity check for multi-optimizer support with habana backends ([#13217](https://github.com/PyTorchLightning/pytorch-lightning/pull/13217)) @@ -302,6 +303,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed deprecated `LightningModule.on_post_move_to_device` ([#13548](https://github.com/Lightning-AI/lightning/pull/13548)) +- Removed `TPUSpawnStrategy.{tpu_local_core_rank,tpu_global_core_rank}` attributes in favor of `TPUSpawnStrategy.{local_rank,global_rank}` ([#11163](https://github.com/PyTorchLightning/pytorch-lightning/pull/11163)) + + +- Removed `SingleTPUStrategy.{tpu_local_core_rank,tpu_global_core_rank}` attributes in favor of `SingleTPUStrategy.{local_rank,global_rank}`([#11163](https://github.com/PyTorchLightning/pytorch-lightning/pull/11163)) + + + ### Fixed diff --git a/src/pytorch_lightning/strategies/single_tpu.py b/src/pytorch_lightning/strategies/single_tpu.py index e65078efc67ee..e24f3a6732570 100644 --- a/src/pytorch_lightning/strategies/single_tpu.py +++ b/src/pytorch_lightning/strategies/single_tpu.py @@ -44,10 +44,7 @@ def __init__( checkpoint_io=checkpoint_io, precision_plugin=precision_plugin, ) - self.debug = debug - self.tpu_local_core_rank = 0 - self.tpu_global_core_rank = 0 @property def is_distributed(self) -> bool: @@ -63,9 +60,6 @@ def setup(self, trainer: "pl.Trainer") -> None: if self.debug: os.environ["PT_XLA_DEBUG"] = str(1) - self.tpu_local_core_rank = xm.get_local_ordinal() - self.tpu_global_core_rank = xm.get_ordinal() - @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: strategy_registry.register( diff --git a/src/pytorch_lightning/strategies/tpu_spawn.py b/src/pytorch_lightning/strategies/tpu_spawn.py index 8f818a34bfcc7..178fe638cc0a3 100644 --- a/src/pytorch_lightning/strategies/tpu_spawn.py +++ b/src/pytorch_lightning/strategies/tpu_spawn.py @@ -72,8 +72,6 @@ def __init__( precision_plugin=precision_plugin, ) self.debug = debug - self.tpu_local_core_rank = 0 - self.tpu_global_core_rank = 0 self.start_method = "fork" @property @@ -152,12 +150,6 @@ def process_dataloader(self, dataloader: DataLoader) -> MpDeviceLoader: def configure_ddp(self) -> None: pass - def init_dist_connection(self, global_rank: int, world_size: int) -> None: - pass - - def set_world_ranks(self, process_idx: int = 0) -> None: - pass - def model_to_device(self) -> None: self.model = self.wrapped_model.to(self.root_device) @@ -203,9 +195,7 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ def _worker_setup(self, process_idx: int): reset_seed() - self._local_rank = xm.get_local_ordinal() - self.tpu_local_core_rank = xm.get_local_ordinal() - self.tpu_global_core_rank = xm.get_ordinal() + self.set_world_ranks(process_idx) rank_zero_only.rank = self.global_rank def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: @@ -237,7 +227,7 @@ def _pod_progress_bar_force_stdout(self) -> None: # from different vms to the main worker doesn't work well with tqdm # Ref: https://github.com/pytorch/xla/blob/master/torch_xla/distributed/xla_dist.py#L140 # The print statement seems to force tqdm to flush stdout. - if self.tpu_global_core_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1: + if self.global_rank == 0 and int(os.getenv(xenv.TPUVM_MODE, 0)) == 1: print() def save_checkpoint( @@ -276,6 +266,10 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo tensor = tensor.unsqueeze(0) return xm.all_gather(tensor) + def teardown(self) -> None: + super().teardown() + os.environ.pop("PT_XLA_DEBUG", None) + @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: strategy_registry.register( @@ -287,7 +281,3 @@ def register_strategies(cls, strategy_registry: Dict) -> None: cls, description=f"{cls.__class__.__name__}", ) - - def teardown(self) -> None: - super().teardown() - os.environ.pop("PT_XLA_DEBUG", None)