From a2a53fbab69aca7b407727f39ac814371c4f8221 Mon Sep 17 00:00:00 2001 From: Jim Plotts Date: Thu, 24 Feb 2022 12:16:19 -0500 Subject: [PATCH 1/3] Update test_catalyst. The API was changed from save_n_best= to save_best=. --- tests/test_catalyst.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_catalyst.py b/tests/test_catalyst.py index 48b2934b..d2944abd 100644 --- a/tests/test_catalyst.py +++ b/tests/test_catalyst.py @@ -141,7 +141,7 @@ def test_mnist(self): logdir=logdir, num_epochs=num_epochs, verbose=False, - callbacks=[CheckpointCallback(save_n_best=3, use_runner_logdir=True)] + callbacks=[CheckpointCallback(save_best=True, use_runner_logdir=True)] ) with open('./logs/_metrics.json') as f: From 26bf81140e22c21d8f452abb0676c6cc539df9a9 Mon Sep 17 00:00:00 2001 From: Jim Plotts Date: Thu, 24 Feb 2022 13:05:36 -0500 Subject: [PATCH 2/3] Remove deprecated parameter use_runner_logdir. --- tests/test_catalyst.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_catalyst.py b/tests/test_catalyst.py index d2944abd..aa902ce4 100644 --- a/tests/test_catalyst.py +++ b/tests/test_catalyst.py @@ -141,7 +141,7 @@ def test_mnist(self): logdir=logdir, num_epochs=num_epochs, verbose=False, - callbacks=[CheckpointCallback(save_best=True, use_runner_logdir=True)] + callbacks=[CheckpointCallback(save_best=True)] ) with open('./logs/_metrics.json') as f: From 83b5a09144377c7ee2e25cd973c79b80eb9f9e03 Mon Sep 17 00:00:00 2001 From: Jim Plotts Date: Fri, 25 Feb 2022 11:14:13 -0500 Subject: [PATCH 3/3] Fix catalyst unit tests. --- tests/test_catalyst.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/tests/test_catalyst.py b/tests/test_catalyst.py index aa902ce4..3b9c97d4 100644 --- a/tests/test_catalyst.py +++ b/tests/test_catalyst.py @@ -141,10 +141,18 @@ def test_mnist(self): logdir=logdir, num_epochs=num_epochs, verbose=False, - callbacks=[CheckpointCallback(save_best=True)] + callbacks=[CheckpointCallback( + logdir, + topk=3, + save_best=True, + loader_key="valid", + metric_key="loss", + minimize=True)] ) - - with open('./logs/_metrics.json') as f: + + with open('./logs/model.storage.json') as f: metrics = json.load(f) - self.assertTrue(metrics['train.3']['valid']['loss'] < metrics['train.1']['valid']['loss']) - self.assertTrue(metrics['best']['valid']['loss'] < 0.35) + storage = metrics['storage'] + self.assertEqual(3, len(storage)) + self.assertTrue(storage[0]['metric'] < storage[2]['metric']) + self.assertTrue(storage[0]['metric']< 0.35)