File tree Expand file tree Collapse file tree 1 file changed +41
-0
lines changed
pytorch_lightning/plugins/training_type Expand file tree Collapse file tree 1 file changed +41
-0
lines changed Original file line number Diff line number Diff line change 1+ import io
2+ import os
3+ from typing import Optional
4+
5+ import torch
6+
7+ from pytorch_lightning .plugins .training_type .single_device import SingleDevicePlugin
8+ from pytorch_lightning .utilities import _TPU_AVAILABLE , rank_zero_warn
9+
10+ if _TPU_AVAILABLE :
11+ import torch_xla
12+ import torch_xla .core .xla_model as xm
13+
14+
15+ class SingleTPUPlugin (SingleDevicePlugin ):
16+ def __init__ (self , device : torch .device ):
17+ super ().__init__ (device )
18+
19+ self .tpu_local_core_rank = 0
20+ self .tpu_global_core_rank = 0
21+
22+ def on_tpu (self ) -> bool :
23+ return True
24+
25+ def pre_training (self ) -> None :
26+ if isinstance (self .device , int ):
27+ self .device = xm .xla_device (self .device )
28+
29+ self .tpu_local_core_rank = xm .get_local_ordinal ()
30+ self .tpu_global_core_rank = xm .get_ordinal ()
31+
32+ def post_training (self ) -> None :
33+ model = self .lightning_module
34+
35+ if self .on_colab_kaggle :
36+ rank_zero_warn ("cleaning up... please do not interrupt" )
37+ self .save_spawn_weights (model )
38+
39+ @property
40+ def on_colab_kaggle (self ) -> bool :
41+ return bool (os .getenv ("COLAB_GPU" ) or os .getenv ("KAGGLE_URL_BASE" ))
You can’t perform that action at this time.
0 commit comments