@@ -95,13 +95,20 @@ def set_world_ranks(self, process_idx):
9595 self .global_rank = self .node_rank * self .num_processes + self .local_rank
9696 self .world_size = self .num_nodes * self .num_processes
9797
98+ @property
99+ def mp_spawn_kwargs (self ):
100+ return {
101+ "args" : (self .lightning_module .trainer , self .mp_queue ),
102+ "nprocs" : self .num_processes ,
103+ }
104+
98105 def start_training (self , trainer ):
99- mp .spawn (self .new_process , nprocs = self .num_processes , args = ( trainer , self . mp_queue ) )
106+ mp .spawn (self .new_process , ** self .mp_spawn_kwargs )
100107 # reset optimizers, since main process is never used for training and thus does not have a valid optim state
101108 trainer .optimizers = []
102109
103110 def start_testing (self , trainer ):
104- mp .spawn (self .new_process , nprocs = self .num_processes , args = ( trainer , self . mp_queue ) )
111+ mp .spawn (self .new_process , ** self .mp_spawn_kwargs )
105112
106113 def new_process (self , process_idx , trainer , mp_queue ):
107114 self .mp_queue = mp_queue
@@ -173,7 +180,6 @@ def pre_configure_ddp(self):
173180 self ._ddp_kwargs ["find_unused_parameters" ] = True
174181
175182 def configure_ddp (self ):
176-
177183 self .pre_configure_ddp ()
178184 self ._model = DistributedDataParallel (
179185 LightningDistributedModule (self .model ),
@@ -197,6 +203,9 @@ def determine_ddp_device_ids(self):
197203 return None
198204 return [self .root_device .index ]
199205
206+ def on_save (self , checkpoint : dict ) -> dict :
207+ return checkpoint
208+
200209 def transfer_distrib_spawn_state_on_fit_end (self , results ):
201210 # TODO: is there a better way than accessing callback through model -> trainer -> callback?
202211 checkpoint_callback = self .lightning_module .trainer .checkpoint_callback
@@ -210,7 +219,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, results):
210219 # TODO: is there a better way than accessing trainer through model -> trainer?
211220 if not self .lightning_module .trainer .testing and best_model_path is not None and len (best_model_path ) > 0 :
212221 last_path = re .sub (".ckpt" , ".tmp_end.ckpt" , best_model_path )
213- atomic_save (self .lightning_module .state_dict (), last_path )
222+ atomic_save (self .on_save ( self . lightning_module .state_dict () ), last_path )
214223
215224 # todo, pass complete checkpoint as state dictionary
216225 self .mp_queue .put (best_model_path )
0 commit comments