Skip to content

Commit bf59ab2

Browse files
authored
Merge branch 'master' into bugfix/should_stop
2 parents 4fc80df + f581411 commit bf59ab2

File tree

9 files changed

+79
-131
lines changed

9 files changed

+79
-131
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
200200
- Fixed torch distributed not available in setup hook for DDP ([#6506](https://github.com/PyTorchLightning/pytorch-lightning/pull/6506))
201201

202202

203+
- Fixed TPU Colab hang issue, post training ([#6816](https://github.com/PyTorchLightning/pytorch-lightning/pull/6816))
204+
205+
203206
- Enforce an epoch scheduler interval when using SWA ([#6588](https://github.com/PyTorchLightning/pytorch-lightning/pull/6588))
204207

205208

@@ -209,6 +212,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
209212
- Fixed `EarlyStopping` logic when `min_epochs` or `min_steps` requirement is not met ([#6705](https://github.com/PyTorchLightning/pytorch-lightning/pull/6705))
210213

211214

215+
- Fixed a bug where `TensorBoardLogger` would give a warning and not log correctly to a symbolic link `save_dir` ([#6730](https://github.com/PyTorchLightning/pytorch-lightning/pull/6730))
216+
212217

213218
## [1.2.6] - 2021-03-30
214219

pytorch_lightning/plugins/training_type/single_tpu.py

Lines changed: 9 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,10 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import os
15-
from typing import Optional, Union
16-
1714
import torch
1815

19-
from pytorch_lightning.core.lightning import LightningModule
2016
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
21-
from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle
22-
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn
17+
from pytorch_lightning.utilities import _TPU_AVAILABLE
2318
from pytorch_lightning.utilities.apply_func import move_data_to_device
2419

2520
if _TPU_AVAILABLE:
@@ -28,17 +23,22 @@
2823

2924
class SingleTPUPlugin(SingleDevicePlugin):
3025

31-
def __init__(self, device: Union[torch.device, int]):
32-
if isinstance(device, int):
33-
device = xm.xla_device(device)
26+
def __init__(self, device: int):
27+
28+
device = xm.xla_device(device)
3429
super().__init__(device)
3530

3631
self.tpu_local_core_rank = 0
3732
self.tpu_global_core_rank = 0
3833

34+
@property
3935
def on_tpu(self) -> bool:
4036
return True
4137

38+
@property
39+
def is_distributed(self) -> bool:
40+
return False
41+
4242
def model_to_device(self) -> None:
4343
self.model.to(self.root_device)
4444

@@ -49,29 +49,10 @@ def pre_dispatch(self) -> None:
4949
self.tpu_local_core_rank = xm.get_local_ordinal()
5050
self.tpu_global_core_rank = xm.get_ordinal()
5151

52-
def post_dispatch(self) -> None:
53-
model = self.lightning_module
54-
55-
if on_colab_kaggle():
56-
rank_zero_warn("cleaning up... please do not interrupt")
57-
self.save_spawn_weights(model)
58-
59-
def save_spawn_weights(self, model: LightningModule) -> Optional[str]:
60-
"""
61-
Dump a temporary checkpoint after ddp ends to get weights out of the process
62-
"""
63-
path = os.path.join(model.trainer.default_root_dir, "__temp_weight_distributed_end.ckpt")
64-
model.trainer.save_checkpoint(path)
65-
return path
66-
6752
def on_save(self, checkpoint: dict) -> dict:
6853
"""
6954
Move XLA tensors to CPU before saving
7055
Recommended on XLA Guide:
7156
https://github.com/pytorch/xla/blob/master/API_GUIDE.md#saving-and-loading-xla-tensors
7257
"""
7358
return move_data_to_device(checkpoint, torch.device("cpu"))
74-
75-
@property
76-
def is_distributed(self):
77-
return False

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 7 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,7 @@
2020
import torch
2121
import torch.multiprocessing as mp
2222

23-
from pytorch_lightning.core.lightning import LightningModule
2423
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
25-
from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle
2624
from pytorch_lightning.trainer.states import TrainerState
2725
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE, rank_zero_warn
2826
from pytorch_lightning.utilities.apply_func import apply_to_collection
@@ -32,28 +30,20 @@
3230

3331
if _TPU_AVAILABLE:
3432
import torch_xla.core.xla_model as xm
35-
import torch_xla.distributed.parallel_loader as xla_pl
3633
import torch_xla.distributed.xla_multiprocessing as xmp
3734
from torch_xla.core.xla_model import rendezvous
38-
from torch_xla.distributed.parallel_loader import ParallelLoader
35+
from torch_xla.distributed.parallel_loader import MpDeviceLoader
3936
else:
40-
xm, xla_pl, xmp, ParallelLoader, rendezvous = [None] * 5
37+
xm, xmp, MpDeviceLoader, rendezvous = [None] * 4
4138

4239
if _OMEGACONF_AVAILABLE:
4340
from omegaconf import DictConfig, ListConfig, OmegaConf
4441

4542

4643
class TPUSpawnPlugin(DDPSpawnPlugin):
4744

48-
def __init__(
49-
self,
50-
parallel_devices: Optional[List[torch.device]] = None,
51-
num_nodes: int = 1,
52-
**kwargs: Dict[str, Any]
53-
) -> None:
54-
super().__init__(
55-
parallel_devices, num_nodes=num_nodes, cluster_environment=None, sync_batchnorm=False, **kwargs
56-
)
45+
def __init__(self, parallel_devices: Optional[List[int]] = None, **kwargs: Dict[str, Any]) -> None:
46+
super().__init__(parallel_devices, num_nodes=1, cluster_environment=None, sync_batchnorm=False)
5747
self.tpu_local_core_rank = 0
5848
self.start_method = None
5949

@@ -74,10 +64,9 @@ def distributed_sampler_kwargs(self) -> dict:
7464
def is_distributed(self):
7565
return self.world_size != 1
7666

77-
def process_dataloader(self, dataloader: Union[Iterable, torch.utils.data.DataLoader]) -> ParallelLoader:
67+
def process_dataloader(self, dataloader: Union[Iterable, torch.utils.data.DataLoader]) -> MpDeviceLoader:
7868
device = xm.xla_device()
79-
dataloader = xla_pl.ParallelLoader(dataloader, [device])
80-
dataloader = dataloader.per_device_loader(device)
69+
dataloader = MpDeviceLoader(dataloader, device)
8170
return dataloader
8271

8372
def configure_ddp(self) -> None:
@@ -115,7 +104,6 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
115104

116105
results = trainer.run_stage()
117106

118-
self.__save_end_of_training_weights(self.lightning_module)
119107
self.transfer_distrib_spawn_state_on_fit_end(results)
120108

121109
# https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
@@ -125,12 +113,6 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
125113
if self.global_rank == 0:
126114
time.sleep(2)
127115

128-
def __save_end_of_training_weights(self, model: LightningModule) -> None:
129-
# when training ends on these platforms dump weights to get out of the main process
130-
if on_colab_kaggle():
131-
rank_zero_warn("cleaning up... please do not interrupt")
132-
self.save_spawn_weights(model)
133-
134116
def model_to_device(self) -> None:
135117
self._model.to(xm.xla_device())
136118

@@ -172,37 +154,7 @@ def broadcast(self, obj: object, src: int = 0) -> object:
172154
obj = torch.load(buffer)
173155
return obj
174156

175-
def load_spawn_weights(self, original_model: LightningModule) -> LightningModule:
176-
"""
177-
Load the temp weights saved in the process
178-
To recover the trained model from the ddp process we load the saved weights
179-
"""
180-
181-
loaded_model = original_model
182-
183-
if self.is_global_zero:
184-
# load weights saved in ddp
185-
path = os.path.join(original_model.trainer.default_root_dir, "__temp_weight_distributed_end.ckpt")
186-
loaded_model = original_model.__class__.load_from_checkpoint(path)
187-
188-
# copy loaded weights to old model
189-
original_model.load_state_dict(loaded_model.state_dict())
190-
191-
# remove ddp weights
192-
os.remove(path)
193-
194-
return loaded_model
195-
196-
def save_spawn_weights(self, model: LightningModule) -> Optional[str]:
197-
"""
198-
Dump a temporary checkpoint after ddp ends to get weights out of the process
199-
"""
200-
if model.trainer.is_global_zero:
201-
path = os.path.join(model.trainer.default_root_dir, "__temp_weight_distributed_end.ckpt")
202-
model.trainer.save_checkpoint(path)
203-
return path
204-
205-
def reduce_decision(self, decision: bool) -> bool:
157+
def reduce_boolean_decision(self, decision: bool) -> bool:
206158
decision = torch.tensor(int(decision), device=self.device)
207159
decision = self.reduce(decision, "sum")
208160
decision = bool(decision == self.world_size)
@@ -226,39 +178,6 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[
226178

227179
return output
228180

229-
def post_dispatch(self) -> None:
230-
# TODO: Check if trainer references can be resolved otherwise
231-
model = self.lightning_module
232-
233-
# restore main state with best weights
234-
best_path = self.mp_queue.get()
235-
last_path = self.mp_queue.get()
236-
self._results = self.mp_queue.get()
237-
238-
# transfer back the best path to the trainer
239-
if self.lightning_module.trainer.checkpoint_callback is not None:
240-
self.lightning_module.trainer.checkpoint_callback.best_model_path = best_path
241-
# todo, pass also bets score
242-
243-
# load last weights
244-
if last_path and model.trainer.state == TrainerState.FITTING:
245-
ckpt = torch.load(last_path, map_location=lambda storage, loc: storage)
246-
model.load_state_dict(ckpt)
247-
248-
self._model = model
249-
250-
# when training completes, load the weights back in main process
251-
self.__load_weights_on_main_process()
252-
253-
def __load_weights_on_main_process(self) -> None:
254-
model = self.lightning_module
255-
256-
# load weights if not interrupted
257-
if on_colab_kaggle() and model.trainer.state == TrainerState.FITTING:
258-
self.load_spawn_weights(model)
259-
260-
self._model = model
261-
262181
def _close_logger(self, trainer) -> None:
263182
if trainer.logger is not None:
264183
trainer.logger.finalize("success")

pytorch_lightning/trainer/connectors/accelerator_connector.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def use_dp(self) -> bool:
257257
def use_ddp(self) -> bool:
258258
return self._distrib_type in (
259259
DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED,
260-
DistributedType.DDP_SHARDED_SPAWN, DistributedType.DEEPSPEED
260+
DistributedType.DDP_SHARDED_SPAWN, DistributedType.DEEPSPEED, DistributedType.TPU_SPAWN
261261
)
262262

263263
@property
@@ -297,7 +297,8 @@ def parallel_devices(self) -> List[Union[torch.device, int]]:
297297
elif self.on_tpu:
298298
# explicitly don't make a tpu device here!
299299
# https://github.com/PyTorchLightning/pytorch-lightning/issues/3169
300-
devices = [i for i in self.parallel_device_ids]
300+
if isinstance(self.tpu_cores, int):
301+
devices = list(range(self.tpu_cores))
301302
else:
302303
devices = [torch.device("cpu")] * self.num_processes
303304
return devices
@@ -376,6 +377,7 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
376377
use_torchelastic_ddp = self.use_ddp and self.is_using_torchelastic
377378
use_ddp_spawn = self._distrib_type == DistributedType.DDP_SPAWN
378379
use_ddp_cpu_spawn = self.use_ddp and self.on_cpu
380+
use_tpu_spawn = self.on_tpu and self._distrib_type == DistributedType.TPU_SPAWN
379381
use_ddp_cpu_torch_elastic = use_ddp_cpu_spawn and self.is_using_torchelastic
380382
use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks
381383
use_ddp_sharded = self._distrib_type == DistributedType.DDP_SHARDED
@@ -386,7 +388,7 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
386388
if os.environ.get("PL_IN_DDP_SUBPROCESS", False):
387389
use_torchelastic_ddp = False
388390

389-
if self.on_tpu:
391+
if use_tpu_spawn:
390392
ddp_plugin_cls = TPUSpawnPlugin
391393
elif use_ddp_sharded:
392394
ddp_plugin_cls = DDPShardedPlugin
@@ -409,11 +411,8 @@ def select_training_type_plugin(self) -> TrainingTypePlugin:
409411
plugin = DataParallelPlugin(parallel_devices=self.parallel_devices)
410412
elif self.use_horovod:
411413
plugin = HorovodPlugin(parallel_devices=self.parallel_devices)
412-
elif self.on_tpu:
413-
if isinstance(self.tpu_cores, list):
414-
plugin = SingleTPUPlugin(self.tpu_id)
415-
else:
416-
plugin = TPUSpawnPlugin(parallel_devices=list(range(self.tpu_cores)))
414+
elif self.on_tpu and isinstance(self.tpu_cores, list):
415+
plugin = SingleTPUPlugin(self.tpu_id)
417416
else:
418417
single_gpu_ordinal = device_parser.determine_root_gpu_device(self.parallel_device_ids)
419418
plugin = SingleDevicePlugin(device=torch.device(f"cuda:{single_gpu_ordinal}" if self.on_gpu else "cpu"))
@@ -507,6 +506,8 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None):
507506
# special case with TPUs
508507
elif self.distributed_backend == 'tpu' or self.tpu_cores is not None:
509508
self._device_type = DeviceType.TPU
509+
if isinstance(self.tpu_cores, int):
510+
self._distrib_type = DistributedType.TPU_SPAWN
510511
elif self.distributed_backend and self._distrib_type is None:
511512
self._distrib_type = DistributedType(self.distributed_backend)
512513

@@ -515,9 +516,9 @@ def set_distributed_mode(self, distributed_backend: Optional[str] = None):
515516
if self.num_gpus > 0 and not _on_cpu:
516517
self._device_type = DeviceType.GPU
517518

518-
_distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2)
519+
_gpu_distrib_types = (DistributedType.DP, DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP2)
519520
# DP and DDP2 cannot run without GPU
520-
if self.num_gpus == 0 and self._distrib_type in _distrib_types and not _on_cpu:
521+
if self.num_gpus == 0 and self._distrib_type in _gpu_distrib_types and not _on_cpu:
521522
rank_zero_warn(
522523
'You requested distributed training on GPUs, but none is available, so we set backend to `ddp_cpu`.'
523524
)

pytorch_lightning/tuner/tuning.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,13 @@ def tune(self, model, train_dataloader, val_dataloaders, datamodule):
6161

6262
# Run learning rate finder:
6363
if self.trainer.auto_lr_find:
64-
self.lr_find(model, update_attr=True)
64+
self.lr_find(
65+
model,
66+
update_attr=True,
67+
train_dataloader=train_dataloader,
68+
val_dataloaders=val_dataloaders,
69+
datamodule=datamodule,
70+
)
6571

6672
self.trainer.state = TrainerState.FINISHED
6773

pytorch_lightning/utilities/cloud_io.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,29 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
1516
import io
1617
from distutils.version import LooseVersion
1718
from pathlib import Path
1819
from typing import IO, Union
1920

2021
import fsspec
22+
from fsspec.implementations.local import LocalFileSystem
23+
2124
import torch
2225

2326

27+
class _LightningLocalFileSystem(LocalFileSystem):
28+
"""Extension of ``fsspec.implementations.local.LocalFileSystem`` where ``LightningLocalFileSystem.isdir`` behaves
29+
the same as ``os.isdir``.
30+
31+
To be removed when https://github.com/intake/filesystem_spec/issues/591 is fixed.
32+
"""
33+
34+
def isdir(self, path: str) -> bool:
35+
return os.path.isdir(path) # follows symlinks
36+
37+
2438
def load(path_or_url: Union[str, IO, Path], map_location=None):
2539
if not isinstance(path_or_url, (str, Path)):
2640
# any sort of BytesIO or similiar
@@ -39,7 +53,7 @@ def get_filesystem(path: Union[str, Path]):
3953
return fsspec.filesystem(path.split(":", 1)[0])
4054
else:
4155
# use local filesystem
42-
return fsspec.filesystem("file")
56+
return _LightningLocalFileSystem()
4357

4458

4559
def atomic_save(checkpoint, filepath: str):

pytorch_lightning/utilities/enums.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ class DistributedType(LightningEnum):
6262
@staticmethod
6363
def interactive_compatible_types() -> List['DistributedType']:
6464
"""Returns a list containing interactive compatible DistributeTypes"""
65-
return [DistributedType.DP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED_SPAWN]
65+
return [
66+
DistributedType.DP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED_SPAWN, DistributedType.TPU_SPAWN
67+
]
6668

6769
def is_interactive_compatible(self) -> bool:
6870
"""Returns whether self is interactive compatible"""
@@ -72,6 +74,7 @@ def is_interactive_compatible(self) -> bool:
7274
DDP = 'ddp'
7375
DDP2 = 'ddp2'
7476
DDP_SPAWN = 'ddp_spawn'
77+
TPU_SPAWN = 'tpu_spawn'
7578
DEEPSPEED = 'deepspeed'
7679
HOROVOD = 'horovod'
7780
DDP_SHARDED = 'ddp_sharded'

0 commit comments

Comments
 (0)