Skip to content

Commit 3643954

Browse files
committed
[Fix] TPU Training Type Plugin (Lightning-AI#6816)
1 parent 215a9c9 commit 3643954

File tree

5 files changed

+55
-129
lines changed

5 files changed

+55
-129
lines changed
Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
1-
import os
2-
from typing import Optional, Union
3-
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
414
import torch
515

6-
from pytorch_lightning.core.lightning import LightningModule
716
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
8-
from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle
9-
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn
17+
from pytorch_lightning.utilities import _TPU_AVAILABLE
1018
from pytorch_lightning.utilities.apply_func import move_data_to_device
1119

1220
if _TPU_AVAILABLE:
@@ -15,21 +23,26 @@
1523

1624
class SingleTPUPlugin(SingleDevicePlugin):
1725

18-
def __init__(self, device: Union[torch.device, int]):
19-
if isinstance(device, int):
20-
device = xm.xla_device(device)
26+
def __init__(self, device: int):
27+
28+
device = xm.xla_device(device)
2129
super().__init__(device)
2230

2331
self.tpu_local_core_rank = 0
2432
self.tpu_global_core_rank = 0
2533

34+
@property
2635
def on_tpu(self) -> bool:
2736
return True
2837

2938
def connect(self, model: torch.nn.Module) -> torch.nn.Module:
3039
self._model = model
3140
self.model_to_device()
3241
return self._model
42+
43+
@property
44+
def is_distributed(self) -> bool:
45+
return False
3346

3447
def model_to_device(self) -> None:
3548
self._model.to(self.root_device)
@@ -41,29 +54,10 @@ def pre_dispatch(self) -> None:
4154
self.tpu_local_core_rank = xm.get_local_ordinal()
4255
self.tpu_global_core_rank = xm.get_ordinal()
4356

44-
def post_dispatch(self) -> None:
45-
model = self.lightning_module
46-
47-
if on_colab_kaggle():
48-
rank_zero_warn("cleaning up... please do not interrupt")
49-
self.save_spawn_weights(model)
50-
51-
def save_spawn_weights(self, model: LightningModule) -> Optional[str]:
52-
"""
53-
Dump a temporary checkpoint after ddp ends to get weights out of the process
54-
"""
55-
path = os.path.join(model.trainer.default_root_dir, "__temp_weight_distributed_end.ckpt")
56-
model.trainer.save_checkpoint(path)
57-
return path
58-
5957
def on_save(self, checkpoint: dict) -> dict:
6058
"""
6159
Move XLA tensors to CPU before saving
6260
Recommended on XLA Guide:
6361
https://github.com/pytorch/xla/blob/master/API_GUIDE.md#saving-and-loading-xla-tensors
6462
"""
6563
return move_data_to_device(checkpoint, torch.device("cpu"))
66-
67-
@property
68-
def is_distributed(self):
69-
return False

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 7 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77
import torch
88
import torch.multiprocessing as mp
99

10-
from pytorch_lightning.core.lightning import LightningModule
1110
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
12-
from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle
1311
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE, rank_zero_warn
1412
from pytorch_lightning.utilities.apply_func import apply_to_collection
1513
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
@@ -18,28 +16,20 @@
1816

1917
if _TPU_AVAILABLE:
2018
import torch_xla.core.xla_model as xm
21-
import torch_xla.distributed.parallel_loader as xla_pl
2219
import torch_xla.distributed.xla_multiprocessing as xmp
2320
from torch_xla.core.xla_model import rendezvous
24-
from torch_xla.distributed.parallel_loader import ParallelLoader
21+
from torch_xla.distributed.parallel_loader import MpDeviceLoader
2522
else:
26-
xm, xla_pl, xmp, ParallelLoader, rendezvous = [None] * 5
23+
xm, xmp, MpDeviceLoader, rendezvous = [None] * 4
2724

2825
if _OMEGACONF_AVAILABLE:
2926
from omegaconf import DictConfig, ListConfig, OmegaConf
3027

3128

3229
class TPUSpawnPlugin(DDPSpawnPlugin):
3330

34-
def __init__(
35-
self,
36-
parallel_devices: Optional[List[torch.device]] = None,
37-
num_nodes: int = 1,
38-
**kwargs: Dict[str, Any]
39-
) -> None:
40-
super().__init__(
41-
parallel_devices, num_nodes=num_nodes, cluster_environment=None, sync_batchnorm=False, **kwargs
42-
)
31+
def __init__(self, parallel_devices: Optional[List[int]] = None, **kwargs: Dict[str, Any]) -> None:
32+
super().__init__(parallel_devices, num_nodes=1, cluster_environment=None, sync_batchnorm=False)
4333
self.tpu_local_core_rank = 0
4434
self.start_method = None
4535

@@ -61,10 +51,9 @@ def distributed_sampler_kwargs(self) -> dict:
6151
def is_distributed(self):
6252
return self.world_size != 1
6353

64-
def process_dataloader(self, dataloader: Union[Iterable, torch.utils.data.DataLoader]) -> ParallelLoader:
54+
def process_dataloader(self, dataloader: Union[Iterable, torch.utils.data.DataLoader]) -> MpDeviceLoader:
6555
device = xm.xla_device()
66-
dataloader = xla_pl.ParallelLoader(dataloader, [device])
67-
dataloader = dataloader.per_device_loader(device)
56+
dataloader = MpDeviceLoader(dataloader, device)
6857
return dataloader
6958

7059
def configure_ddp(self) -> None:
@@ -104,7 +93,6 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
10493

10594
results = trainer.train_or_test_or_predict()
10695

107-
self.__save_end_of_training_weights(self.lightning_module)
10896
self.transfer_distrib_spawn_state_on_fit_end(results)
10997

11098
# https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
@@ -114,12 +102,6 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
114102
if self.global_rank == 0:
115103
time.sleep(2)
116104

117-
def __save_end_of_training_weights(self, model: LightningModule) -> None:
118-
# when training ends on these platforms dump weights to get out of the main process
119-
if on_colab_kaggle():
120-
rank_zero_warn("cleaning up... please do not interrupt")
121-
self.save_spawn_weights(model)
122-
123105
def model_to_device(self) -> None:
124106
self._model.to(xm.xla_device())
125107

@@ -159,37 +141,7 @@ def broadcast(self, obj: object, src: int = 0) -> object:
159141
obj = torch.load(buffer)
160142
return obj
161143

162-
def load_spawn_weights(self, original_model: LightningModule) -> LightningModule:
163-
"""
164-
Load the temp weights saved in the process
165-
To recover the trained model from the ddp process we load the saved weights
166-
"""
167-
168-
loaded_model = original_model
169-
170-
if self.is_global_zero:
171-
# load weights saved in ddp
172-
path = os.path.join(original_model.trainer.default_root_dir, "__temp_weight_distributed_end.ckpt")
173-
loaded_model = original_model.__class__.load_from_checkpoint(path)
174-
175-
# copy loaded weights to old model
176-
original_model.load_state_dict(loaded_model.state_dict())
177-
178-
# remove ddp weights
179-
os.remove(path)
180-
181-
return loaded_model
182-
183-
def save_spawn_weights(self, model: LightningModule) -> Optional[str]:
184-
"""
185-
Dump a temporary checkpoint after ddp ends to get weights out of the process
186-
"""
187-
if model.trainer.is_global_zero:
188-
path = os.path.join(model.trainer.default_root_dir, "__temp_weight_distributed_end.ckpt")
189-
model.trainer.save_checkpoint(path)
190-
return path
191-
192-
def reduce_decision(self, decision: bool) -> bool:
144+
def reduce_boolean_decision(self, decision: bool) -> bool:
193145
decision = torch.tensor(int(decision), device=self.device)
194146
decision = self.reduce(decision, "sum")
195147
decision = bool(decision == self.world_size)
@@ -213,40 +165,6 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[
213165

214166
return output
215167

216-
def post_dispatch(self) -> None:
217-
# TODO: Check if trainer references can be resolved otherwise
218-
model = self.lightning_module
219-
220-
# restore main state with best weights
221-
best_path = self.mp_queue.get()
222-
last_path = self.mp_queue.get()
223-
self._results = self.mp_queue.get()
224-
225-
# transfer back the best path to the trainer
226-
if self.lightning_module.trainer.checkpoint_callback is not None:
227-
self.lightning_module.trainer.checkpoint_callback.best_model_path = best_path
228-
# todo, pass also bets score
229-
230-
# load last weights
231-
if last_path and not self.lightning_module.trainer.testing:
232-
ckpt = torch.load(last_path, map_location=lambda storage, loc: storage)
233-
model.load_state_dict(ckpt)
234-
235-
self._model = model
236-
237-
# when training completes, load the weights back in main process
238-
self.__load_weights_on_main_process()
239-
240-
def __load_weights_on_main_process(self) -> None:
241-
model = self.lightning_module
242-
243-
# load weights if not interrupted
244-
# TODO: check for trainer reference
245-
if on_colab_kaggle() and not model.trainer.testing:
246-
self.load_spawn_weights(model)
247-
248-
self._model = model
249-
250168
def _close_logger(self, trainer) -> None:
251169
if trainer.logger is not None:
252170
trainer.logger.finalize("success")

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def use_dp(self) -> bool:
251251
def use_ddp(self) -> bool:
252252
return self._distrib_type in (
253253
DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED,
254-
DistributedType.DDP_SHARDED_SPAWN, DistributedType.DEEPSPEED
254+
DistributedType.DDP_SHARDED_SPAWN, DistributedType.DEEPSPEED, DistributedType.TPU_SPAWN
255255
)
256256

257257
@property
@@ -291,7 +291,8 @@ def parallel_devices(self) -> Union[List[torch.device], int]:
291291
elif self.on_tpu:
292292
# explicitly don't make a tpu device here!
293293
# https://github.com/PyTorchLightning/pytorch-lightning/issues/3169
294-
devices = [i for i in self.parallel_device_ids]
294+
if isinstance(self.tpu_cores, int):
295+
devices = list(range(self.tpu_cores))
295296
else:
296297
devices = [torch.device("cpu")] * self.num_processes
297298
return devices
@@ -369,6 +370,7 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
369370
use_torchelastic_ddp = self.use_ddp and self.is_using_torchelastic
370371
use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN
371372
use_ddp_cpu_spawn = self.use_ddp and self.on_cpu
373+
use_tpu_spawn = self.on_tpu and self._distrib_type == DistributedType.TPU_SPAWN
372374
use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self.is_using_torchelastic
373375
use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks
374376
use_ddp_sharded = self._distrib_type == DistributedType.DDP_SHARDED
@@ -379,7 +381,7 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
379381
if os.environ.get("PL_IN_DDP_SUBPROCESS", False):
380382
use_torchelastic_ddp = False
381383

382-
if self.on_tpu:
384+
if use_tpu_spawn:
383385
ddp_plugin_cls = TPUSpawnPlugin
384386
elif use_ddp_sharded:
385387
ddp_plugin_cls = DDPShardedPlugin
@@ -402,11 +404,8 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
402404
plugin = DataParallelPlugin(parallel_devices=self.parallel_devices)
403405
elif self.use_horovod:
404406
plugin = HorovodPlugin(parallel_devices=self.parallel_devices)
405-
elif self.on_tpu:
406-
if isinstance(self.tpu_cores, list):
407-
plugin = SingleTPUPlugin(self.tpu_id)
408-
else:
409-
plugin = TPUSpawnPlugin(parallel_devices=list(range(self.tpu_cores)))
407+
elif self.on_tpu and isinstance(self.tpu_cores, list):
408+
plugin = SingleTPUPlugin(self.tpu_id)
410409
else:
411410
single_gpu_ordinal = device_parser.determine_root_gpu_device(self.parallel_device_ids)
412411
plugin = SingleDevicePlugin(device=torch.device(f"cuda:{single_gpu_ordinal}" if self.on_gpu else "cpu"))
@@ -507,6 +506,8 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None):
507506
# special case with TPUs
508507
elif self.distributed_backend == 'tpu' or self.tpu_cores is not None:
509508
self._device_type = DeviceType.TPU
509+
if isinstance(self.tpu_cores, int):
510+
self._distrib_type = DistributedType.TPU_SPAWN
510511
elif self.distributed_backend and self._distrib_type is None:
511512
self._distrib_type = DistributedType(self.distributed_backend)
512513

@@ -515,9 +516,9 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None):
515516
if self.num_gpus > 0 and not _on_cpu:
516517
self._device_type = DeviceType.GPU
517518

