Skip to content

Commit caa2275

Browse files
committed
Fix save checkpoint logic for TPUs
1 parent beda8e8 commit caa2275

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
8787
trainer.accelerator.setup_optimizers(trainer)
8888
trainer.precision_plugin.connect(self._model, None, None)
8989

90+
# replace trainer save_checkpoint to use `xm.save`
91+
trainer.save_checkpoint = self.save_checkpoint
9092
self.barrier("pre-run-stage")
9193

9294
results = trainer.train_or_test_or_predict()
@@ -201,12 +203,14 @@ def test_step(self, *args, **kwargs):
201203
def predict(self, *args, **kwargs):
202204
return self.lightning_module.predict(*args, **kwargs)
203205

204-
def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None:
206+
def save_checkpoint(self, filepath: str, weights_only: bool = False) -> None:
205207
"""Save model/training states as a checkpoint file through state-dump and file-write.
206208
Args:
207-
checkpoint: dict containing model and trainer state
208209
filepath: write-target file's path
210+
weights_only: saving model weights only
209211
"""
212+
# dump states as a checkpoint dictionary object
213+
checkpoint = self.lightning_module.trainer.checkpoint_connector.dump_checkpoint(weights_only)
210214
# Todo: TypeError: 'mappingproxy' object does not support item assignment
211215
if _OMEGACONF_AVAILABLE:
212216
checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container)

tests/models/test_tpu.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,8 @@ def test_tpu_grad_norm(tmpdir):
210210
progress_bar_refresh_rate=0,
211211
max_epochs=4,
212212
tpu_cores=1,
213-
limit_train_batches=0.7,
214-
limit_val_batches=0.7,
213+
limit_train_batches=10,
214+
limit_val_batches=10,
215215
gradient_clip_val=0.5,
216216
)
217217

0 commit comments

Comments
 (0)