Skip to content

Bounded memory leak caused by trainer.evalutaion_loop.outputs #5735

@roytseng-tw

Description

@roytseng-tw

🐛 Bug

trainer.evalutaion_loop.outputs caches the outputs of every validation steps in def run_evaluation(self, max_batches=None): of trainer:
https://github.com/PyTorchLightning/pytorch-lightning/blob/d71659b42a13946b854d49f5bb1bf6e2bcd5b9b2/pytorch_lightning/trainer/trainer.py#L659

It's not reset until the start of the next validation epoch:

https://github.com/PyTorchLightning/pytorch-lightning/blob/d71659b42a13946b854d49f5bb1bf6e2bcd5b9b2/pytorch_lightning/trainer/trainer.py#L621

https://github.com/PyTorchLightning/pytorch-lightning/blob/d71659b42a13946b854d49f5bb1bf6e2bcd5b9b2/pytorch_lightning/trainer/evaluation_loop.py#L124-L128

Please reproduce using the BoringModel

To Reproduce

Sorry, my working environment forbids me to use google drive.

import torch
from pytorch_lightning import LightningModule
from torch.utils.data import Dataset

class BoringModel(LightningModule):

    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def loss(self, batch, prediction):
        # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
        return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))
    
    def on_train_epoch_start(self):
        print('Before delete:', torch.cuda.memory_allocated())
        for out in self.trainer.evaluation_loop.outputs[0]:
          if 'x' in out:
            del out['x']
        print('After delete:', torch.cuda.memory_allocated())

    def training_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"loss": loss}

    def training_step_end(self, training_step_outputs):
        return training_step_outputs

    def training_epoch_end(self, outputs) -> None:
        torch.stack([x["loss"] for x in outputs]).mean()

    def validation_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"x": loss}

    def validation_epoch_end(self, outputs) -> None:
        torch.stack([x['x'] for x in outputs]).mean()

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
        return [optimizer], [lr_scheduler]
def test_x(tmpdir):
    # init model
    model = BoringModel()

    # Initialize a trainer
    trainer = pl.Trainer(
        max_epochs=1, 
        progress_bar_refresh_rate=20,
        gpus=[0],
    )

    # Train the model ⚡
    trainer.fit(model, train, val)

Execution

test_x(tmpdir)

Output

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params
---------------------------------
0 | layer | Linear | 66    
---------------------------------
66        Trainable params
0         Non-trainable params
66        Total params
Validation sanity check: 0%
0/2 [00:00<?, ?it/s]
Epoch 0: 100%
626/626 [00:00<00:00, 653.87it/s, loss=2.5e-14, v_num=6]
Before delete: 2048
After delete: 1024

Expected behavior

There shouldn't be such cached tensors.
This may cause OOM in some cases that OOM can be avoided.
For example,

  • On the first training epoch, a model fitted just right in the GPU memory runs fine w/o OOM.
  • After the first validation epoch, some GPU tensors are retained and occupy some portion of the memory.
  • On the second training epoch, the same model encounters OOM error.

Clear all references to those validation output tensors at the end of the validation epoch.
Maybe, more specifically, at here?
https://github.com/PyTorchLightning/pytorch-lightning/blob/d71659b42a13946b854d49f5bb1bf6e2bcd5b9b2/pytorch_lightning/trainer/evaluation_loop.py#L224-L229

Environment

  • CUDA:
    • GPU:
      • Tesla T4
    • available: True
    • version: 10.1
  • Packages:
    • numpy: 1.19.5
    • pyTorch_debug: True
    • pyTorch_version: 1.7.0+cu101
    • pytorch-lightning: 1.1.6
    • tqdm: 4.41.1
  • System:
    • OS: Linux
    • architecture:
      • 64bit
    • processor: x86_64
    • python: 3.6.9
    • version: Proposal for help #1 SMP Thu Jul 23 08:00:38 PDT 2020

Additional context

Metadata

Metadata

Assignees

Labels

bugSomething isn't workinghelp wantedOpen to be worked onpriority: 0High priority task

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions