|
16 | 16 | from unittest import mock |
17 | 17 |
|
18 | 18 | import pytest |
| 19 | +import torch |
19 | 20 | from torch.utils.data import DataLoader |
20 | 21 |
|
21 | 22 | import tests.helpers.pipelines as tpipes |
22 | 23 | import tests.helpers.utils as tutils |
23 | 24 | from pytorch_lightning import Trainer |
24 | 25 | from pytorch_lightning.accelerators import TPUAccelerator |
25 | 26 | from pytorch_lightning.callbacks import EarlyStopping |
| 27 | +from pytorch_lightning.core.step_result import Result |
26 | 28 | from pytorch_lightning.plugins import TPUSpawnPlugin |
27 | 29 | from pytorch_lightning.trainer.states import TrainerState |
28 | 30 | from pytorch_lightning.utilities import _TPU_AVAILABLE |
@@ -397,3 +399,26 @@ def test_if_test_works_with_checkpoint_false(tmpdir): |
397 | 399 | trainer = Trainer(max_epochs=1, tpu_cores=8, default_root_dir=tmpdir, fast_dev_run=True, checkpoint_callback=False) |
398 | 400 | trainer.fit(model) |
399 | 401 | assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" |
| 402 | + |
| 403 | + |
| 404 | +@RunIf(tpu=True) |
| 405 | +@pl_multi_process_test |
| 406 | +def test_tpu_sync_dist(): |
| 407 | + """Test tpu spawn sync dist operation """ |
| 408 | + |
| 409 | + def test_sync_dist(rank): |
| 410 | + tensor = torch.tensor([1.0]) |
| 411 | + training_type_plugin = TPUSpawnPlugin() |
| 412 | + |
| 413 | + res = Result() |
| 414 | + res.log( |
| 415 | + "test_tensor", |
| 416 | + tensor, |
| 417 | + sync_fn=training_type_plugin.reduce, |
| 418 | + sync_dist=True, |
| 419 | + sync_dist_op=torch.distributed.ReduceOp.SUM |
| 420 | + ) |
| 421 | + |
| 422 | + assert res["test_tensor"].item() == 8, "Result-Log does not work properly with TPU Spawn and Tensors" |
| 423 | + |
| 424 | + xmp.spawn(test_sync_dist, nprocs=8, start_method='fork') |
0 commit comments