@@ -87,6 +87,8 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
8787 trainer .accelerator .setup_optimizers (trainer )
8888 trainer .precision_plugin .connect (self ._model , None , None )
8989
90+ # replace trainer save_checkpoint to use `xm.save`
91+ trainer .save_checkpoint = self .save_checkpoint
9092 self .barrier ("pre-run-stage" )
9193
9294 results = trainer .train_or_test_or_predict ()
@@ -201,12 +203,14 @@ def test_step(self, *args, **kwargs):
201203 def predict (self , * args , ** kwargs ):
202204 return self .lightning_module .predict (* args , ** kwargs )
203205
204- def save_checkpoint (self , checkpoint : Dict [ str , Any ], filepath : str ) -> None :
206+ def save_checkpoint (self , filepath : str , weights_only : bool = False ) -> None :
205207 """Save model/training states as a checkpoint file through state-dump and file-write.
206208 Args:
207- checkpoint: dict containing model and trainer state
208209 filepath: write-target file's path
210+ weights_only: saving model weights only
209211 """
212+ # dump states as a checkpoint dictionary object
213+ checkpoint = self .lightning_module .trainer .checkpoint_connector .dump_checkpoint (weights_only )
210214 # Todo: TypeError: 'mappingproxy' object does not support item assignment
211215 if _OMEGACONF_AVAILABLE :
212216 checkpoint = apply_to_collection (checkpoint , (DictConfig , ListConfig ), OmegaConf .to_container )
0 commit comments