Skip to content

Commit 2a5d05b

Browse files
authored
Fix tpu spawn plugin test (#11131)
1 parent 3cc69f9 commit 2a5d05b

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

dockers/tpu-tests/tpu_test_cases.jsonnet

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ local tputests = base.BaseTest {
3333
echo $KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS
3434
export XRT_TPU_CONFIG="tpu_worker;0;${KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS:7}"
3535
coverage run --source=pytorch_lightning -m pytest -v --capture=no \
36+
tests/plugins/test_tpu_spawn.py \
3637
tests/profiler/test_xla_profiler.py \
3738
pytorch_lightning/utilities/xla_device.py \
3839
tests/accelerators/test_tpu.py \

tests/plugins/test_tpu_spawn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,20 +85,20 @@ def test_error_process_iterable_dataloader(_):
8585

8686
class BoringModelTPU(BoringModel):
8787
def on_train_start(self) -> None:
88-
assert self.device == torch.device("xla")
88+
assert self.device == torch.device("xla", index=1)
8989
assert os.environ.get("PT_XLA_DEBUG") == "1"
9090

9191

9292
@RunIf(tpu=True)
9393
@pl_multi_process_test
9494
def test_model_tpu_one_core():
9595
"""Tests if device/debug flag is set correctely when training and after teardown for TPUSpawnPlugin."""
96-
trainer = Trainer(tpu_cores=1, fast_dev_run=True, plugin=TPUSpawnPlugin(debug=True))
96+
trainer = Trainer(tpu_cores=1, fast_dev_run=True, strategy=TPUSpawnPlugin(debug=True))
9797
# assert training type plugin attributes for device setting
9898
assert isinstance(trainer.training_type_plugin, TPUSpawnPlugin)
9999
assert not trainer.training_type_plugin.on_gpu
100100
assert trainer.training_type_plugin.on_tpu
101-
assert trainer.training_type_plugin.root_device == torch.device("xla")
101+
assert trainer.training_type_plugin.root_device == torch.device("xla", index=1)
102102
model = BoringModelTPU()
103103
trainer.fit(model)
104104
assert "PT_XLA_DEBUG" not in os.environ

0 commit comments

Comments
 (0)