|
23 | 23 | from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin |
24 | 24 | from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle |
25 | 25 | from pytorch_lightning.trainer.states import TrainerState |
26 | | -from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn |
| 26 | +from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn, _OMEGACONF_AVAILABLE |
27 | 27 | from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp |
28 | 28 | from pytorch_lightning.utilities.exceptions import MisconfigurationException |
29 | 29 | from pytorch_lightning.utilities.seed import seed_everything |
| 30 | +from pytorch_lightning.utilities.apply_func import apply_to_collection |
30 | 31 |
|
31 | 32 | if _TPU_AVAILABLE: |
32 | 33 | import torch_xla.core.xla_model as xm |
|
37 | 38 | else: |
38 | 39 | xm, xla_pl, xmp, ParallelLoader, rendezvous = [None] * 5 |
39 | 40 |
|
| 41 | +if _OMEGACONF_AVAILABLE: |
| 42 | + from omegaconf import OmegaConf |
| 43 | + from omegaconf import DictConfig, ListConfig |
| 44 | + |
40 | 45 |
|
41 | 46 | class TPUSpawnPlugin(DDPSpawnPlugin): |
42 | 47 |
|
@@ -304,4 +309,6 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: |
304 | 309 | filepath: write-target file's path |
305 | 310 | """ |
306 | 311 | # Todo: TypeError: 'mappingproxy' object does not support item assignment |
| 312 | + if _OMEGACONF_AVAILABLE: |
| 313 | + checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container) |
307 | 314 | self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath) |
0 commit comments