Skip to content

Conversation

@akihironitta
Copy link
Contributor

@akihironitta akihironitta commented Aug 23, 2022

What does this PR do?

Tries to fix #13937 by partially reverting #10497 to revive auto_refresh=True.

code for benchmark
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import RichProgressBar

class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

def main():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    model = BoringModel()
    trainer = Trainer(
        max_epochs=100,
        enable_model_summary=False,
        enable_checkpointing=False,
        logger=False,
        benchmark=False,  # True by default in 1.6.{0-3}.
        callbacks=RichProgressBar(),
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)

if __name__ == "__main__":
    main()

Does your PR introduce any breaking changes? If yes, please list them.

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • [n/a] Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

Did you have fun?

Make sure you had fun coding 🙃

@akihironitta akihironitta changed the title Make RichProgressBar use auto_refresh=True [wip] Make RichProgressBar use auto_refresh=True Aug 23, 2022
@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Aug 23, 2022
@akihironitta akihironitta force-pushed the bugfix/rich-auto-refresh branch from fcb4727 to 07b171b Compare August 23, 2022 11:32
@akihironitta
Copy link
Contributor Author

akihironitta commented Sep 9, 2022

(Still wip but) Following Textualize/rich#2432, it now uses auto_refresh=True in this PR, and I see a significant improvement locally:

time: 4.771114709   # 4ca7033693f0bef01ac724ca83e4c6c4c9550b42 (bugfix/rich-auto-refresh)
time: 12.920467208  # b84c03f3a6c5327814c6aaab5634e5b5d46919d8 (master)
code
from time import monotonic
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import RichProgressBar

class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
        return [optimizer], [lr_scheduler]

def main():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    model = BoringModel()
    trainer = Trainer(
        max_epochs=100,
        enable_model_summary=False,
        enable_checkpointing=False,
        logger=False,
        benchmark=False,  # True by default in 1.6.{0-3}.
        callbacks=RichProgressBar(),
    )
    t0 = monotonic()
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    print("time:", monotonic() - t0)

if __name__ == "__main__":
    main()

@akihironitta akihironitta self-assigned this Sep 9, 2022
@akihironitta akihironitta marked this pull request as ready for review September 9, 2022 02:19
@akihironitta
Copy link
Contributor Author

⚠️ Not ready ⚠️ Marked it as ready just to trigger CI.

@akihironitta akihironitta added the priority: 1 Medium priority task label Sep 9, 2022
@justusschock
Copy link
Member

@akihironitta How is it going here? Just a friendly reminder :)

@akihironitta akihironitta deleted the bugfix/rich-auto-refresh branch March 16, 2023 17:29
@PKizzle
Copy link

PKizzle commented Mar 17, 2023

May I ask why this PR has been closed?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance pl Generic label for PyTorch Lightning package priority: 1 Medium priority task progress bar: rich

Projects

None yet

Development

Successfully merging this pull request may close these issues.

RichProgressBar in v1.6 is slower than v1.5

3 participants