Skip to content

show progressbar only on progress_rank 0 on ddp_slurm #4364

@cool425589

Description

@cool425589

🐛 Bug

The progress bars will show repeatedly when using slurm for multi-nodes

To Reproduce

Using pytorch_lightning

To run this template just do:
python generative_adversarial_net.py

After a few epochs, launch TensorBoard to see the images being generated at every batch:

tensorboard --logdir default

import os
from argparse import ArgumentParser, Namespace
from collections import OrderedDict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

from pytorch_lightning.core import LightningModule
from pytorch_lightning.trainer import Trainer


class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super().__init__()
        self.img_shape = img_shape

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *self.img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super().__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity


class GAN(LightningModule):

    def __init__(self,
                 latent_dim: int = 100,
                 lr: float = 0.0002,
                 b1: float = 0.5,
                 b2: float = 0.999,
                 batch_size: int = 64, **kwargs):
        super().__init__()

        self.latent_dim = latent_dim
        self.lr = lr
        self.b1 = b1
        self.b2 = b2
        self.batch_size = batch_size

        # networks
        mnist_shape = (1, 28, 28)
        self.generator = Generator(latent_dim=self.latent_dim, img_shape=mnist_shape)
        self.discriminator = Discriminator(img_shape=mnist_shape)

        self.validation_z = torch.randn(8, self.latent_dim)

        self.example_input_array = torch.zeros(2, hparams.latent_dim)

    def forward(self, z):
        return self.generator(z)

    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)

    def training_step(self, batch, batch_idx, optimizer_idx):
        imgs, _ = batch

        # sample noise
        z = torch.randn(imgs.shape[0], self.latent_dim)
        z = z.type_as(imgs)

        # train generator
        if optimizer_idx == 0:

            # generate images
            self.generated_imgs = self(z)

            # log sampled images
            sample_imgs = self.generated_imgs[:6]
            grid = torchvision.utils.make_grid(sample_imgs)
            self.logger.experiment.add_image('generated_images', grid, 0)

            # ground truth result (ie: all fake)
            # put on GPU because we created this tensor inside training_loop
            valid = torch.ones(imgs.size(0), 1)
            valid = valid.type_as(imgs)

            # adversarial loss is binary cross-entropy
            g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)
            tqdm_dict = {'g_loss': g_loss}


            return {'loss':g_loss}

        # train discriminator
        if optimizer_idx == 1:
            # Measure discriminator's ability to classify real from generated samples

            # how well can it label as real?
            valid = torch.ones(imgs.size(0), 1)
            valid = valid.type_as(imgs)

            real_loss = self.adversarial_loss(self.discriminator(imgs), valid)

            # how well can it label as fake?
            fake = torch.zeros(imgs.size(0), 1)
            fake = fake.type_as(imgs)

            fake_loss = self.adversarial_loss(
                self.discriminator(self(z).detach()), fake)

            # discriminator loss is the average of these
            d_loss = (real_loss + fake_loss) / 2
            tqdm_dict = {'d_loss': d_loss}

            return {'loss':d_loss}

    def configure_optimizers(self):
        lr = self.lr
        b1 = self.b1
        b2 = self.b2

        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
        return [opt_g, opt_d], []

    def train_dataloader(self):
        transform = transforms.Compose([transforms.ToTensor(),
                                        transforms.Normalize([0.5], [0.5])])
        dataset = MNIST(os.getcwd(), train=True, download=True, transform=transform)
        return DataLoader(dataset, batch_size=self.batch_size)

    def on_epoch_end(self):
        pass


def main(args: Namespace) -> None:
    # ------------------------
    # 1 INIT LIGHTNING MODEL
    # ------------------------
    model = GAN(**vars(args))

    # ------------------------
    # 2 INIT TRAINER
    # ------------------------
    # If use distubuted training  PyTorch recommends to use DistributedDataParallel.
    # See: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel
    trainer = Trainer(gpus=8, num_nodes = 1, distributed_backend= 'ddp', profiler=True , max_epochs=10)  

    # ------------------------
    # 3 START TRAINING
    # ------------------------
    trainer.fit(model)


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
    parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
    parser.add_argument("--b1", type=float, default=0.5,
                        help="adam: decay of first order momentum of gradient")
    parser.add_argument("--b2", type=float, default=0.999,
                        help="adam: decay of first order momentum of gradient")
    parser.add_argument("--latent_dim", type=int, default=100,
                        help="dimensionality of the latent space")

    hparams = parser.parse_args()

    main(hparams)

