Skip to content

Commit 8245540

Browse files
kaushikb11SeanNaren
authored andcommitted
Fix TPU Spawn gather (#6896)
(cherry picked from commit 5552503)
1 parent f895e9f commit 8245540

File tree

4 files changed

+70
-39
lines changed

4 files changed

+70
-39
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
228228
- Fixed `EarlyStopping` logic when `min_epochs` or `min_steps` requirement is not met ([#6705](https://github.com/PyTorchLightning/pytorch-lightning/pull/6705))
229229

230230

231+
- Fixed TPU Spawn all gather ([#6896](https://github.com/PyTorchLightning/pytorch-lightning/pull/6896))
232+
233+
231234
- Fixed `--gpus` default for parser returned by `Trainer.add_argparse_args` ([#6898](https://github.com/PyTorchLightning/pytorch-lightning/pull/6898))
232235

233236

pytorch_lightning/accelerators/tpu.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
1-
from typing import Any, Callable, Optional, Union
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Any, Callable, TYPE_CHECKING, Union
215

3-
import torch
416
from torch.optim import Optimizer
517

618
from pytorch_lightning.accelerators.accelerator import Accelerator
@@ -16,10 +28,19 @@
1628

1729
xla_clip_grad_norm_ = clip_grad_norm_
1830

31+
if TYPE_CHECKING:
32+
from pytorch_lightning.core.lightning import LightningModule
33+
from pytorch_lightning.trainer.trainer import Trainer
34+
1935

2036
class TPUAccelerator(Accelerator):
2137

22-
def setup(self, trainer, model):
38+
def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None:
39+
"""
40+
Raises:
41+
MisconfigurationException:
42+
If AMP is used with TPU, or if TPUs are not using a single TPU core or TPU spawn training.
43+
"""
2344
if isinstance(self.precision_plugin, MixedPrecisionPlugin):
2445
raise MisconfigurationException(
2546
"amp + tpu is not supported. "
@@ -30,24 +51,11 @@ def setup(self, trainer, model):
3051
raise MisconfigurationException("TPUs only support a single tpu core or tpu spawn training.")
3152
return super().setup(trainer, model)
3253

33-
def run_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs):
54+
def run_optimizer_step(
55+
self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any
56+
) -> None:
3457
xm.optimizer_step(optimizer, barrier=False, optimizer_args={'closure': lambda_closure, **kwargs})
3558

36-
def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False):
37-
"""
38-
Function to gather a tensor from several distributed processes
39-
Args:
40-
tensor: tensor of shape (batch, ...)
41-
group: not available with TPUs
42-
sync_grads: not available with TPUs
43-
Return:
44-
A tensor of shape (world_size, batch, ...)
45-
"""
46-
# todo: Add support for backward with all_gather
47-
if isinstance(self.training_type_plugin, TPUSpawnPlugin) and self.training_type_plugin.is_distributed:
48-
return xm.all_gather(tensor).view(-1, *tensor.shape)
49-
return tensor
50-
5159
def clip_gradients(self, optimizer: Optimizer, clip_val: Union[float, int], norm_type: float = 2.0):
5260

5361
model = self.lightning_module

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,14 +192,14 @@ def broadcast(self, obj: object, src: int = 0) -> object:
192192
return obj
193193

194194
def reduce_boolean_decision(self, decision: bool) -> bool:
195-
decision = torch.tensor(int(decision), device=self.device)
196-
decision = self.reduce(decision, "sum")
195+
decision = torch.tensor(int(decision), device=self.lightning_module.device)
196+
decision = self.reduce(decision, reduce_op="sum")
197197
decision = bool(decision == self.world_size)
198198
return decision
199199

200200
def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
201201
if not isinstance(output, torch.Tensor):
202-
output = torch.tensor(output, device=self.device)
202+
output = torch.tensor(output, device=self.lightning_module.device)
203203

204204
_invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM
205205
_invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg")
@@ -265,3 +265,15 @@ def save_checkpoint(self, filepath: str, weights_only: bool = False) -> None:
265265
if _OMEGACONF_AVAILABLE:
266266
checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container)
267267
self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath)
268+
269+
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
270+
"""
271+
Function to gather a tensor from several distributed processes
272+
Args:
273+
tensor: tensor of shape (batch, ...)
274+
group: not available with TPUs
275+
sync_grads: not available with TPUs
276+
Return:
277+
A tensor of shape (world_size, batch, ...)
278+
"""
279+
return xm.all_gather(tensor.unsqueeze(0))

tests/models/test_tpu.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def val_dataloader(self):
5454
return DataLoader(RandomDataset(32, 2000), batch_size=32)
5555

5656

57-
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
57+
@RunIf(tpu=True)
5858
@pl_multi_process_test
5959
def test_model_tpu_cores_1(tmpdir):
6060
"""Make sure model trains on TPU."""
@@ -73,7 +73,7 @@ def test_model_tpu_cores_1(tmpdir):
7373

7474

7575
@pytest.mark.parametrize('tpu_core', [1, 5])
76-
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
76+
@RunIf(tpu=True)
7777
@pl_multi_process_test
7878
def test_model_tpu_index(tmpdir, tpu_core):
7979
"""Make sure model trains on TPU."""
@@ -92,7 +92,7 @@ def test_model_tpu_index(tmpdir, tpu_core):
9292
assert torch_xla._XLAC._xla_get_default_device() == f'xla:{tpu_core}'
9393

9494

95-
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
95+
@RunIf(tpu=True)
9696
@pl_multi_process_test
9797
def test_model_tpu_cores_8(tmpdir):
9898
"""Make sure model trains on TPU."""
@@ -111,7 +111,7 @@ def test_model_tpu_cores_8(tmpdir):
111111
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False, min_acc=0.05)
112112

113113

114-
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
114+
@RunIf(tpu=True)
115115
@pl_multi_process_test
116116
def test_model_16bit_tpu_cores_1(tmpdir):
117117
"""Make sure model trains on TPU."""
@@ -132,7 +132,7 @@ def test_model_16bit_tpu_cores_1(tmpdir):
132132

133133

134134
@pytest.mark.parametrize('tpu_core', [1, 5])
135-
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
135+
@RunIf(tpu=True)
136136
@pl_multi_process_test
137137
def test_model_16bit_tpu_index(tmpdir, tpu_core):
138138
"""Make sure model trains on TPU."""
@@ -153,7 +153,7 @@ def test_model_16bit_tpu_index(tmpdir, tpu_core):
153153
assert os.environ.get('XLA_USE_BF16') == str(1), "XLA_USE_BF16 was not set in environment variables"
154154

155155

156-
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
156+
@RunIf(tpu=True)
157157
@pl_multi_process_test
158158
def test_model_16bit_tpu_cores_8(tmpdir):
159159
"""Make sure model trains on TPU."""
@@ -173,7 +173,7 @@ def test_model_16bit_tpu_cores_8(tmpdir):
173173
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False, min_acc=0.05)
174174

175175

176-
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
176+
@RunIf(tpu=True)
177177
@pl_multi_process_test
178178
def test_model_tpu_early_stop(tmpdir):
179179
"""Test if single TPU core training works"""
@@ -200,7 +200,7 @@ def validation_step(self, *args, **kwargs):
200200
trainer.test(test_dataloaders=DataLoader(RandomDataset(32, 2000), batch_size=32))
201201

202202

203-
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
203+
@RunIf(tpu=True)
204204
@pl_multi_process_test
205205
def test_tpu_grad_norm(tmpdir):
206206
"""Test if grad_norm works on TPU."""
@@ -219,16 +219,24 @@ def test_tpu_grad_norm(tmpdir):
219219
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)
220220

221221

222-
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
222+
@RunIf(tpu=True)
223223
@pl_multi_process_test
224224
def test_dataloaders_passed_to_fit(tmpdir):
225225
"""Test if dataloaders passed to trainer works on TPU"""
226226

227227
tutils.reset_seed()
228228
model = BoringModel()
229229

230-
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, tpu_cores=8)
231-
trainer.fit(model, train_dataloader=model.train_dataloader(), val_dataloaders=model.val_dataloader())
230+
trainer = Trainer(
231+
default_root_dir=tmpdir,
232+
max_epochs=1,
233+
tpu_cores=8,
234+
)
235+
trainer.fit(
236+
model,
237+
train_dataloader=model.train_dataloader(),
238+
val_dataloaders=model.val_dataloader(),
239+
)
232240
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
233241

