Skip to content

Commit 0930b1f

Browse files
committed
fix typing
1 parent 197ff6b commit 0930b1f

File tree

2 files changed

+5
-7
lines changed

2 files changed

+5
-7
lines changed

pytorch_lightning/strategies/launchers/spawn.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,15 +100,13 @@ def _wrapping_function(
100100
def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer: "pl.Trainer") -> None:
101101
# transfer back the best path to the trainer
102102
if trainer.checkpoint_callback:
103-
trainer.checkpoint_callback.best_model_path = spawn_output.best_model_path
103+
trainer.checkpoint_callback.best_model_path = str(spawn_output.best_model_path)
104104

105105
# TODO: pass also best score
106106
# load last weights
107107
if spawn_output.weights_path is not None:
108-
ckpt = self._strategy.checkpoint_io.load_checkpoint(
109-
spawn_output.weights_path, map_location=(lambda storage, loc: storage)
110-
)
111-
trainer.lightning_module.load_state_dict(ckpt)
108+
ckpt = self._strategy.checkpoint_io.load_checkpoint(spawn_output.weights_path)
109+
trainer.lightning_module.load_state_dict(ckpt) # type: ignore[arg-type]
112110
self._strategy.checkpoint_io.remove_checkpoint(spawn_output.weights_path)
113111

114112
trainer.state = spawn_output.trainer_state
@@ -129,7 +127,7 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt
129127
state_dict = trainer.lightning_module.state_dict()
130128

131129
if self._strategy.global_rank != 0:
132-
return
130+
return None
133131

134132
# save the last weights
135133
weights_path = None

pytorch_lightning/strategies/launchers/xla_spawn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt
116116

117117
# We use `local_rank` here as separate filesystems are used for each VM for TPU Pod Training
118118
if self._strategy.local_rank != 0:
119-
return
119+
return None
120120

121121
# adds the `callback_metrics` to the queue
122122
extra = _FakeQueue()

0 commit comments

Comments
 (0)