1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ import pickle
1415from copy import deepcopy
1516
1617import torch
@@ -39,9 +40,6 @@ def update(self, x):
3940 def compute (self ):
4041 return self .x
4142
42- def extra_repr (self ) -> str :
43- return str (self .name ) if self .name else ''
44-
4543
4644def _setup_ddp (rank , worldsize ):
4745 import os
@@ -186,7 +184,11 @@ def lightning_log(fx, *args, **kwargs):
186184 assert result [k ].cumulated_batch_size == torch .tensor (1. ), k
187185
188186
189- def test_result_collection_restoration ():
187+ def my_sync_dist (x ):
188+ return x
189+
190+
191+ def test_result_collection_restoration (tmpdir ):
190192 """"
191193 This test make sure metrics are properly reloaded on failure.
192194 """
@@ -203,7 +205,7 @@ def lightning_log(fx, *args, **kwargs):
203205 nonlocal current_fx_name
204206 if current_fx_name != fx and batch_idx in (None , 0 ):
205207 result .reset (metrics = False , fx = fx )
206- result .log (fx , * args , ** kwargs )
208+ result .log (fx , * args , ** kwargs , sync_dist_fn = my_sync_dist )
207209 current_fx_name = fx
208210
209211 for _ in range (2 ):
@@ -230,38 +232,46 @@ def lightning_log(fx, *args, **kwargs):
230232 batch_log = result .metrics (on_step = True )[MetricSource .LOG ]
231233 assert set (batch_log ) == {"a_step" , "c" , "a_1_step" , "c_1" }
232234 assert set (batch_log ['c_1' ]) == {'1' , '2' }
235+
233236 result_copy = deepcopy (result )
237+ new_result = ResultCollection (True , torch .device ("cpu" ))
234238 state_dict = result .state_dict ()
235-
236- result = ResultCollection (True , torch .device ("cpu" ))
237- result .load_state_dict (state_dict , sync_fn = result_copy ['training_step.a' ].meta .sync .fn )
238-
239- assert result_copy .items () == result .items ()
240- assert result_copy ["training_step.c_1" ].meta == result ["training_step.c_1" ].meta
241-
242- batch_idx = None
239+ # check the sync fn is the expected
240+ assert state_dict ['items' ]['training_step.a' ]['meta' ].sync .fn == my_sync_dist
241+ new_result .load_state_dict (state_dict )
242+ assert result_copy == new_result
243+ # should match
244+ assert result_copy ['training_step.a' ].meta .sync .fn == new_result ['training_step.a' ].meta .sync .fn
243245
244246 epoch_log = result .metrics (on_step = False )[MetricSource .LOG ]
245247 epoch_log_copy = result_copy .metrics (on_step = False )[MetricSource .LOG ]
246248 assert epoch_log == epoch_log_copy
247249
248- assert set (epoch_log ) == {'a_1_epoch' , 'a_epoch' , 'b' , 'b_1' }
249- for k in epoch_log :
250- if k in {'a_epoch' , 'b' }:
251- assert epoch_log [k ] == cumulative_sum
252- else :
253- assert epoch_log [k ] == 1
254-
255250 lightning_log ('train_epoch_end' , 'a' , metric_a , on_step = False , on_epoch = True )
256-
257- result .reset ()
258- result_copy .reset ()
251+ epoch_log = result .metrics (on_step = False )[MetricSource .LOG ]
252+ assert epoch_log == {
253+ 'a_1_epoch' : 1 ,
254+ 'a_epoch' : cumulative_sum ,
255+ 'a' : cumulative_sum ,
256+ 'b' : cumulative_sum ,
257+ 'b_1' : 1
258+ }
259+
260+ # make sure can be pickled
261+ pickle .loads (pickle .dumps (result ))
262+ # make sure can be torch.loaded
263+ filepath = str (tmpdir / 'result' )
264+ torch .save (result , filepath )
265+ torch .load (filepath )
259266
260267 # assert metric state reset to default values
268+ result .reset ()
261269 assert metric_a .x == metric_a ._defaults ['x' ]
262270 assert metric_b .x == metric_b ._defaults ['x' ]
263271 assert metric_c .x == metric_c ._defaults ['x' ]
264272
273+ batch_idx = None
274+
265275
266276def test_lightning_module_logging_result_collection (tmpdir ):
267277
@@ -271,21 +281,24 @@ def __init__(self):
271281 super ().__init__ ()
272282 self .metric = DummyMetric ()
273283
274- def training_step (self , batch , batch_idx ):
284+ def validation_step (self , batch , batch_idx ):
275285 v = self .metric (batch_idx )
276286 self .log_dict ({"v" : v , "m" : self .metric })
277- return super ().training_step (batch , batch_idx )
287+ return super ().validation_step (batch , batch_idx )
278288
279289 def on_save_checkpoint (self , checkpoint ) -> None :
280- state_dict = self .trainer .train_loop .results .state_dict ()
281- checkpoint ["result_collections" ] = state_dict
282- self .trainer .train_loop .results .load_state_dict (state_dict )
283- assert self .trainer .train_loop .results ['training_step.v' ].meta .sync .fn is None
284- return super ().on_save_checkpoint (checkpoint )
290+ results = self .trainer ._results
291+ state_dict = results .state_dict ()
292+ # sync fn should be kept
293+ assert results ['validation_step.v' ].meta .sync .fn == self .trainer .training_type_plugin .reduce
294+ assert state_dict ['items' ]['validation_step.v' ]['meta' ].sync .fn == self .trainer .training_type_plugin .reduce
295+ results .load_state_dict (state_dict )
296+ # check if the sync fn was preserved
297+ assert results ['validation_step.v' ].meta .sync .fn == self .trainer .training_type_plugin .reduce
285298
286299 model = LoggingModel ()
287300 ckpt = ModelCheckpoint (dirpath = tmpdir , save_last = True )
288301 trainer = Trainer (
289- default_root_dir = tmpdir , max_epochs = 3 , limit_train_batches = 2 , limit_val_batches = 2 , callbacks = [ckpt ]
302+ default_root_dir = tmpdir , max_epochs = 2 , limit_train_batches = 2 , limit_val_batches = 2 , callbacks = [ckpt ]
290303 )
291304 trainer .fit (model )
0 commit comments