From edcc225213c92ef92f8cc63d780e9bb6405614b0 Mon Sep 17 00:00:00 2001 From: Jan-Henrik Lambrechts Date: Fri, 4 Dec 2020 16:59:23 +0100 Subject: [PATCH 01/11] Added changeable extension variable for model checkpoints --- pytorch_lightning/callbacks/model_checkpoint.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 79feba5a4190d..aeb956d304b13 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -140,6 +140,7 @@ class ModelCheckpoint(Callback): CHECKPOINT_JOIN_CHAR = "-" CHECKPOINT_NAME_LAST = "last" + FILE_EXTENSION = "ckpt" def __init__( self, @@ -442,7 +443,7 @@ def format_checkpoint_name( ) if ver is not None: filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}")) - ckpt_name = f"{filename}.ckpt" + ckpt_name = f"{filename}.{self.FILE_EXTENSION}" return os.path.join(self.dirpath, ckpt_name) if self.dirpath else ckpt_name def __resolve_ckpt_dir(self, trainer, pl_module): From 8779d4923cdb5ce2071171ef0a3354779cbd178a Mon Sep 17 00:00:00 2001 From: Jan-Henrik Lambrechts Date: Fri, 4 Dec 2020 17:16:08 +0100 Subject: [PATCH 02/11] Removed whitespace --- pytorch_lightning/callbacks/model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index aeb956d304b13..3c1faccde83a7 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -140,7 +140,7 @@ class ModelCheckpoint(Callback): CHECKPOINT_JOIN_CHAR = "-" CHECKPOINT_NAME_LAST = "last" - FILE_EXTENSION = "ckpt" + FILE_EXTENSION = "ckpt" def __init__( self, From 30f64588b0290460bc7f5714171e9a6d75e3f7db Mon Sep 17 00:00:00 2001 From: Jan-Henrik Lambrechts Date: Fri, 4 Dec 2020 17:20:46 +0100 Subject: [PATCH 03/11] Removed the last bit of whitespace --- pytorch_lightning/callbacks/model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 3c1faccde83a7..131a422896546 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -140,7 +140,7 @@ class ModelCheckpoint(Callback): CHECKPOINT_JOIN_CHAR = "-" CHECKPOINT_NAME_LAST = "last" - FILE_EXTENSION = "ckpt" + FILE_EXTENSION = "ckpt" def __init__( self, From 4c56ff733c6e078f5f2849b9afd19f223ff56ee6 Mon Sep 17 00:00:00 2001 From: Jan-Henrik Lambrechts Date: Sat, 5 Dec 2020 10:29:28 +0100 Subject: [PATCH 04/11] Wrote tests for FILE_EXTENSION --- .../callbacks/model_checkpoint.py | 6 ++-- tests/checkpointing/test_model_checkpoint.py | 34 +++++++++++++++++++ 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 131a422896546..956f29fd10a2f 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -140,7 +140,7 @@ class ModelCheckpoint(Callback): CHECKPOINT_JOIN_CHAR = "-" CHECKPOINT_NAME_LAST = "last" - FILE_EXTENSION = "ckpt" + FILE_EXTENSION = ".ckpt" def __init__( self, @@ -443,7 +443,7 @@ def format_checkpoint_name( ) if ver is not None: filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}")) - ckpt_name = f"{filename}.{self.FILE_EXTENSION}" + ckpt_name = f"{filename}{self.FILE_EXTENSION}" return os.path.join(self.dirpath, ckpt_name) if self.dirpath else ckpt_name def __resolve_ckpt_dir(self, trainer, pl_module): @@ -546,7 +546,7 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, filepath) ckpt_name_metrics, prefix=self.prefix ) - last_filepath = os.path.join(self.dirpath, f"{last_filepath}.ckpt") + last_filepath = os.path.join(self.dirpath, f"{last_filepath}{self.FILE_EXTENSION}") self._save_model(last_filepath, trainer, pl_module) if ( diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 33bc19a894d8f..d3a248eafe89b 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -261,6 +261,40 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir): assert ckpt_name == filepath / 'test-epoch=3-step=2.ckpt' +def test_model_checkpoint_file_extension(tmpdir): + + # tests that format_checkpoint_name uses the user-defined FILE_EXTENSION + ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath='.').format_checkpoint_name(0, 1, {}) + ModelCheckpoint.FILE_EXTENSION = '.tpkc' + tpkc_name = ModelCheckpoint(monitor='early_stop_on', dirpath='.').format_checkpoint_name(0, 1, {}) + assert ckpt_name == str(Path('.').resolve() / 'epoch=0-step=1.ckpt') + assert tpkc_name == str(Path('.').resolve() / 'epoch=0-step=1.tpkc') + + #tests that _save_last_checkpoint uses the user-defined FILE_EXTENSION + seed_everything() + model = LogInTwoMethods() + epochs = 1 + ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last-{epoch}' + model_checkpoint = ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, save_top_k=-1, save_last=True) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[model_checkpoint], + max_epochs=epochs, + limit_train_batches=10, + limit_val_batches=10, + logger=False, + ) + trainer.fit(model) + last_filename = model_checkpoint._format_checkpoint_name( + ModelCheckpoint.CHECKPOINT_NAME_LAST, trainer.current_epoch, trainer.global_step, {} + ) + last_filename = last_filename + '.tpkc' + assert str(tmpdir / last_filename) == model_checkpoint.last_model_path + + #Reset model checkpoint file extension so it does not break other tests + ModelCheckpoint.FILE_EXTENSION = '.ckpt' + + def test_model_checkpoint_save_last(tmpdir): """Tests that save_last produces only one last checkpoint.""" seed_everything() From cb9e3811f8995faf74c1033182a2f680d94d4bdf Mon Sep 17 00:00:00 2001 From: Jan-Henrik Lambrechts Date: Sat, 5 Dec 2020 10:35:20 +0100 Subject: [PATCH 05/11] Fixed formatting issues --- tests/checkpointing/test_model_checkpoint.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index d3a248eafe89b..0913f57483417 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -262,15 +262,15 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir): def test_model_checkpoint_file_extension(tmpdir): - + # tests that format_checkpoint_name uses the user-defined FILE_EXTENSION - ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath='.').format_checkpoint_name(0, 1, {}) + ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath='.').format_checkpoint_name(0, 1, {}) ModelCheckpoint.FILE_EXTENSION = '.tpkc' tpkc_name = ModelCheckpoint(monitor='early_stop_on', dirpath='.').format_checkpoint_name(0, 1, {}) assert ckpt_name == str(Path('.').resolve() / 'epoch=0-step=1.ckpt') assert tpkc_name == str(Path('.').resolve() / 'epoch=0-step=1.tpkc') - - #tests that _save_last_checkpoint uses the user-defined FILE_EXTENSION + + # tests that _save_last_checkpoint uses the user-defined FILE_EXTENSION seed_everything() model = LogInTwoMethods() epochs = 1 @@ -291,9 +291,9 @@ def test_model_checkpoint_file_extension(tmpdir): last_filename = last_filename + '.tpkc' assert str(tmpdir / last_filename) == model_checkpoint.last_model_path - #Reset model checkpoint file extension so it does not break other tests + # Reset model checkpoint file extension so it does not break other tests ModelCheckpoint.FILE_EXTENSION = '.ckpt' - + def test_model_checkpoint_save_last(tmpdir): """Tests that save_last produces only one last checkpoint.""" From 8b1119e2d15bab40b28af7931e518fadf478c7f5 Mon Sep 17 00:00:00 2001 From: Jan-Henrik Lambrechts Date: Sat, 5 Dec 2020 10:40:10 +0100 Subject: [PATCH 06/11] More formatting issues --- tests/checkpointing/test_model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 0913f57483417..0893e1de774e8 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -290,7 +290,7 @@ def test_model_checkpoint_file_extension(tmpdir): ) last_filename = last_filename + '.tpkc' assert str(tmpdir / last_filename) == model_checkpoint.last_model_path - + # Reset model checkpoint file extension so it does not break other tests ModelCheckpoint.FILE_EXTENSION = '.ckpt' From fe3111d6f143aebf588149d2d713cf9c232039e0 Mon Sep 17 00:00:00 2001 From: Jan-Henrik Lambrechts Date: Sat, 5 Dec 2020 14:56:23 +0100 Subject: [PATCH 07/11] Simplify test by just using defaults --- tests/checkpointing/test_model_checkpoint.py | 31 +++++--------------- 1 file changed, 8 insertions(+), 23 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 0893e1de774e8..4407eb3ef7b2d 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -263,34 +263,19 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir): def test_model_checkpoint_file_extension(tmpdir): - # tests that format_checkpoint_name uses the user-defined FILE_EXTENSION - ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath='.').format_checkpoint_name(0, 1, {}) + # tests that files get saved with user-defined FILE_EXTENSION ModelCheckpoint.FILE_EXTENSION = '.tpkc' - tpkc_name = ModelCheckpoint(monitor='early_stop_on', dirpath='.').format_checkpoint_name(0, 1, {}) - assert ckpt_name == str(Path('.').resolve() / 'epoch=0-step=1.ckpt') - assert tpkc_name == str(Path('.').resolve() / 'epoch=0-step=1.tpkc') - - # tests that _save_last_checkpoint uses the user-defined FILE_EXTENSION - seed_everything() - model = LogInTwoMethods() - epochs = 1 - ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last-{epoch}' - model_checkpoint = ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, save_top_k=-1, save_last=True) + model = LogInTwoMethods() + model_checkpoint = ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, save_top_k=1, save_last=True) trainer = Trainer( default_root_dir=tmpdir, callbacks=[model_checkpoint], - max_epochs=epochs, - limit_train_batches=10, - limit_val_batches=10, - logger=False, - ) + max_epochs=2 + ) trainer.fit(model) - last_filename = model_checkpoint._format_checkpoint_name( - ModelCheckpoint.CHECKPOINT_NAME_LAST, trainer.current_epoch, trainer.global_step, {} - ) - last_filename = last_filename + '.tpkc' - assert str(tmpdir / last_filename) == model_checkpoint.last_model_path - + expected = 'last.tpkc' + assert expected in set(os.listdir(tmpdir)) + # Reset model checkpoint file extension so it does not break other tests ModelCheckpoint.FILE_EXTENSION = '.ckpt' From 947a9d592c902d0faf7585dfdf098c08c9e910ee Mon Sep 17 00:00:00 2001 From: Jan-Henrik Lambrechts Date: Sat, 5 Dec 2020 15:00:34 +0100 Subject: [PATCH 08/11] Formatting to PEP8 --- tests/checkpointing/test_model_checkpoint.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 4407eb3ef7b2d..51e305438f136 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -265,17 +265,17 @@ def test_model_checkpoint_file_extension(tmpdir): # tests that files get saved with user-defined FILE_EXTENSION ModelCheckpoint.FILE_EXTENSION = '.tpkc' - model = LogInTwoMethods() + model = LogInTwoMethods() model_checkpoint = ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, save_top_k=1, save_last=True) trainer = Trainer( default_root_dir=tmpdir, callbacks=[model_checkpoint], max_epochs=2 - ) + ) trainer.fit(model) expected = 'last.tpkc' assert expected in set(os.listdir(tmpdir)) - + # Reset model checkpoint file extension so it does not break other tests ModelCheckpoint.FILE_EXTENSION = '.ckpt' From 410dd286894ae29d30f13501272495edd5457356 Mon Sep 17 00:00:00 2001 From: Jan-Henrik Lambrechts Date: Sun, 6 Dec 2020 11:13:55 +0100 Subject: [PATCH 09/11] Added dummy class that inherits ModelCheckpoint; run only one batch instead of epoch for integration test --- tests/checkpointing/test_model_checkpoint.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 51e305438f136..78c4a73ff3523 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -261,24 +261,28 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir): assert ckpt_name == filepath / 'test-epoch=3-step=2.ckpt' +class ModelCheckpointExtensionTest(ModelCheckpoint): + # Helper class for test_model_checkpoint_file_extension + # Needs to be defined outside of function call as local objects cannot be pickled + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def test_model_checkpoint_file_extension(tmpdir): # tests that files get saved with user-defined FILE_EXTENSION - ModelCheckpoint.FILE_EXTENSION = '.tpkc' + ModelCheckpointExtensionTest.FILE_EXTENSION = '.tpkc' model = LogInTwoMethods() - model_checkpoint = ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, save_top_k=1, save_last=True) + model_checkpoint = ModelCheckpointExtensionTest(monitor='early_stop_on', dirpath=tmpdir, save_top_k=1, save_last=True) trainer = Trainer( default_root_dir=tmpdir, callbacks=[model_checkpoint], - max_epochs=2 + max_steps=1 ) trainer.fit(model) expected = 'last.tpkc' assert expected in set(os.listdir(tmpdir)) - # Reset model checkpoint file extension so it does not break other tests - ModelCheckpoint.FILE_EXTENSION = '.ckpt' - def test_model_checkpoint_save_last(tmpdir): """Tests that save_last produces only one last checkpoint.""" From e395a37d8ab037f883cc3a272b35325ee12a530f Mon Sep 17 00:00:00 2001 From: Jan-Henrik Lambrechts Date: Sun, 6 Dec 2020 11:17:15 +0100 Subject: [PATCH 10/11] Fixed too much whitespace formatting --- tests/checkpointing/test_model_checkpoint.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 78c4a73ff3523..5cbfcc3e69ffa 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -262,10 +262,10 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir): class ModelCheckpointExtensionTest(ModelCheckpoint): - # Helper class for test_model_checkpoint_file_extension - # Needs to be defined outside of function call as local objects cannot be pickled - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + # Helper class for test_model_checkpoint_file_extension + # Needs to be defined outside of function call as local objects cannot be pickled + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) def test_model_checkpoint_file_extension(tmpdir): From 047424d0bdb76f69921a3dec3b0ad4045202550b Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sun, 6 Dec 2020 18:08:52 +0530 Subject: [PATCH 11/11] some changes --- .../callbacks/model_checkpoint.py | 2 +- tests/checkpointing/test_model_checkpoint.py | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 956f29fd10a2f..eb669736ada3a 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -90,7 +90,7 @@ class ModelCheckpoint(Callback): Example:: # custom path - # saves a file like: my/path/epoch=0.ckpt + # saves a file like: my/path/epoch=0-step=10.ckpt >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/') By default, dirpath is ``None`` and will be set at runtime to the location diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 5cbfcc3e69ffa..6d1d3edea5be9 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -262,26 +262,26 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir): class ModelCheckpointExtensionTest(ModelCheckpoint): - # Helper class for test_model_checkpoint_file_extension - # Needs to be defined outside of function call as local objects cannot be pickled - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + FILE_EXTENSION = '.tpkc' def test_model_checkpoint_file_extension(tmpdir): + """ + Test ModelCheckpoint with different file extension. + """ - # tests that files get saved with user-defined FILE_EXTENSION - ModelCheckpointExtensionTest.FILE_EXTENSION = '.tpkc' model = LogInTwoMethods() model_checkpoint = ModelCheckpointExtensionTest(monitor='early_stop_on', dirpath=tmpdir, save_top_k=1, save_last=True) trainer = Trainer( default_root_dir=tmpdir, callbacks=[model_checkpoint], - max_steps=1 + max_steps=1, + logger=False, ) trainer.fit(model) - expected = 'last.tpkc' - assert expected in set(os.listdir(tmpdir)) + + expected = ['epoch=0-step=0.tpkc', 'last.tpkc'] + assert set(expected) == set(os.listdir(tmpdir)) def test_model_checkpoint_save_last(tmpdir):