Skip to content

Commit cd25a6a

Browse files
add tpu spawn
Co-authored-by: Adrian Wälchli <[email protected]>
1 parent ede6716 commit cd25a6a

File tree

1 file changed

+186
-0
lines changed

1 file changed

+186
-0
lines changed
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import io
2+
import os
3+
from typing import Any, Dict, Iterable, Optional, Sequence, Union
4+
5+
import torch
6+
7+
from pytorch_lightning.core.lightning import LightningModule
8+
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
9+
from pytorch_lightning.trainer import Trainer
10+
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn
11+
from pytorch_lightning.utilities.apply_func import move_data_to_device
12+
from pytorch_lightning.utilities.distributed import rank_zero_only
13+
from pytorch_lightning.utilities.seed import seed_everything
14+
15+
if _TPU_AVAILABLE:
16+
import torch_xla
17+
import torch_xla.core.xla_model as xm
18+
import torch_xla.distributed.parallel_loader as xla_pl
19+
import torch_xla.distributed.xla_multiprocessing as xmp
20+
21+
class TPUSpawnPlugin(DDPSpawnPlugin):
22+
def __init__(self, parallel_devices: Sequence, num_nodes: int = 1, **kwargs: Dict[str, Any]) -> None:
23+
24+
parallel_devices = [xm.xla_device(device) if isinstance(device, int) else device for device in parallel_devices]
25+
super().__init__(parallel_devices, num_nodes=num_nodes, cluster_environment=None, sync_batchnorm=False, **kwargs)
26+
self.tpu_local_core_rank = 0
27+
self.start_method = None
28+
29+
@property
30+
def distributed_sampler_kwargs(self) -> dict:
31+
return dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
32+
33+
def process_dataloader(self, dataloader: Union[Iterable, torch.utils.data.DataLoader]) -> xla_pl.ParallelLoader:
34+
device = xm.xla_device(self.trainer.tpu_id)
35+
dataloader = xla_pl.ParallelLoader(dataloader, [device])
36+
dataloader = dataloader.per_device_loader(device)
37+
return dataloader
38+
39+
def configure_ddp(self) -> None:
40+
pass
41+
42+
def init_ddp_connection(self, global_rank: int, world_size: int) -> None:
43+
pass
44+
45+
def set_world_ranks(self, process_idx: int) -> None:
46+
self.tpu_local_core_rank = xm.get_local_ordinal()
47+
self.tpu_global_core_rank = xm.get_ordinal()
48+
self.global_rank = self.tpu_local_core_rank
49+
self.world_size = self.num_nodes * self.num_processes
50+
51+
def new_process(self, process_idx: int, trainer: Trainer) ->None:
52+
seed = os.environ.get("PL_GLOBAL_SEED")
53+
if seed is not None:
54+
seed_everything(int(seed))
55+
56+
self.set_world_ranks(process_idx)
57+
58+
# set warning rank
59+
rank_zero_only.rank = self.global_rank
60+
61+
if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None:
62+
trainer.progress_bar_callback.disable()
63+
64+
self.model_to_device()
65+
self.barrier()
66+
67+
if trainer.testing:
68+
results = trainer.run_test()
69+
else:
70+
results = trainer.train()
71+
72+
self.__save_end_of_training_weights(self.lightning_module)
73+
self.transfer_distrib_spawn_state_on_fit_end(results)
74+
75+
def __save_end_of_training_weights(self, model: LightningModule, trainer: Trainer) -> None:
76+
# when training ends on these platforms dump weights to get out of the main process
77+
if self.on_colab_kaggle:
78+
rank_zero_warn("cleaning up... please do not interrupt")
79+
self.save_spawn_weights(model)
80+
81+
def model_to_device(self) -> None:
82+
pass
83+
84+
def barrier(self, name: Optional[str] = None) -> None:
85+
torch_xla.core.xla_model.rendezvous(f"pl.Trainer.{name}")
86+
87+
def on_save(self, checkpoint: dict) -> dict:
88+
"""
89+
Move XLA tensors to CPU before saving
90+
Recommended on XLA Guide:
91+
https://github.com/pytorch/xla/blob/master/API_GUIDE.md#saving-and-loading-xla-tensors
92+
"""
93+
return move_data_to_device(checkpoint, torch.device("cpu"))
94+
95+
@property
96+
def on_colab_kaggle(self) -> bool:
97+
return bool(os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE'))
98+
99+
def broadcast(self, obj: object, src:int=0)->object:
100+
buffer = io.BytesIO()
101+
torch.save(obj, buffer)
102+
data = bytearray(buffer.getbuffer())
103+
data_tensor = torch.tensor(data).to(xm.xla_device(), dtype=torch.float)
104+
data = xm.all_gather(data_tensor)
105+
buffer = io.BytesIO(data.cpu().byte().numpy())
106+
obj = torch.load(buffer)
107+
return obj
108+
109+
def load_spawn_weights(self, original_model: LightningModule) -> LightningModule:
110+
"""
111+
Load the temp weights saved in the process
112+
To recover the trained model from the ddp process we load the saved weights
113+
"""
114+
115+
loaded_model = original_model
116+
117+
if self.is_global_zero:
118+
# load weights saved in ddp
119+
path = os.path.join(original_model.trainer.default_root_dir, "__temp_weight_distributed_end.ckpt")
120+
loaded_model = original_model.__class__.load_from_checkpoint(path)
121+
122+
# copy loaded weights to old model
123+
original_model.load_state_dict(loaded_model.state_dict())
124+
125+
# remove ddp weights
126+
os.remove(path)
127+
128+
return loaded_model
129+
130+
def save_spawn_weights(self, model: LightningModule) -> Optional[str]:
131+
"""
132+
Dump a temporary checkpoint after ddp ends to get weights out of the process
133+
"""
134+
if model.trainer.is_global_zero:
135+
path = os.path.join(model.trainer.default_root_dir, "__temp_weight_distributed_end.ckpt")
136+
model.trainer.save_checkpoint(path)
137+
return path
138+
139+
def reduce_early_stopping_decision(self, should_stop: bool) -> bool:
140+
should_stop = torch.tensor(int(should_stop), device=self.lightning_module.device)
141+
stop = xm.mesh_reduce('stop_signal', should_stop, sum)
142+
torch_xla.core.xla_model.rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check")
143+
should_stop = int(stop.item()) == self.world_size
144+
return should_stop
145+
146+
def post_training(self) -> None:
147+
# TODO: Check if trainer references can be resolved otherwise
148+
model = self.lightning_module
149+
150+
# restore main state with best weights
151+
best_path = self.mp_queue.get()
152+
results = self.mp_queue.get()
153+
last_path = self.mp_queue.get()
154+
155+
# transfer back the best path to the trainer
156+
if self.lightning_module.trainer.checkpoint_callback is not None:
157+
self.lightning_module.trainer.checkpoint_callback.best_model_path = best_path
158+
# todo, pass also bets score
159+
160+
# load last weights
161+
if last_path and not self.lightning_module.trainer.testing:
162+
ckpt = torch.load(last_path, map_location=lambda storage, loc: storage)
163+
model.load_state_dict(ckpt)
164+
165+
self.lightning_module = model
166+
167+
# when training completes, load the weights back in main process
168+
self.__load_weights_on_main_process()
169+
170+
def __load_weights_on_main_process(self) -> None:
171+
model = self.lightning_module
172+
173+
# load weights if not interrupted
174+
# TODO: check for trainer reference
175+
if self.on_colab_kaggle and not model.trainer.testing:
176+
self.load_spawn_weights(model)
177+
178+
self.lightning_module = model
179+
180+
def start_training(self, trainer: Trainer) -> None:
181+
xmp.spawn(self.new_process, args=(self.lightning_module, trainer, self.mp_queue),
182+
nproc=len(self.parallel_devices), start_method=self.start_method)
183+
184+
def start_testing(self, trainer: Trainer) -> None:
185+
xmp.spawn(self.new_process, args=(self.lightning_module, trainer, self.mp_queue),
186+
nproc=len(self.parallel_devices), start_method=self.start_method)

0 commit comments

Comments
 (0)