|
14 | 14 | import io |
15 | 15 | import os |
16 | 16 | import re |
| 17 | +import time |
17 | 18 | from typing import Any, Dict, Iterable, List, Optional, Union |
18 | 19 |
|
19 | 20 | import torch |
|
23 | 24 | from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin |
24 | 25 | from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle |
25 | 26 | from pytorch_lightning.trainer.states import TrainerState |
26 | | -from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn, _OMEGACONF_AVAILABLE |
| 27 | +from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE, rank_zero_warn |
| 28 | +from pytorch_lightning.utilities.apply_func import apply_to_collection |
27 | 29 | from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp |
28 | 30 | from pytorch_lightning.utilities.exceptions import MisconfigurationException |
29 | 31 | from pytorch_lightning.utilities.seed import seed_everything |
30 | | -from pytorch_lightning.utilities.apply_func import apply_to_collection |
31 | 32 |
|
32 | 33 | if _TPU_AVAILABLE: |
33 | 34 | import torch_xla.core.xla_model as xm |
|
39 | 40 | xm, xla_pl, xmp, ParallelLoader, rendezvous = [None] * 5 |
40 | 41 |
|
41 | 42 | if _OMEGACONF_AVAILABLE: |
42 | | - from omegaconf import OmegaConf |
43 | | - from omegaconf import DictConfig, ListConfig |
| 43 | + from omegaconf import DictConfig, ListConfig, OmegaConf |
44 | 44 |
|
45 | 45 |
|
46 | 46 | class TPUSpawnPlugin(DDPSpawnPlugin): |
@@ -118,6 +118,9 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None: |
118 | 118 | self.__save_end_of_training_weights(self.lightning_module) |
119 | 119 | self.transfer_distrib_spawn_state_on_fit_end(results) |
120 | 120 |
|
| 121 | + if self.global_rank == 0: |
| 122 | + time.sleep(2) |
| 123 | + |
121 | 124 | self.barrier("end-process") |
122 | 125 |
|
123 | 126 | def __save_end_of_training_weights(self, model: LightningModule) -> None: |
|
0 commit comments