-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed as not planned
Closed as not planned
Copy link
Labels
bugSomething isn't workingSomething isn't workinghelp wantedOpen to be worked onOpen to be worked onpriority: 1Medium priority taskMedium priority task
Description
🐛 Bug
Metrics compute() method fails when calling sync_ddp with DDP and move_metrics_to_cpu:
File "pytorch1.10/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ddp.py", line 385, in reduce
tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
File "pytorch1.10/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py", line 158, in sync_ddp_if_available
return sync_ddp(result, group=group, reduce_op=reduce_op)
File "pytorch1.10/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py", line 193, in sync_ddp
torch.distributed.all_reduce(result, op=op, group=group, async_op=False)
File "pytorch1.10/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1287, in all_reduce
work = group.allreduce([tensor], opts)
RuntimeError: Tensors must be CUDA and dense
To Reproduce
import os
import torch
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
from torchmetrics import Accuracy
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
self.labels = torch.randint(0, 2, (length, 2))
def __getitem__(self, index):
return self.data[index], self.labels[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
self.metric = Accuracy()
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
x, y = batch
preds = self(x)
loss = self(x).sum()
self.log("train_loss", loss)
self.log("accuracy", self.metric(preds, y))
return {"loss": loss}
def training_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
self.log("end_accuracy", self.metric.compute())
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run():
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
num_sanity_val_steps=0,
max_epochs=1,
gpus=1,
accelerator="ddp",
move_metrics_to_cpu=True
)
trainer.fit(model, train_dataloaders=train_data)
if __name__ == "__main__":
run()Expected behavior
Metrics are computed without trouble !
Environment
PyTorch version: 1.10.0+cu113
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.3 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-6ubuntu2) 7.5.0
Clang version: Could not collect
CMake version: version 3.16.3
Libc version: glibc-2.31
Python version: 3.8.12 (default, Oct 12 2021, 13:49:34) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.11.0-38-generic-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: 10.1.243
GPU models and configuration: GPU 0: NVIDIA GeForce GTX 970
Nvidia driver version: 495.29.05
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Versions of relevant libraries:
[pip3] numpy==1.21.2
[pip3] pytorch-lightning==1.4.9
[pip3] torch==1.10.0+cu113
[pip3] torchaudio==0.10.0+cu113
[pip3] torchmetrics==0.5.1
[pip3] torchvision==0.11.1+cu113
[conda] blas 1.0 mkl
[conda] cudatoolkit 11.3.1 h2bc3f7f_2
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] mkl 2021.3.0 h06a4308_520
[conda] mkl-service 2.4.0 py38h7f8727e_0
[conda] mkl_fft 1.3.1 py38hd3c417c_0
[conda] mkl_random 1.2.2 py38h51133e4_0
[conda] numpy 1.21.2 py38h20f2e39_0
[conda] numpy-base 1.21.2 py38h79a1101_0
[conda] pytorch-lightning 1.4.9 pypi_0 pypi
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] torch 1.10.0+cu113 pypi_0 pypi
[conda] torchaudio 0.10.0+cu113 pypi_0 pypi
[conda] torchmetrics 0.5.1 pypi_0 pypi
[conda] torchvision 0.11.1+cu113 pypi_0 pypi
Additional context
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workinghelp wantedOpen to be worked onOpen to be worked onpriority: 1Medium priority taskMedium priority task