234242

@@ -237,7 +245,7 @@ def test_dataloaders_passed_to_fit(tmpdir):
237245
[pytest.param(1, None), pytest.param(8, None),
238246
pytest.param([1], 1), pytest.param([8], 8)],
239247
)
240-
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires missing TPU")
248+
@RunIf(tpu=True)
241249
def test_tpu_id_to_be_as_expected(tpu_cores, expected_tpu_id):
242250
"""Test if trainer.tpu_id is set as expected"""
243251
assert Trainer(tpu_cores=tpu_cores).accelerator_connector.tpu_id == expected_tpu_id
@@ -258,13 +266,13 @@ def test_exception_when_no_tpu_found(tmpdir):
258266

259267

260268
@pytest.mark.parametrize('tpu_cores', [1, 8, [1]])
261-
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
269+
@RunIf(tpu=True)
262270
def test_distributed_backend_set_when_using_tpu(tmpdir, tpu_cores):
263271
"""Test if distributed_backend is set to `tpu` when tpu_cores is not None"""
264272
assert Trainer(tpu_cores=tpu_cores).distributed_backend == "tpu"
265273

266274

267-
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
275+
@RunIf(tpu=True)
268276
@pl_multi_process_test
269277
def test_broadcast_on_tpu():
270278
""" Checks if an object from the master process is broadcasted to other processes correctly"""
@@ -296,7 +304,7 @@ def test_broadcast(rank):
296304
pytest.param(10, None, True),
297305
],
298306
)
299-
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
307+
@RunIf(tpu=True)
300308
@pl_multi_process_test
301309
def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected):
302310
if error_expected:
@@ -312,7 +320,7 @@ def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected):
312320
[pytest.param('--tpu_cores=8', {'tpu_cores': 8}),
313321
pytest.param("--tpu_cores=1,", {'tpu_cores': '1,'})]
314322
)
315-
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
323+
@RunIf(tpu=True)
316324
@pl_multi_process_test
317325
def test_tpu_cores_with_argparse(cli_args, expected):
318326
"""Test passing tpu_cores in command line"""
@@ -327,7 +335,7 @@ def test_tpu_cores_with_argparse(cli_args, expected):
327335
assert Trainer.from_argparse_args(args)
328336

329337

330-
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
338+
@RunIf(tpu=True)
331339
@pl_multi_process_test
332340
def test_tpu_reduce():
333341
"""Test tpu spawn reduce operation """

0 commit comments

Comments
 (0)