Skip to content

Commit 78bfc1d

Browse files
tchatonYour Name
authored andcommitted
[bugfix] Add support for omegaconf and tpu (#6741)
* fix_hydra * update changelog Co-authored-by: Your Name <[email protected]>
1 parent b9e9743 commit 78bfc1d

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3838
- Added Autocast in validation, test and predict modes for Native AMP ([#6565](https://github.com/PyTorchLightning/pytorch-lightning/pull/6565))
3939

4040

41+
- Fixed resolve a bug with omegaconf and xm.save ([#6741](https://github.com/PyTorchLightning/pytorch-lightning/pull/6741))
42+
4143
## [1.2.4] - 2021-03-16
4244

4345
### Changed

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from pytorch_lightning.core.lightning import LightningModule
1010
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
1111
from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle
12-
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn
12+
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE, rank_zero_warn
13+
from pytorch_lightning.utilities.apply_func import apply_to_collection
1314
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
1415
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1516
from pytorch_lightning.utilities.seed import seed_everything
@@ -23,6 +24,9 @@
2324
else:
2425
xm, xla_pl, xmp, ParallelLoader, rendezvous = [None] * 5
2526

27+
if _OMEGACONF_AVAILABLE:
28+
from omegaconf import DictConfig, ListConfig, OmegaConf
29+
2630

2731
class TPUSpawnPlugin(DDPSpawnPlugin):
2832

@@ -294,4 +298,6 @@ def save_checkpoint(self, filepath, weights_only: bool = False):
294298
# dump states as a checkpoint dictionary object
295299
_checkpoint = self.lightning_module.trainer.checkpoint_connector.dump_checkpoint(weights_only)
296300
# Todo: TypeError: 'mappingproxy' object does not support item assignment
297-
self.save({k: v for k, v in _checkpoint.items() if k != "callbacks"}, filepath)
301+
if _OMEGACONF_AVAILABLE:
302+
checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container)
303+
self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath)

0 commit comments

Comments
 (0)