-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workingdistributedGeneric distributed-related topicGeneric distributed-related topicenvironment: slurmhelp wantedOpen to be worked onOpen to be worked onpriority: 1Medium priority taskMedium priority task
Description
🐛 Bug
DDPPlugin crashes my training scripts when training models on multiple nodes (using SLURM).
When I use multiple GPUs on 1 node with the plugin -> all gucci.
When I use multiple GPUs on multiple nodes without the plugin -> all gucci.
When I use multiple GPUs on multiple nodes with the plugin -> crashes 😢
So when I run...
sbatch submit.sh debug.py --num_nodes 2 --num_gpus 4 --ddp_plugin
Code fails with...
----------------------------------
Total of 8 GPUs over 2 nodes.
Conda environment = DDP_Fail
pytorch-lightning 1.3.0
Running at my_secret_server.fi
Python command:
python3 ~/debug.py --num_gpus 4 --num_nodes 2 --ddp_plugin
----------------------------------
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/8
initializing ddp: GLOBAL_RANK: 7, MEMBER: 8/8
initializing ddp: GLOBAL_RANK: 6, MEMBER: 7/8
initializing ddp: GLOBAL_RANK: 4, MEMBER: 5/8
initializing ddp: GLOBAL_RANK: 1, MEMBER: 2/8
initializing ddp: GLOBAL_RANK: 3, MEMBER: 4/8
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Multi-processing is handled by Slurm.
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/8
initializing ddp: GLOBAL_RANK: 2, MEMBER: 3/8
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
Traceback (most recent call last):
...
ValueError: Invalid rank 5, rank should be in the interval [0, 3]
ValueError: Invalid rank 6, rank should be in the interval [0, 3]
ValueError: Invalid rank 7, rank should be in the interval [0, 3]
To Reproduce
debug.py:
import os
import argparse
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.plugins import DDPPlugin
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run(args):
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
test_data = DataLoader(RandomDataset(32, 64), batch_size=2)
if args.ddp_plugin:
plugin = DDPPlugin(find_unused_parameters=False)
else:
plugin = None
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
limit_val_batches=1,
num_sanity_val_steps=0,
max_epochs=1,
weights_summary=None,
gpus=args.num_gpus,
num_nodes=args.num_nodes,
accelerator='ddp' if args.num_gpus*args.num_nodes > 1 else None,
plugins=plugin,
)
trainer.fit(model, train_dataloader=train_data, val_dataloaders=val_data)
trainer.test(model, test_dataloaders=test_data)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--num_nodes', type=int, default=1, metavar='',
help='[Default: %(default)s]')
parser.add_argument('--num_gpus', type=int, default=1, metavar='',
help='[Default: %(default)s]')
parser.add_argument('--ddp_plugin', action='store_true',
help='[Default: %(default)s]')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = get_args()
run(args)submit.sh
#!/bin/bash
#SBATCH --nodes=2
#SBATCH --gres=gpu:4
#SBATCH --ntasks-per-node=4
#SBATCH --cpus-per-task=14
#SBATCH -o ~/logs/ddp_fail_%j.txt
echo "----------------------------------"
echo "Total of 8 GPUs over 2 nodes."
echo Conda environment = $CONDA_DEFAULT_ENV
echo $(pip list | grep lightning)
echo "Running at $(hostname)"
echo "Python command:"
CMD=" python3 ~/$@"
echo $CMD
echo "----------------------------------"
srun $CMD
echo "Done!"Commands
# These are okay.
sbatch submit.sh debug.py --num_nodes 1 --num_gpus 4
sbatch submit.sh debug.py --num_nodes 2 --num_gpus 4
sbatch submit.sh debug.py --num_nodes 1 --num_gpus 4 --ddp_plugin
# This fails.
sbatch submit.sh debug.py --num_nodes 2 --num_gpus 4 --ddp_pluginExpected behaviour
Code shouldn't fail..? :D
Environment
- PyTorch Version: 1.8.0
- OS: Ubuntu
- How you installed PyTorch: conda
- Python version: 3.8
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingdistributedGeneric distributed-related topicGeneric distributed-related topicenvironment: slurmhelp wantedOpen to be worked onOpen to be worked onpriority: 1Medium priority taskMedium priority task