Submit

#!/bin/bash
#SBATCH --gres=gpu:8
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=8
conda activate your_env
srun python gan.py

Expected behavior

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Multi-processing is handled by Slurm.
LOCAL_RANK: 6 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Multi-processing is handled by Slurm.
GPU available: True, used: True
LOCAL_RANK: 2 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
TPU available: False, using: 0 TPU cores
Multi-processing is handled by Slurm.
LOCAL_RANK: 4 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
initializing ddp: GLOBAL_RANK: 6, MEMBER: 7/8
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Multi-processing is handled by Slurm.
LOCAL_RANK: 3 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
initializing ddp: GLOBAL_RANK: 2, MEMBER: 3/8
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
initializing ddp: GLOBAL_RANK: 4, MEMBER: 5/8
Multi-processing is handled by Slurm.
LOCAL_RANK: 7 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Multi-processing is handled by Slurm.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
initializing ddp: GLOBAL_RANK: 3, MEMBER: 4/8
initializing ddp: GLOBAL_RANK: 7, MEMBER: 8/8
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/8
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Multi-processing is handled by Slurm.
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Multi-processing is handled by Slurm.
LOCAL_RANK: 5 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
initializing ddp: GLOBAL_RANK: 1, MEMBER: 2/8
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/8
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.

  | Name          | Type          | Params | In sizes | Out sizes     
----------------------------------------------------------------------------
0 | generator     | Generator     | 1 M    | [2, 100] | [2, 1, 28, 28]
1 | discriminator | Discriminator | 533 K  | ?        | ?             

Training: 0it [00:00, ?it/s]
Training:   0%|          | 0/118 [00:00<?, ?it/s]
Epoch 0:   1%|          | 1/118 [00:00<01:10,  1.65it/s, loss=0.710, v_num=194638]
Epoch 0:   2%|▏         | 2/118 [00:00<00:36,  3.17it/s, loss=0.686, v_num=194638]
Epoch 0:   3%|▎         | 3/118 [00:00<00:24,  4.61it/s, loss=0.664, v_num=194638]
Epoch 0:   3%|▎         | 4/118 [00:00<00:19,  5.97it/s, loss=0.646, v_num=194638]
Epoch 0:   4%|▍         | 5/118 [00:00<00:15,  7.25it/s, loss=0.630, v_num=194638]
Epoch 0:   5%|▌         | 6/118 [00:00<00:13,  8.47it/s, loss=0.630, v_num=194638]
Epoch 0:   6%|▌         | 7/118 [00:00<00:11,  9.62it/s, loss=0.606, v_num=194638]
Epoch 0:   7%|▋         | 8/118 [00:00<00:10, 10.71it/s, loss=0.597, v_num=194638]
Epoch 0:   8%|▊         | 9/118 [00:00<00:09, 11.74it/s, loss=0.589, v_n
...
...

Actual behavior

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Multi-processing is handled by Slurm.
LOCAL_RANK: 6 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Multi-processing is handled by Slurm.
GPU available: True, used: True
LOCAL_RANK: 2 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
TPU available: False, using: 0 TPU cores
Multi-processing is handled by Slurm.
LOCAL_RANK: 4 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
initializing ddp: GLOBAL_RANK: 6, MEMBER: 7/8
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Multi-processing is handled by Slurm.
LOCAL_RANK: 3 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
initializing ddp: GLOBAL_RANK: 2, MEMBER: 3/8
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
initializing ddp: GLOBAL_RANK: 4, MEMBER: 5/8
Multi-processing is handled by Slurm.
LOCAL_RANK: 7 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Multi-processing is handled by Slurm.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
initializing ddp: GLOBAL_RANK: 3, MEMBER: 4/8
initializing ddp: GLOBAL_RANK: 7, MEMBER: 8/8
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/8
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Multi-processing is handled by Slurm.
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Multi-processing is handled by Slurm.
LOCAL_RANK: 5 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
initializing ddp: GLOBAL_RANK: 1, MEMBER: 2/8
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/8
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.

  | Name          | Type          | Params | In sizes | Out sizes     
----------------------------------------------------------------------------
0 | generator     | Generator     | 1 M    | [2, 100] | [2, 1, 28, 28]
1 | discriminator | Discriminator | 533 K  | ?        | ?             

Training: 0it [00:00, ?it/s]
Training:   0%|          | 0/118 [00:00<?, ?it/s]
Epoch 0:   0%|          | 0/118 [00:00<?, ?it/s] 
Epoch 0:   1%|          | 1/118 [00:00<01:10,  1.65it/s]
Epoch 0:   1%|          | 1/118 [00:00<01:10,  1.65it/s, loss=0.710, v_num=194638]
Epoch 0:   2%|▏         | 2/118 [00:00<00:36,  3.17it/s, loss=0.686, v_num=194638]
Epoch 0:   3%|▎         | 3/118 [00:00<00:24,  4.61it/s, loss=0.664, v_num=194638]
Epoch 0:   3%|▎         | 4/118 [00:00<00:19,  5.97it/s, loss=0.646, v_num=194638]
Epoch 0:   4%|▍         | 5/118 [00:00<00:15,  7.25it/s, loss=0.630, v_num=194638]
Epoch 0:   5%|▌         | 6/118 [00:00<00:13,  8.47it/s, loss=0.630, v_num=194638]
Epoch 0:   5%|▌         | 6/118 [00:00<00:13,  8.47it/s, loss=0.617, v_num=194638]
Epoch 0:   6%|▌         | 7/118 [00:00<00:11,  9.62it/s, loss=0.606, v_num=194638]
Epoch 0:   7%|▋         | 8/118 [00:00<00:10, 10.71it/s, loss=0.597, v_num=194638]
Epoch 0:   8%|▊         | 9/118 [00:00<00:09, 11.74it/s, loss=0.589, v_n
Training: 0it [00:00, ?it/s]
Training:   0%|          | 0/118 [00:00<?, ?it/s]
Epoch 0:   0%|          | 0/118 [00:00<?, ?it/s] 
Epoch 0:   1%|          | 1/118 [00:00<01:10,  1.65it/s]
Epoch 0:   1%|          | 1/118 [00:00<01:11,  1.65it/s, loss=0.710, v_num=194638]
Epoch 0:   2%|▏         | 2/118 [00:00<00:36,  3.17it/s, loss=0.686, v_num=194638]
Epoch 0:   3%|▎         | 3/118 [00:00<00:24,  4.61it/s, loss=0.665, v_num=194638]
Epoch 0:   3%|▎         | 4/118 [00:00<00:19,  5.97it/s, loss=0.647, v_num=194638]
Epoch 0:   4%|▍         | 5/118 [00:00<00:15,  7.25it/s, loss=0.631, v_num=194638]
Epoch 0:   5%|▌         | 6/118 [00:00<00:13,  8.47it/s, loss=0.631, v_num=194638]
Epoch 0:   5%|▌         | 6/118 [00:00<00:13,  8.46it/s, loss=0.618, v_num=194638]
Epoch 0:   6%|▌         | 7/118 [00:00<00:11,  9.61it/s, loss=0.606, v_num=194638]
Epoch 0:   7%|▋         | 8/118 [00:00<00:10, 10.70it/s, loss=0.597, v_num=194638]
Epoch 0:   8%|▊         | 9/118 [00:00<00:09, 11.74it/s, loss=0.590, v_n
Training: 0it [00:00, ?it/s]
Training:   0%|          | 0/118 [00:00<?, ?it/s]
Epoch 0:   0%|          | 0/118 [00:00<?, ?it/s] 
Epoch 0:   1%|          | 1/118 [00:00<01:10,  1.65it/s]
Epoch 0:   1%|          | 1/118 [00:00<01:10,  1.65it/s, loss=0.711, v_num=194638]
Epoch 0:   2%|▏         | 2/118 [00:00<00:36,  3.17it/s, loss=0.686, v_num=194638]

Environment

environment
OS:CentOS Linux

  • Python version: 3.8
    install packages
    pytorch-lightning 1.0.1
    torch 1.6.0
    torchvision 0.7.0

How to fix (my understanding)

ddp_slurm_accelerator.py

        # toggle prog bar
        
        if self.trainer.global_rank == 0 and self.trainer.progress_bar_callback is not None:
            self.trainer.progress_bar_callback.disable()

To

        # toggle prog bar
        
        if self.trainer.global_rank != 0 and self.trainer.progress_bar_callback is not None:
            self.trainer.progress_bar_callback.disable()

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureIs an improvement or enhancementhelp wantedOpen to be worked on

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions