Skip to content

Commit aace276

Browse files
tchatonlezwonYour Name
authored
[wip] Fix some bugs for TPU [skip ci] (#5878)
* fixed for single tpu * fixed spawn * fixed spawn * update * update * wip * resolve bugs * resolve bug * update on comment * removed decorator * resolve comments * set to 4 * update * update * need cleaning * update * update * update * resolve flake8 * resolve bugs * exclude broadcast * resolve bugs * change test * update * update * skip if meet fails * properly raise trace * update * add catch * wrap test * resolve typo * update * typo Co-authored-by: Lezwon Castelino <[email protected]> Co-authored-by: Your Name <[email protected]>
1 parent 236009e commit aace276

File tree

20 files changed

+201
-108
lines changed

20 files changed

+201
-108
lines changed

dockers/tpu-tests/tpu_test_cases.jsonnet

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ local tputests = base.BaseTest {
2121
command: utils.scriptCommand(
2222
|||
2323
cd pytorch-lightning
24-
coverage run --source=pytorch_lightning -m pytest -v \
24+
coverage run --source=pytorch_lightning -m pytest -v --capture=no \
2525
pytorch_lightning/utilities/xla_device_utils.py \
2626
tests/accelerators/legacy/test_tpu_backend.py \
2727
tests/models/test_tpu.py

pytorch_lightning/accelerators/accelerator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def setup(self, trainer: "Trainer", model: LightningModule) -> None:
7676
model: the model to train
7777
"""
7878
self.connect_training_type_plugin(self.training_type_plugin, model)
79-
self.setup_optimizers(trainer, model)
79+
self.setup_optimizers(trainer)
8080
self.connect_precision_plugin(self.precision_plugin)
8181

8282
@property
@@ -306,7 +306,7 @@ def on_train_end(self) -> None:
306306
"""Hook to do something at the end of the training"""
307307
pass
308308

309-
def setup_optimizers(self, trainer: "Trainer", model: LightningModule):
309+
def setup_optimizers(self, trainer: "Trainer"):
310310
"""creates optimizers and schedulers
311311
312312
Args:
@@ -315,7 +315,7 @@ def setup_optimizers(self, trainer: "Trainer", model: LightningModule):
315315
"""
316316
if trainer.testing is True:
317317
return
318-
optimizers, lr_schedulers, optimizer_frequencies = trainer.init_optimizers(model)
318+
optimizers, lr_schedulers, optimizer_frequencies = trainer.init_optimizers(self.lightning_module)
319319
self.optimizers = optimizers
320320
self.lr_schedulers = lr_schedulers
321321
self.optimizer_frequencies = optimizer_frequencies

pytorch_lightning/accelerators/accelerator_connector.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def on_tpu(self):
227227

228228
@property
229229
def tpu_id(self):
230-
if self.on_tpu:
230+
if self.on_tpu and isinstance(self.tpu_cores, list):
231231
return self.tpu_cores[0]
232232

233233
return None
@@ -380,7 +380,10 @@ def select_training_type_plugin(self):
380380
elif self.use_horovod:
381381
plugin = HorovodPlugin(parallel_devices=self.parallel_devices)
382382
elif self.on_tpu:
383-
plugin = SingleTPUPlugin(self.tpu_id)
383+
if isinstance(self.tpu_cores, list):
384+
plugin = SingleTPUPlugin(self.tpu_id)
385+
else:
386+
plugin = TPUSpawnPlugin(parallel_devices=list(range(self.tpu_cores)))
384387
else:
385388
single_gpu_ordinal = device_parser.determine_root_gpu_device(self.parallel_device_ids)
386389
plugin = SingleDevicePlugin(device=torch.device(f"cuda:{single_gpu_ordinal}" if self.on_gpu else "cpu"))

pytorch_lightning/accelerators/legacy/tpu_accelerator.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414
import io
1515
import os
16-
import re
1716
from typing import Any, Callable, Optional, Union
1817

1918
import torch
@@ -31,7 +30,6 @@
3130
rank_zero_only,
3231
rank_zero_warn,
3332
)
34-
from pytorch_lightning.utilities.cloud_io import atomic_save
3533
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3634

3735
if _TPU_AVAILABLE:
@@ -307,29 +305,6 @@ def load_spawn_weights(self, original_model):
307305

308306
return loaded_model
309307

310-
def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
311-
if self.trainer.distributed_backend not in ("ddp_spawn", "ddp_cpu", "tpu"):
312-
return
313-
314-
# track the best model path
315-
best_model_path = None
316-
if self.trainer.checkpoint_callback is not None:
317-
best_model_path = self.trainer.checkpoint_callback.best_model_path
318-
319-
if self.trainer.global_rank == 0 and mp_queue is not None:
320-
rank_zero_warn('cleaning up ddp environment...')
321-
# todo, pass complete checkpoint as state dictionary
322-
mp_queue.put(best_model_path)
323-
mp_queue.put(results)
324-
325-
# save the last weights
326-
last_path = None
327-
if not self.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
328-
last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path)
329-
state_dict = move_data_to_device(model.state_dict(), torch.device("cpu"))
330-
atomic_save(state_dict, last_path)
331-
mp_queue.put(last_path)
332-
333308
def broadcast(self, obj, src=0):
334309
if self.trainer.tpu_id is not None:
335310
# running on a single core

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -520,11 +520,9 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics):
520520
trainer,
521521
)
522522

523-
accelerator_backend = trainer.accelerator_backend
524-
525-
if accelerator_backend.training_type_plugin.rpc_enabled:
523+
if trainer.training_type_plugin.rpc_enabled:
526524
# RPCPlugin manages saving all model states
527-
accelerator_backend.training_type_plugin.rpc_save_model(self._save_model, last_filepath, trainer, pl_module)
525+
trainer.training_type_plugin.rpc_save_model(self._save_model, last_filepath, trainer, pl_module)
528526
else:
529527
self._save_model(last_filepath, trainer, pl_module)
530528
if (

pytorch_lightning/core/step_result.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,9 @@ def log(
148148
value = torch.tensor(value, device=device, dtype=torch.float)
149149
value = sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op)
150150

151+
if value.device.type == "xla":
152+
value = value.cpu()
153+
151154
if 'meta' not in self:
152155
self.__setitem__('meta', {})
153156

pytorch_lightning/plugins/precision/tpu_bfloat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,4 @@ class TPUHalfPrecisionPlugin(PrecisionPlugin):
2525

2626
def connect(self, model: torch.nn.Module, optimizers, lr_schedulers):
2727
os.environ["XLA_USE_BF16"] = str(1)
28-
return super().connect(model=model, optimizers=optimizers, lr_schedulers=lr_schedulers)
28+
return super().connect(model=model, optimizers=optimizers, lr_schedulers=lr_schedulers)

pytorch_lightning/plugins/training_type/ddp_spawn.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,20 @@ def set_world_ranks(self, process_idx):
9595
self.global_rank = self.node_rank * self.num_processes + self.local_rank
9696
self.world_size = self.num_nodes * self.num_processes
9797

98+
@property
99+
def mp_spawn_kwargs(self):
100+
return {
101+
"args": (self.lightning_module.trainer, self.mp_queue),
102+
"nprocs": self.num_processes,
103+
}
104+
98105
def start_training(self, trainer):
99-
mp.spawn(self.new_process, nprocs=self.num_processes, args=(trainer, self.mp_queue))
106+
mp.spawn(self.new_process, **self.mp_spawn_kwargs)
100107
# reset optimizers, since main process is never used for training and thus does not have a valid optim state
101108
trainer.optimizers = []
102109

103110
def start_testing(self, trainer):
104-
mp.spawn(self.new_process, nprocs=self.num_processes, args=(trainer, self.mp_queue))
111+
mp.spawn(self.new_process, **self.mp_spawn_kwargs)
105112

106113
def new_process(self, process_idx, trainer, mp_queue):
107114
self.mp_queue = mp_queue
@@ -173,7 +180,6 @@ def pre_configure_ddp(self):
173180
self._ddp_kwargs["find_unused_parameters"] = True
174181

175182
def configure_ddp(self):
176-
177183
self.pre_configure_ddp()
178184
self._model = DistributedDataParallel(
179185
LightningDistributedModule(self.model),
@@ -197,6 +203,9 @@ def determine_ddp_device_ids(self):
197203
return None
198204
return [self.root_device.index]
199205

206+
def on_save(self, checkpoint: dict) -> dict:
207+
return checkpoint
208+
200209
def transfer_distrib_spawn_state_on_fit_end(self, results):
201210
# TODO: is there a better way than accessing callback through model -> trainer -> callback?
202211
checkpoint_callback = self.lightning_module.trainer.checkpoint_callback
@@ -210,7 +219,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
210219
# TODO: is there a better way than accessing trainer through model -> trainer?
211220
if not self.lightning_module.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
212221
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
213-
atomic_save(self.lightning_module.state_dict(), last_path)
222+
atomic_save(self.on_save(self.lightning_module.state_dict()), last_path)
214223

215224
# todo, pass complete checkpoint as state dictionary
216225
self.mp_queue.put(best_model_path)

pytorch_lightning/plugins/training_type/single_tpu.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import io
22
import os
3-
from typing import Optional
3+
from typing import Optional, Union
44

55
import torch
66

7+
from pytorch_lightning import LightningModule
78
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
89
from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle
10+
from pytorch_lightning.utilities.apply_func import move_data_to_device
911
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn
1012

1113
if _TPU_AVAILABLE:
@@ -15,7 +17,9 @@
1517

1618
class SingleTPUPlugin(SingleDevicePlugin):
1719

18-
def __init__(self, device: torch.device):
20+
def __init__(self, device: Union[torch.device, int]):
21+
if isinstance(device, int):
22+
device = xm.xla_device(device)
1923
super().__init__(device)
2024

2125
self.tpu_local_core_rank = 0
@@ -24,6 +28,14 @@ def __init__(self, device: torch.device):
2428
def on_tpu(self) -> bool:
2529
return True
2630

31+
def connect(self, model: torch.nn.Module) -> torch.nn.Module:
32+
self._model = model
33+
self.model_to_device()
34+
return self._model
35+
36+
def model_to_device(self) -> None:
37+
self._model.to(self.root_device)
38+
2739
def pre_training(self) -> None:
2840
if isinstance(self.device, int):
2941
self.device = xm.xla_device(self.device)
@@ -37,3 +49,19 @@ def post_training(self) -> None:
3749
if on_colab_kaggle():
3850
rank_zero_warn("cleaning up... please do not interrupt")
3951
self.save_spawn_weights(model)
52+
53+
def save_spawn_weights(self, model: LightningModule) -> Optional[str]:
54+
"""
55+
Dump a temporary checkpoint after ddp ends to get weights out of the process
56+
"""
57+
path = os.path.join(model.trainer.default_root_dir, "__temp_weight_distributed_end.ckpt")
58+
model.trainer.save_checkpoint(path)
59+
return path
60+
61+
def on_save(self, checkpoint: dict) -> dict:
62+
"""
63+
Move XLA tensors to CPU before saving
64+
Recommended on XLA Guide:
65+
https://github.com/pytorch/xla/blob/master/API_GUIDE.md#saving-and-loading-xla-tensors
66+
"""
67+
return move_data_to_device(checkpoint, torch.device("cpu"))

0 commit comments

Comments
 (0)