Skip to content

Latest Lightning does not support multiple callbacks that stop #6194

@jlperla

Description

@jlperla

🐛 Bug

In the latest version of lightning, you do not seem to be able to have multiple callbacks which can stop.

Please reproduce using the BoringModel

  1. If you have mulitple callbacks which can do early stopping, only the last one can be active.
  2. Create a callback with early stopping, MyStoppingCallback(). Add it, then EarlyStoppingCallback() to the callbacks argument of the trainer, e.g. callbacks = [MyStoppingCallback(), EarlyStoppingCallback('val_loss')]
  • The callback is triggered and calculates that it needs to stop, but it ontinues training
  • On the other hand, if you change the order (e.g. callbacks = [EarlyStoppingCallback('val_loss'),MyStoppingCallback()] it will be stop with MyStoppingCallback but probably doesn't triggle the EarlyStoppingCallback.
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# --------------------------------------------
# --------------------------------------------
# --------------------------------------------
# USE THIS MODEL TO REPRODUCE A BUG YOU REPORT
# --------------------------------------------
# --------------------------------------------
# --------------------------------------------
import os

import torch
from torch.utils.data import Dataset

from pl_examples import cli_lightning_logo
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks import Callback

class RandomDataset(Dataset):
    """
    >>> RandomDataset(size=10, length=20)  # doctest: +ELLIPSIS
    <...bug_report_model.RandomDataset object at ...>
    """

    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    """
    >>> BoringModel()  # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
    BoringModel(
      (layer): Linear(...)
    )
    """

    def __init__(self):
        """
        Testing PL Module

        Use as follows:
        - subclass
        - modify the behavior for what you want

        class TestModel(BaseTestModel):
            def training_step(...):
                # do your own thing

        or:

        model = BaseTestModel()
        model.training_epoch_end = None

        """
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def loss(self, batch, prediction):
        # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
        return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))

    def step(self, x):
        x = self.layer(x)
        out = torch.nn.functional.mse_loss(x, torch.ones_like(x))
        return out

    def training_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"loss": loss}

    def training_step_end(self, training_step_outputs):
        return training_step_outputs

    def training_epoch_end(self, outputs) -> None:
        torch.stack([x["loss"] for x in outputs]).mean()

    def validation_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        self.log('val_loss', loss)
        return {"x": loss}

    def validation_epoch_end(self, outputs) -> None:
        torch.stack([x['x'] for x in outputs]).mean()

    def test_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"y": loss}

    def test_epoch_end(self, outputs) -> None:
        torch.stack([x["y"] for x in outputs]).mean()

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
        return [optimizer], [lr_scheduler]


#  NOTE: If you are using a cmd line to run your script,
#  provide the cmd line as below.
#  opt = "--max_epochs 1 --limit_train_batches 1".split(" ")
#  parser = ArgumentParser()
#  args = parser.parse_args(opt)

class EarlyStoppingExample(Callback):
    def on_validation_end(self, trainer, pl_module):
        if trainer.current_epoch > 5:
            should_stop = True
        else:
            should_stop = False

        if bool(should_stop):
            print("\nSTOPPING!!!!!!!!!!!!!!!!!!!!\n")
            self.stopped_epoch = trainer.current_epoch
            trainer.should_stop = True

        # stop every ddp process if any world process decides to stop
        should_stop = trainer.training_type_plugin.reduce_early_stopping_decision(should_stop)
        trainer.should_stop = should_stop

def test_run():

    class TestModel(BoringModel):

        def on_train_epoch_start(self) -> None:
            pass

    # fake data
    train_data = torch.utils.data.DataLoader(RandomDataset(32, 64))
    val_data = torch.utils.data.DataLoader(RandomDataset(32, 64))
    test_data = torch.utils.data.DataLoader(RandomDataset(32, 64))

    # model

    early_stopping = EarlyStopping('val_loss', patience=50)

    model = TestModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        max_epochs=100,
        weights_summary=None,
        callbacks=[
            EarlyStoppingExample(),
            early_stopping,
            ]
    )
    trainer.fit(model, train_data, val_data)
    trainer.test(test_dataloaders=test_data)


if __name__ == '__main__':
    #cli_lightning_logo()
    test_run()

To Reproduce

Use following BoringModel and post here

Expected behavior

  • PyTorch Version (e.g., 1.0): 1.7
  • OS (e.g., Linux): Windows
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source):
  • Python version: 3.7.4
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • Any other relevant information:

Metadata

Metadata

Assignees

Labels

bugSomething isn't workinghelp wantedOpen to be worked onpriority: 0High priority task

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions