Skip to content

Commit 03e605f

Browse files
committed
Update tests
1 parent 9ced763 commit 03e605f

File tree

1 file changed

+44
-31
lines changed

1 file changed

+44
-31
lines changed

tests/core/test_metric_result_integration.py

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import pickle
1415
from copy import deepcopy
1516

1617
import torch
@@ -39,9 +40,6 @@ def update(self, x):
3940
def compute(self):
4041
return self.x
4142

42-
def extra_repr(self) -> str:
43-
return str(self.name) if self.name else ''
44-
4543

4644
def _setup_ddp(rank, worldsize):
4745
import os
@@ -186,7 +184,11 @@ def lightning_log(fx, *args, **kwargs):
186184
assert result[k].cumulated_batch_size == torch.tensor(1.), k
187185

188186

189-
def test_result_collection_restoration():
187+
def my_sync_dist(x):
188+
return x
189+
190+
191+
def test_result_collection_restoration(tmpdir):
190192
""""
191193
This test make sure metrics are properly reloaded on failure.
192194
"""
@@ -203,7 +205,7 @@ def lightning_log(fx, *args, **kwargs):
203205
nonlocal current_fx_name
204206
if current_fx_name != fx and batch_idx in (None, 0):
205207
result.reset(metrics=False, fx=fx)
206-
result.log(fx, *args, **kwargs)
208+
result.log(fx, *args, **kwargs, sync_dist_fn=my_sync_dist)
207209
current_fx_name = fx
208210

209211
for _ in range(2):
@@ -230,38 +232,46 @@ def lightning_log(fx, *args, **kwargs):
230232
batch_log = result.metrics(on_step=True)[MetricSource.LOG]
231233
assert set(batch_log) == {"a_step", "c", "a_1_step", "c_1"}
232234
assert set(batch_log['c_1']) == {'1', '2'}
235+
233236
result_copy = deepcopy(result)
237+
new_result = ResultCollection(True, torch.device("cpu"))
234238
state_dict = result.state_dict()
235-
236-
result = ResultCollection(True, torch.device("cpu"))
237-
result.load_state_dict(state_dict, sync_fn=result_copy['training_step.a'].meta.sync.fn)
238-
239-
assert result_copy.items() == result.items()
240-
assert result_copy["training_step.c_1"].meta == result["training_step.c_1"].meta
241-
242-
batch_idx = None
239+
# check the sync fn is the expected
240+
assert state_dict['items']['training_step.a']['meta'].sync.fn == my_sync_dist
241+
new_result.load_state_dict(state_dict)
242+
assert result_copy == new_result
243+
# should match
244+
assert result_copy['training_step.a'].meta.sync.fn == new_result['training_step.a'].meta.sync.fn
243245

244246
epoch_log = result.metrics(on_step=False)[MetricSource.LOG]
245247
epoch_log_copy = result_copy.metrics(on_step=False)[MetricSource.LOG]
246248
assert epoch_log == epoch_log_copy
247249

248-
assert set(epoch_log) == {'a_1_epoch', 'a_epoch', 'b', 'b_1'}
249-
for k in epoch_log:
250-
if k in {'a_epoch', 'b'}:
251-
assert epoch_log[k] == cumulative_sum
252-
else:
253-
assert epoch_log[k] == 1
254-
255250
lightning_log('train_epoch_end', 'a', metric_a, on_step=False, on_epoch=True)
256-
257-
result.reset()
258-
result_copy.reset()
251+
epoch_log = result.metrics(on_step=False)[MetricSource.LOG]
252+
assert epoch_log == {
253+
'a_1_epoch': 1,
254+
'a_epoch': cumulative_sum,
255+
'a': cumulative_sum,
256+
'b': cumulative_sum,
257+
'b_1': 1
258+
}
259+
260+
# make sure can be pickled
261+
pickle.loads(pickle.dumps(result))
262+
# make sure can be torch.loaded
263+
filepath = str(tmpdir / 'result')
264+
torch.save(result, filepath)
265+
torch.load(filepath)
259266

260267
# assert metric state reset to default values
268+
result.reset()
261269
assert metric_a.x == metric_a._defaults['x']
262270
assert metric_b.x == metric_b._defaults['x']
263271
assert metric_c.x == metric_c._defaults['x']
264272

273+
batch_idx = None
274+
265275

266276
def test_lightning_module_logging_result_collection(tmpdir):
267277

@@ -271,21 +281,24 @@ def __init__(self):
271281
super().__init__()
272282
self.metric = DummyMetric()
273283

274-
def training_step(self, batch, batch_idx):
284+
def validation_step(self, batch, batch_idx):
275285
v = self.metric(batch_idx)
276286
self.log_dict({"v": v, "m": self.metric})
277-
return super().training_step(batch, batch_idx)
287+
return super().validation_step(batch, batch_idx)
278288

279289
def on_save_checkpoint(self, checkpoint) -> None:
280-
state_dict = self.trainer.train_loop.results.state_dict()
281-
checkpoint["result_collections"] = state_dict
282-
self.trainer.train_loop.results.load_state_dict(state_dict)
283-
assert self.trainer.train_loop.results['training_step.v'].meta.sync.fn is None
284-
return super().on_save_checkpoint(checkpoint)
290+
results = self.trainer._results
291+
state_dict = results.state_dict()
292+
# sync fn should be kept
293+
assert results['validation_step.v'].meta.sync.fn == self.trainer.training_type_plugin.reduce
294+
assert state_dict['items']['validation_step.v']['meta'].sync.fn == self.trainer.training_type_plugin.reduce
295+
results.load_state_dict(state_dict)
296+
# check if the sync fn was preserved
297+
assert results['validation_step.v'].meta.sync.fn == self.trainer.training_type_plugin.reduce
285298

286299
model = LoggingModel()
287300
ckpt = ModelCheckpoint(dirpath=tmpdir, save_last=True)
288301
trainer = Trainer(
289-
default_root_dir=tmpdir, max_epochs=3, limit_train_batches=2, limit_val_batches=2, callbacks=[ckpt]
302+
default_root_dir=tmpdir, max_epochs=2, limit_train_batches=2, limit_val_batches=2, callbacks=[ckpt]
290303
)
291304
trainer.fit(model)

0 commit comments

Comments
 (0)