Skip to content

[DDP] Metrics .compute() / sync_ddp fails with move_metrics_to_cpu #10379

@Quintulius

Description

@Quintulius

🐛 Bug

Metrics compute() method fails when calling sync_ddp with DDP and move_metrics_to_cpu:

  File "pytorch1.10/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ddp.py", line 385, in reduce
    tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
  File "pytorch1.10/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py", line 158, in sync_ddp_if_available
    return sync_ddp(result, group=group, reduce_op=reduce_op)
  File "pytorch1.10/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py", line 193, in sync_ddp
    torch.distributed.all_reduce(result, op=op, group=group, async_op=False)
  File "pytorch1.10/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1287, in all_reduce
    work = group.allreduce([tensor], opts)
RuntimeError: Tensors must be CUDA and dense

To Reproduce

import os

import torch
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer
from torchmetrics import Accuracy


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)
        self.labels = torch.randint(0, 2, (length, 2))

    def __getitem__(self, index):
        return self.data[index], self.labels[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
        self.metric = Accuracy()

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)
        loss = self(x).sum()
        self.log("train_loss", loss)
        self.log("accuracy", self.metric(preds, y))
        return {"loss": loss}

    def training_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
        self.log("end_accuracy", self.metric.compute())

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        gpus=1,
        accelerator="ddp",
        move_metrics_to_cpu=True
    )
    trainer.fit(model, train_dataloaders=train_data)


if __name__ == "__main__":
    run()

Expected behavior

Metrics are computed without trouble !

Environment

PyTorch version: 1.10.0+cu113
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.3 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-6ubuntu2) 7.5.0
Clang version: Could not collect
CMake version: version 3.16.3
Libc version: glibc-2.31

Python version: 3.8.12 (default, Oct 12 2021, 13:49:34)  [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.11.0-38-generic-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: 10.1.243
GPU models and configuration: GPU 0: NVIDIA GeForce GTX 970
Nvidia driver version: 495.29.05
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.21.2
[pip3] pytorch-lightning==1.4.9
[pip3] torch==1.10.0+cu113
[pip3] torchaudio==0.10.0+cu113
[pip3] torchmetrics==0.5.1
[pip3] torchvision==0.11.1+cu113
[conda] blas                      1.0                         mkl  
[conda] cudatoolkit               11.3.1               h2bc3f7f_2  
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] mkl                       2021.3.0           h06a4308_520  
[conda] mkl-service               2.4.0            py38h7f8727e_0  
[conda] mkl_fft                   1.3.1            py38hd3c417c_0  
[conda] mkl_random                1.2.2            py38h51133e4_0  
[conda] numpy                     1.21.2           py38h20f2e39_0  
[conda] numpy-base                1.21.2           py38h79a1101_0  
[conda] pytorch-lightning         1.4.9                    pypi_0    pypi
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torch                     1.10.0+cu113             pypi_0    pypi
[conda] torchaudio                0.10.0+cu113             pypi_0    pypi
[conda] torchmetrics              0.5.1                    pypi_0    pypi
[conda] torchvision               0.11.1+cu113             pypi_0    pypi

Additional context

cc @tchaton @rohitgr7

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinghelp wantedOpen to be worked onpriority: 1Medium priority task

    Type

    No type

    Projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions