Skip to content

Commit ede6716

Browse files
add single tpu
Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 6dd7298 commit ede6716

File tree

1 file changed

+41
-0
lines changed

1 file changed

+41
-0
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import io
2+
import os
3+
from typing import Optional
4+
5+
import torch
6+
7+
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
8+
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn
9+
10+
if _TPU_AVAILABLE:
11+
import torch_xla
12+
import torch_xla.core.xla_model as xm
13+
14+
15+
class SingleTPUPlugin(SingleDevicePlugin):
16+
def __init__(self, device: torch.device):
17+
super().__init__(device)
18+
19+
self.tpu_local_core_rank = 0
20+
self.tpu_global_core_rank = 0
21+
22+
def on_tpu(self) -> bool:
23+
return True
24+
25+
def pre_training(self) -> None:
26+
if isinstance(self.device, int):
27+
self.device = xm.xla_device(self.device)
28+
29+
self.tpu_local_core_rank = xm.get_local_ordinal()
30+
self.tpu_global_core_rank = xm.get_ordinal()
31+
32+
def post_training(self) -> None:
33+
model = self.lightning_module
34+
35+
if self.on_colab_kaggle:
36+
rank_zero_warn("cleaning up... please do not interrupt")
37+
self.save_spawn_weights(model)
38+
39+
@property
40+
def on_colab_kaggle(self) -> bool:
41+
return bool(os.getenv("COLAB_GPU") or os.getenv("KAGGLE_URL_BASE"))

0 commit comments

Comments
 (0)