Skip to content

DDP fails with DDPPlugin and num_nodes>1 (with SLURM) #7429

@jopo666

Description

@jopo666

🐛 Bug

DDPPlugin crashes my training scripts when training models on multiple nodes (using SLURM).

When I use multiple GPUs on 1 node with the plugin -> all gucci.
When I use multiple GPUs on multiple nodes without the plugin -> all gucci.
When I use multiple GPUs on multiple nodes with the plugin -> crashes 😢

So when I run...

sbatch submit.sh debug.py --num_nodes 2 --num_gpus 4 --ddp_plugin

Code fails with...

----------------------------------
Total of 8 GPUs over 2 nodes.
Conda environment = DDP_Fail
pytorch-lightning 1.3.0
Running at my_secret_server.fi
Python command:
  python3 ~/debug.py --num_gpus 4 --num_nodes 2 --ddp_plugin
----------------------------------
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/8
initializing ddp: GLOBAL_RANK: 7, MEMBER: 8/8
initializing ddp: GLOBAL_RANK: 6, MEMBER: 7/8
initializing ddp: GLOBAL_RANK: 4, MEMBER: 5/8
initializing ddp: GLOBAL_RANK: 1, MEMBER: 2/8
initializing ddp: GLOBAL_RANK: 3, MEMBER: 4/8
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Multi-processing is handled by Slurm.
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/8
initializing ddp: GLOBAL_RANK: 2, MEMBER: 3/8
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.

Traceback (most recent call last):

...

ValueError: Invalid rank 5, rank should be in the interval [0, 3]
ValueError: Invalid rank 6, rank should be in the interval [0, 3]
ValueError: Invalid rank 7, rank should be in the interval [0, 3]

To Reproduce

debug.py:

import os
import argparse

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.plugins import DDPPlugin

class RandomDataset(Dataset):

    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):

    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run(args):
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    
    if args.ddp_plugin:
        plugin = DDPPlugin(find_unused_parameters=False)
    else:
        plugin = None

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        weights_summary=None,
        gpus=args.num_gpus,
        num_nodes=args.num_nodes,
        accelerator='ddp' if args.num_gpus*args.num_nodes > 1 else None,
        plugins=plugin,
    )
    trainer.fit(model, train_dataloader=train_data, val_dataloaders=val_data)
    trainer.test(model, test_dataloaders=test_data)


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_nodes', type=int, default=1, metavar='',
                        help='[Default: %(default)s]')
    parser.add_argument('--num_gpus', type=int, default=1, metavar='',
                        help='[Default: %(default)s]')
    parser.add_argument('--ddp_plugin', action='store_true',
                        help='[Default: %(default)s]')
    args = parser.parse_args()
    return args

if __name__ == '__main__':
    args = get_args()
    run(args)

submit.sh

#!/bin/bash

#SBATCH --nodes=2
#SBATCH --gres=gpu:4
#SBATCH --ntasks-per-node=4
#SBATCH --cpus-per-task=14
#SBATCH -o ~/logs/ddp_fail_%j.txt

echo "----------------------------------"
echo "Total of 8 GPUs over 2 nodes."
echo Conda environment = $CONDA_DEFAULT_ENV
echo $(pip list | grep lightning)

echo "Running at $(hostname)"
echo "Python command:"
CMD="    python3 ~/$@"
echo $CMD
echo "----------------------------------"
srun $CMD
echo "Done!"

Commands

# These are okay.
sbatch submit.sh debug.py --num_nodes 1 --num_gpus 4 
sbatch submit.sh debug.py --num_nodes 2 --num_gpus 4 
sbatch submit.sh debug.py --num_nodes 1 --num_gpus 4 --ddp_plugin

# This fails.
sbatch submit.sh debug.py --num_nodes 2 --num_gpus 4 --ddp_plugin

Expected behaviour

Code shouldn't fail..? :D

Environment

  • PyTorch Version: 1.8.0
  • OS: Ubuntu
  • How you installed PyTorch: conda
  • Python version: 3.8

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions