Skip to content

Trainer.predict should accumulate results on CPU not GPU #10156

@jzazo

Description

@jzazo

🐛 Bug

As discussed here and addressed by this PR, predicted results should accumulate on CPU. However, this does not appear to be working properly.

To Reproduce

import os

import torch
from pytorch_lightning import LightningModule, Trainer
from torch.utils.data import DataLoader, Dataset

class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 1)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

def run():
    predict_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=2,
        max_epochs=1
        gpus=1,
    )
    results = trainer.predict(model, dataloaders=predict_data)
    print(results[0].device)

if __name__ == "__main__":
    run()

Expected behavior

I expected cpu printed output, but I get cuda:0. Works the same if forward returned a dictionary instead, and it should move tensors to cpu as well.

Environment

  • PyTorch Lightning Version (e.g., 1.3.0): 1.4.9
  • PyTorch Version (e.g., 1.8): 1.10.0
  • Python version: 3.8.5
  • OS (e.g., Linux): Ubuntu 20.04
  • How you installed PyTorch (conda, pip, source): pipenv

Additional context

Tagging @EricWiener from the original discussion.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions