Skip to content

Commit ecae7d2

Browse files
committed
replaced all instances of XLA_AVAILABLE
1 parent 1d39d92 commit ecae7d2

File tree

12 files changed

+83
-91
lines changed

12 files changed

+83
-91
lines changed

pytorch_lightning/accelerators/tpu_backend.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,12 @@
2222
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
2323
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2424
from pytorch_lightning.accelerators.base_backend import Accelerator
25+
from pytorch_lightning.utilities.xla_device_utils import TPU_AVAILABLE
2526

26-
try:
27-
import torch_xla
27+
if TPU_AVAILABLE:
2828
import torch_xla.core.xla_model as xm
2929
import torch_xla.distributed.xla_multiprocessing as xmp
3030
import torch_xla.distributed.parallel_loader as xla_pl
31-
except ImportError:
32-
XLA_AVAILABLE = False
33-
else:
34-
XLA_AVAILABLE = True
3531

3632

3733
class TPUBackend(Accelerator):
@@ -44,7 +40,7 @@ def __init__(self, trainer):
4440
def setup(self, model):
4541
rank_zero_info(f'training on {self.trainer.tpu_cores} TPU cores')
4642

47-
if not XLA_AVAILABLE:
43+
if not TPU_AVAILABLE:
4844
raise MisconfigurationException('PyTorch XLA not installed.')
4945

5046
# see: https://discuss.pytorch.org/t/segfault-with-multiprocessing-queue/81292/2
@@ -164,7 +160,7 @@ def to_device(self, batch):
164160
See Also:
165161
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
166162
"""
167-
if not XLA_AVAILABLE:
163+
if not TPU_AVAILABLE:
168164
raise MisconfigurationException(
169165
'Requested to transfer batch to TPU but XLA is not available.'
170166
' Are you sure this machine has TPUs?'

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,16 @@
2626
from pytorch_lightning import _logger as log
2727
from pytorch_lightning.callbacks.base import Callback
2828
from pytorch_lightning.utilities import rank_zero_warn
29+
from pytorch_lightning.utilities.xla_device_utils import TPU_AVAILABLE
2930
import os
3031

32+
3133
torch_inf = torch.tensor(np.Inf)
3234

33-
try:
34-
import torch_xla
35+
if TPU_AVAILABLE:
3536
import torch_xla.core.xla_model as xm
36-
except ImportError:
37-
XLA_AVAILABLE = False
38-
else:
39-
XLA_AVAILABLE = True
37+
import torch_xla
38+
4039

4140

4241
class EarlyStopping(Callback):
@@ -201,7 +200,7 @@ def _run_early_stopping_check(self, trainer, pl_module):
201200
if not isinstance(current, torch.Tensor):
202201
current = torch.tensor(current, device=pl_module.device)
203202

204-
if trainer.use_tpu and XLA_AVAILABLE:
203+
if trainer.use_tpu and TPU_AVAILABLE:
205204
current = current.cpu()
206205

207206
if self.monitor_op(current - self.min_delta, self.best_score):

pytorch_lightning/core/lightning.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,10 @@
3939
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
4040
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args
4141
from pytorch_lightning.core.step_result import TrainResult, EvalResult
42+
from pytorch_lightning.utilities.xla_device_utils import TPU_AVAILABLE
4243

43-
try:
44+
if TPU_AVAILABLE:
4445
import torch_xla.core.xla_model as xm
45-
except ImportError:
46-
XLA_AVAILABLE = False
47-
else:
48-
XLA_AVAILABLE = True
4946

5047

5148
class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, ModelHooks, Module):

pytorch_lightning/trainer/data_loading.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,21 +27,16 @@
2727
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2828
from pytorch_lightning.utilities.debugging import InternalDebugger
2929
from pytorch_lightning.utilities.model_utils import is_overridden
30-
30+
from pytorch_lightning.utilities.xla_device_utils import TPU_AVAILABLE
3131

3232
try:
3333
from apex import amp
3434
except ImportError:
3535
amp = None
3636

37-
try:
37+
if TPU_AVAILABLE:
3838
import torch_xla
3939
import torch_xla.core.xla_model as xm
40-
import torch_xla.distributed.xla_multiprocessing as xmp
41-
except ImportError:
42-
XLA_AVAILABLE = False
43-
else:
44-
XLA_AVAILABLE = True
4540

4641
try:
4742
import horovod.torch as hvd
@@ -336,7 +331,7 @@ def request_dataloader(self, dataloader_fx: Callable) -> DataLoader:
336331
torch_distrib.barrier()
337332

338333
# data download/load on TPU
339-
elif self.use_tpu and XLA_AVAILABLE:
334+
elif self.use_tpu and TPU_AVAILABLE:
340335
# all processes wait until data download has happened
341336
torch_xla.core.xla_model.rendezvous('pl.TrainerDataLoadingMixin.get_dataloaders')
342337

pytorch_lightning/trainer/distrib_data_parallel.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def train_fx(trial_hparams, cluster_manager, _):
141141
from pytorch_lightning.utilities.cloud_io import atomic_save
142142
from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_info
143143
from pytorch_lightning.utilities.exceptions import MisconfigurationException
144+
from pytorch_lightning.utilities.xla_device_utils import TPU_AVAILABLE
144145

145146
try:
146147
from apex import amp
@@ -155,13 +156,6 @@ def train_fx(trial_hparams, cluster_manager, _):
155156
HOROVOD_AVAILABLE = True
156157

157158

158-
try:
159-
import torch_xla
160-
except ImportError:
161-
XLA_AVAILABLE = False
162-
else:
163-
XLA_AVAILABLE = True
164-
165159

166160
class TrainerDDPMixin(ABC):
167161

@@ -303,7 +297,7 @@ def set_distributed_mode(self, distributed_backend):
303297

304298
rank_zero_info(f'GPU available: {torch.cuda.is_available()}, used: {self.on_gpu}')
305299
num_cores = self.tpu_cores if self.tpu_cores is not None else 0
306-
rank_zero_info(f'TPU available: {XLA_AVAILABLE}, using: {num_cores} TPU cores')
300+
rank_zero_info(f'TPU available: {TPU_AVAILABLE}, using: {num_cores} TPU cores')
307301

308302
if torch.cuda.is_available() and not self.on_gpu:
309303
rank_zero_warn('GPU available but not used. Set the --gpus flag when calling the script.')

pytorch_lightning/trainer/distrib_parts.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,6 @@
4242
except ImportError:
4343
amp = None
4444

45-
try:
46-
import torch_xla.core.xla_model as xm
47-
except ImportError:
48-
XLA_AVAILABLE = False
49-
else:
50-
XLA_AVAILABLE = True
5145

5246
try:
5347
import horovod.torch as hvd

pytorch_lightning/trainer/evaluation_loop.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -135,14 +135,6 @@
135135
from pytorch_lightning.core.step_result import EvalResult, Result
136136
from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop
137137

138-
try:
139-
import torch_xla.distributed.parallel_loader as xla_pl
140-
import torch_xla.core.xla_model as xm
141-
except ImportError:
142-
XLA_AVAILABLE = False
143-
else:
144-
XLA_AVAILABLE = True
145-
146138
try:
147139
import horovod.torch as hvd
148140
except (ModuleNotFoundError, ImportError):

pytorch_lightning/trainer/trainer.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@
5959
from pytorch_lightning.utilities.model_utils import is_overridden
6060

6161
# warnings to ignore in trainer
62+
from pytorch_lightning.utilities.xla_device_utils import TPU_AVAILABLE
63+
6264
warnings.filterwarnings(
6365
'ignore', message='torch.distributed.reduce_op is deprecated, ' 'please use torch.distributed.ReduceOp instead'
6466
)
@@ -68,14 +70,8 @@
6870
except ImportError:
6971
amp = None
7072

71-
try:
73+
if TPU_AVAILABLE:
7274
import torch_xla
73-
import torch_xla.core.xla_model as xm
74-
import torch_xla.distributed.xla_multiprocessing as xmp
75-
except ImportError:
76-
XLA_AVAILABLE = False
77-
else:
78-
XLA_AVAILABLE = True
7975

8076
try:
8177
import horovod.torch as hvd
@@ -1131,7 +1127,7 @@ def run_pretrain_routine(self, model: LightningModule):
11311127
torch_distrib.barrier()
11321128

11331129
# wait for all models to restore weights
1134-
if self.on_tpu and XLA_AVAILABLE:
1130+
if self.on_tpu and TPU_AVAILABLE:
11351131
# wait for all processes to catch up
11361132
torch_xla.core.xla_model.rendezvous("pl.Trainer.run_pretrain_routine")
11371133

@@ -1389,7 +1385,7 @@ def barrier(self, name):
13891385
pass
13901386
# torch_distrib.barrier()
13911387

1392-
if self.on_tpu and XLA_AVAILABLE:
1388+
if self.on_tpu and TPU_AVAILABLE:
13931389
# wait for all processes to catch up
13941390
torch_xla.core.xla_model.rendezvous(f'pl.Trainer.{name}')
13951391

pytorch_lightning/trainer/training_io.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,10 @@
118118
from pytorch_lightning.utilities.cloud_io import load as pl_load
119119
from pytorch_lightning.utilities.cloud_io import makedirs
120120
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
121+
from pytorch_lightning.utilities.xla_device_utils import TPU_AVAILABLE
121122

122-
try:
123+
if TPU_AVAILABLE:
123124
import torch_xla
124-
import torch_xla.core.xla_model as xm
125-
import torch_xla.distributed.xla_multiprocessing as xmp
126-
except ImportError:
127-
XLA_AVAILABLE = False
128-
else:
129-
XLA_AVAILABLE = True
130125

131126
try:
132127
from apex import amp
@@ -209,7 +204,7 @@ def restore_weights(self, model: LightningModule):
209204
torch_distrib.barrier()
210205

211206
# wait for all models to restore weights
212-
if self.on_tpu and XLA_AVAILABLE:
207+
if self.on_tpu and TPU_AVAILABLE:
213208
# wait for all processes to catch up
214209
torch_xla.core.xla_model.rendezvous("pl.TrainerIOMixin.restore_weights")
215210

pytorch_lightning/trainer/training_loop.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -182,19 +182,13 @@ def training_step(self, batch, batch_idx):
182182
from pytorch_lightning.utilities.memory import recursive_detach
183183
from pytorch_lightning.utilities.parsing import AttributeDict
184184
from pytorch_lightning.utilities.model_utils import is_overridden
185+
from pytorch_lightning.utilities.xla_device_utils import TPU_AVAILABLE
185186

186187
try:
187188
from apex import amp
188189
except ImportError:
189190
amp = None
190191

191-
try:
192-
import torch_xla.distributed.parallel_loader as xla_pl
193-
import torch_xla.core.xla_model as xm
194-
except ImportError:
195-
XLA_AVAILABLE = False
196-
else:
197-
XLA_AVAILABLE = True
198192

199193
try:
200194
import horovod.torch as hvd
@@ -957,7 +951,7 @@ def call_optimizer_step(self, optimizer, opt_idx, batch_idx, split_batch):
957951
).loss
958952

959953
# apply TPU optimizer
960-
if self.use_tpu and XLA_AVAILABLE:
954+
if self.use_tpu and TPU_AVAILABLE:
961955
model.optimizer_step(self.current_epoch, batch_idx,
962956
optimizer, opt_idx, lambda_closure, on_tpu=True)
963957

0 commit comments

Comments
 (0)