Skip to content

[HOT-BUG] Checkpoint in callbacks list fails #4275

@tchaton

Description

@tchaton

🐛 Bug

Please reproduce using the BoringModel and post here

The ModelCheckpoint is not properly setup when provided through the list of callbacks.

def test_checkpoint_within_callbacks_list(tmpdir):
    """
    This test validates that the checkpoint can be called when provided to callacks list
    """

    os.environ['PL_DEV_DEBUG'] = '1'

    checkpoint_callback = ModelCheckpoint(monitor='val_loss', filepath=osp.join(tmpdir, "{epoch:02d}"))

    class ExtendedBoringModel(BoringModel):

        def validation_step(self, batch, batch_idx):
            output = self.layer(batch)
            loss = self.loss(batch, output)
            return {"val_loss": loss}

    model = ExtendedBoringModel()
    model.validation_step_end = None
    model.validation_epoch_end = None
    trainer = pl.Trainer(max_epochs=1, 
                         limit_train_batches=2, 
                         limit_val_batches=2, 
                         limit_test_batches=2, 
                         callbacks=[checkpoint_callback])

    trainer.fit(model)
    assert os.listdir(tmpdir) == ['epoch=00.ckpt]']

tests/checkpointing/test_model_checkpoint.py:590: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
pytorch_lightning/trainer/trainer.py:439: in fit
    results = self.accelerator_backend.train()
pytorch_lightning/accelerators/cpu_accelerator.py:48: in train
    results = self.train_or_test()
pytorch_lightning/accelerators/accelerator.py:66: in train_or_test
    results = self.trainer.train()
pytorch_lightning/trainer/trainer.py:482: in train
    self.train_loop.run_training_epoch()
pytorch_lightning/trainer/training_loop.py:569: in run_training_epoch
    self.trainer.run_evaluation(test_mode=False)
pytorch_lightning/trainer/trainer.py:609: in run_evaluation
    self.evaluation_loop.on_evaluation_end()
pytorch_lightning/trainer/evaluation_loop.py:109: in on_evaluation_end
    self.trainer.call_hook('on_validation_end', *args, **kwargs)
pytorch_lightning/trainer/trainer.py:822: in call_hook
    trainer_hook(*args, **kwargs)
pytorch_lightning/trainer/callback_hook.py:177: in on_validation_end
    callback.on_validation_end(self, self.get_model())
pytorch_lightning/callbacks/model_checkpoint.py:167: in on_validation_end
    self.save_checkpoint(trainer, pl_module)
pytorch_lightning/callbacks/model_checkpoint.py:213: in save_checkpoint
    self._save_top_k_checkpoints(monitor_candidates, trainer, pl_module, epoch, filepath)
pytorch_lightning/callbacks/model_checkpoint.py:494: in _save_top_k_checkpoints
    self._update_best_and_save(filepath, current, epoch, trainer, pl_module)
pytorch_lightning/callbacks/model_checkpoint.py:543: in _update_best_and_save
    self._save_model(filepath, trainer, pl_module)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint object at 0x13a775b80>
filepath = '/private/var/folders/q9/1sc07dt91w9581rl1vt6vd4w0000gn/T/pytest-of-thomas/pytest-1058/test_checkpoint_within_callbac0/epoch=00.ckpt'
trainer = <pytorch_lightning.trainer.trainer.Trainer object at 0x13a7756a0>, pl_module = ExtendedBoringModel(
  (layer): Linear(in_features=32, out_features=2, bias=True)
)

    def _save_model(self, filepath: str, trainer, pl_module):
        # in debugging, track when we save checkpoints
        trainer.dev_debugger.track_checkpointing_history(filepath)
    
        # make paths
        if trainer.is_global_zero:
            self._fs.makedirs(os.path.dirname(filepath), exist_ok=True)
    
        # delegate the saving to the trainer
        if self.save_function is not None:
            self.save_function(filepath, self.save_weights_only)
        else:
>           raise ValueError(".save_function() not set")
E           ValueError: .save_function() not set

pytorch_lightning/callbacks/model_checkpoint.py:297: ValueError
----------------------------------------------------------------------------------------- Captured log call ------------------------------------------------------------------------------------------
INFO     lightning:distributed.py:49 GPU available: False, used: False
INFO     lightning:distributed.py:49 TPU available: False, using: 0 TPU cores
INFO     lightning:lightning.py:1290 
  | Name  | Type   | Params
---------------------------------
0 | layer | Linear | 66
========================================================================================== warnings summary ==========================================================================================
.venv/lib/python3.8/site-packages/comet_ml/monkey_patching.py:19
  /Users/thomas/Documents/projects/pytorch-lightning/.venv/lib/python3.8/site-packages/comet_ml/monkey_patching.py:19: DeprecationWarning: the imp module is deprecated in favour of importlib; see the module's documentation for alternative uses
    import imp

.venv/lib/python3.8/site-packages/pandas/compat/__init__.py:120
  /Users/thomas/Documents/projects/pytorch-lightning/.venv/lib/python3.8/site-packages/pandas/compat/__init__.py:120: UserWarning: Could not import the lzma module. Your installed Python is incomplete. Attempting to use lzma compression will result in a RuntimeError.
    warnings.warn(msg)

.venv/lib/python3.8/site-packages/wandb/util.py:35
  /Users/thomas/Documents/projects/pytorch-lightning/.venv/lib/python3.8/site-packages/wandb/util.py:35: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3, and in 3.9 it will stop working
    from collections import namedtuple, Mapping, Sequence

.venv/lib/python3.8/site-packages/wandb/vendor/graphql-core-1.1/graphql/type/directives.py:55
  /Users/thomas/Documents/projects/pytorch-lightning/.venv/lib/python3.8/site-packages/wandb/vendor/graphql-core-1.1/graphql/type/directives.py:55: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3, and in 3.9 it will stop working
    assert isinstance(locations, collections.Iterable), 'Must provide locations for directive.'

tests/checkpointing/test_model_checkpoint.py::test_checkpoint_within_callbacks_list
  /Users/thomas/Documents/projects/pytorch-lightning/pytorch_lightning/utilities/distributed.py:45: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
    warnings.warn(*args, **kwargs)

tests/checkpointing/test_model_checkpoint.py::test_checkpoint_within_callbacks_list
  /Users/thomas/Documents/projects/pytorch-lightning/pytorch_lightning/utilities/distributed.py:45: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
    warnings.warn(*args, **kwargs)

-- Docs: https://docs.pytest.org/en/stable/warnings.html
====================================================================================== short test summary info =======================================================================================
FAILED tests/checkpointing/test_model_checkpoint.py::test_checkpoint_within_callbacks_list - ValueError: .save_function() not set

To Reproduce

Expected behavior

Environment

Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).

You can get the script and run it with:

wget https://raw.githubusercontent.com/PyTorchLightning/pytorch-lightning/master/tests/collect_env_details.py
# For security purposes, please check the contents of collect_env_details.py before running it.
python collect_env_details.py
  • PyTorch Version (e.g., 1.0):
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, source):
  • Build command you used (if compiling from source):
  • Python version:
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions