|
9 | 9 | from pytorch_lightning.core.lightning import LightningModule |
10 | 10 | from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin |
11 | 11 | from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle |
12 | | -from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn |
| 12 | +from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE, rank_zero_warn |
| 13 | +from pytorch_lightning.utilities.apply_func import apply_to_collection |
13 | 14 | from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp |
14 | 15 | from pytorch_lightning.utilities.exceptions import MisconfigurationException |
15 | 16 | from pytorch_lightning.utilities.seed import seed_everything |
|
23 | 24 | else: |
24 | 25 | xm, xla_pl, xmp, ParallelLoader, rendezvous = [None] * 5 |
25 | 26 |
|
| 27 | +if _OMEGACONF_AVAILABLE: |
| 28 | + from omegaconf import DictConfig, ListConfig, OmegaConf |
| 29 | + |
26 | 30 |
|
27 | 31 | class TPUSpawnPlugin(DDPSpawnPlugin): |
28 | 32 |
|
@@ -294,4 +298,6 @@ def save_checkpoint(self, filepath, weights_only: bool = False): |
294 | 298 | # dump states as a checkpoint dictionary object |
295 | 299 | _checkpoint = self.lightning_module.trainer.checkpoint_connector.dump_checkpoint(weights_only) |
296 | 300 | # Todo: TypeError: 'mappingproxy' object does not support item assignment |
297 | | - self.save({k: v for k, v in _checkpoint.items() if k != "callbacks"}, filepath) |
| 301 | + if _OMEGACONF_AVAILABLE: |
| 302 | + checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container) |
| 303 | + self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath) |
0 commit comments