518-
_distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2)
519+
_gpu_distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2)
519520
# DP and DDP2 cannot run without GPU
520-
if self.num_gpus == 0 and self._distrib_type in _distrib_types and not _on_cpu:
521+
if self.num_gpus == 0 and self._distrib_type in _gpu_distrib_types and not _on_cpu:
521522
rank_zero_warn(
522523
'You requested distributed training on GPUs, but none is available, so we set backend to `ddp_cpu`.'
523524
)

pytorch_lightning/utilities/enums.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,23 @@ class DistributedType(LightningEnum):
5858
>>> DistributedType.DDP2 in ('ddp2', )
5959
True
6060
"""
61+
62+
@staticmethod
63+
def interactive_compatible_types() -> List['DistributedType']:
64+
"""Returns a list containing interactive compatible DistributeTypes"""
65+
return [
66+
DistributedType.DP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED_SPAWN, DistributedType.TPU_SPAWN
67+
]
68+
69+
def is_interactive_compatible(self) -> bool:
70+
"""Returns whether self is interactive compatible"""
71+
return self in DistributedType.interactive_compatible_types()
72+
6173
DP = 'dp'
6274
DDP = 'ddp'
6375
DDP2 = 'ddp2'
6476
DDP_SPAWN = 'ddp_spawn'
77+
TPU_SPAWN = 'tpu_spawn'
6578
DEEPSPEED = 'deepspeed'
6679
HOROVOD = 'horovod'
6780
DDP_SHARDED = 'ddp_sharded'

tests/models/test_tpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,8 @@ def test_tpu_grad_norm(tmpdir):
210210
progress_bar_refresh_rate=0,
211211
max_epochs=4,
212212
tpu_cores=1,
213-
limit_train_batches=4,
214-
limit_val_batches=4,
213+
limit_train_batches=0.4,
214+
limit_val_batches=0.4,
215215
gradient_clip_val=0.5,
216216
)
217217

0 commit comments

Comments
 (0)