diff --git a/tests/test_catalyst.py b/tests/test_catalyst.py index 48b2934b..3b9c97d4 100644 --- a/tests/test_catalyst.py +++ b/tests/test_catalyst.py @@ -141,10 +141,18 @@ def test_mnist(self): logdir=logdir, num_epochs=num_epochs, verbose=False, - callbacks=[CheckpointCallback(save_n_best=3, use_runner_logdir=True)] + callbacks=[CheckpointCallback( + logdir, + topk=3, + save_best=True, + loader_key="valid", + metric_key="loss", + minimize=True)] ) - - with open('./logs/_metrics.json') as f: + + with open('./logs/model.storage.json') as f: metrics = json.load(f) - self.assertTrue(metrics['train.3']['valid']['loss'] < metrics['train.1']['valid']['loss']) - self.assertTrue(metrics['best']['valid']['loss'] < 0.35) + storage = metrics['storage'] + self.assertEqual(3, len(storage)) + self.assertTrue(storage[0]['metric'] < storage[2]['metric']) + self.assertTrue(storage[0]['metric']< 0.35)