Skip to content

LightningModule.print() fails when it is called before training #7673

@ryanking13

Description

@ryanking13

🐛 Bug

LightningModule.print() fails when it is called before training.

For example, LightningModule.print() in validation_step(), test_step(), prediction_step() fails if Trainer.fit(model) had not been called.

Please reproduce using the BoringModel

https://colab.research.google.com/drive/1om5k36kY2UDZzvl2MqcIrDJBn7XKhnzw?usp=sharing

--> 474         if not self.main_progress_bar.disable:
    475             active_progress_bar = self.main_progress_bar
    476         elif not self.val_progress_bar.disable:

AttributeError: 'NoneType' object has no attribute 'disable'

To Reproduce

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import Trainer, LightningModule


class RandomDataset(Dataset):

    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):

    def __init__(self, **kwargs):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def loss(self, batch, prediction):
        # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
        return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))

    def step(self, x):
        x = self(x)
        out = torch.nn.functional.mse_loss(x, torch.ones_like(x))
        return out

    def training_step(self, batch, batch_idx):
        output = self(batch)
        loss = self.loss(batch, output)
        return {"loss": loss}

    def validation_step(self, *args):
        self.print("[VALIDATION] PRINT ME!!")
    
    def test_step(self, *args):
        self.print("[TEST] PRINT ME!!")
    
    def predict_step(self, *args):
        self.print("[PREDICT] PRINT ME!!")

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
        return [optimizer], [lr_scheduler]

    def train_dataloader(self):
        return DataLoader(RandomDataset(32, 64))

    def val_dataloader(self):
        return DataLoader(RandomDataset(32, 64))

    def test_dataloader(self):
        return DataLoader(RandomDataset(32, 64))

    def predict_dataloader(self):
        return DataLoader(RandomDataset(32, 64))

if __name__ == '__main__':
    model = BoringModel()
    trainer = Trainer(max_epochs=1)

    trainer.validate(model) # raises error
    trainer.test(model) # raises error
    trainer.predict(model) # not printed

Expected behavior

LightningModule.print() must succeed wherever it is called.

Environment

  • CUDA:
    • GPU:
    • available: False
    • version: 10.1
  • Packages:
    • numpy: 1.19.5
    • pyTorch_debug: False
    • pyTorch_version: 1.8.1+cu101
    • pytorch-lightning: 1.3.2
    • tqdm: 4.41.1
  • System:
    • OS: Linux
    • architecture:
      • 64bit
    • processor: x86_64
    • python: 3.7.10
    • version: Proposal for help #1 SMP Tue Apr 20 19:55:43 PDT 2021

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinghelp wantedOpen to be worked onpriority: 2Low priority task

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions