-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
featureIs an improvement or enhancementIs an improvement or enhancementhelp wantedOpen to be worked onOpen to be worked on
Description
🐛 Bug
The progress bars will show repeatedly when using slurm for multi-nodes
To Reproduce
Using pytorch_lightning
To run this template just do:
python generative_adversarial_net.py
After a few epochs, launch TensorBoard to see the images being generated at every batch:
tensorboard --logdir default
import os
from argparse import ArgumentParser, Namespace
from collections import OrderedDict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from pytorch_lightning.core import LightningModule
from pytorch_lightning.trainer import Trainer
class Generator(nn.Module):
def __init__(self, latent_dim, img_shape):
super().__init__()
self.img_shape = img_shape
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*block(latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), *self.img_shape)
return img
class Discriminator(nn.Module):
def __init__(self, img_shape):
super().__init__()
self.model = nn.Sequential(
nn.Linear(int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid(),
)
def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity
class GAN(LightningModule):
def __init__(self,
latent_dim: int = 100,
lr: float = 0.0002,
b1: float = 0.5,
b2: float = 0.999,
batch_size: int = 64, **kwargs):
super().__init__()
self.latent_dim = latent_dim
self.lr = lr
self.b1 = b1
self.b2 = b2
self.batch_size = batch_size
# networks
mnist_shape = (1, 28, 28)
self.generator = Generator(latent_dim=self.latent_dim, img_shape=mnist_shape)
self.discriminator = Discriminator(img_shape=mnist_shape)
self.validation_z = torch.randn(8, self.latent_dim)
self.example_input_array = torch.zeros(2, hparams.latent_dim)
def forward(self, z):
return self.generator(z)
def adversarial_loss(self, y_hat, y):
return F.binary_cross_entropy(y_hat, y)
def training_step(self, batch, batch_idx, optimizer_idx):
imgs, _ = batch
# sample noise
z = torch.randn(imgs.shape[0], self.latent_dim)
z = z.type_as(imgs)
# train generator
if optimizer_idx == 0:
# generate images
self.generated_imgs = self(z)
# log sampled images
sample_imgs = self.generated_imgs[:6]
grid = torchvision.utils.make_grid(sample_imgs)
self.logger.experiment.add_image('generated_images', grid, 0)
# ground truth result (ie: all fake)
# put on GPU because we created this tensor inside training_loop
valid = torch.ones(imgs.size(0), 1)
valid = valid.type_as(imgs)
# adversarial loss is binary cross-entropy
g_loss = self.adversarial_loss(self.discriminator(self(z)), valid)
tqdm_dict = {'g_loss': g_loss}
return {'loss':g_loss}
# train discriminator
if optimizer_idx == 1:
# Measure discriminator's ability to classify real from generated samples
# how well can it label as real?
valid = torch.ones(imgs.size(0), 1)
valid = valid.type_as(imgs)
real_loss = self.adversarial_loss(self.discriminator(imgs), valid)
# how well can it label as fake?
fake = torch.zeros(imgs.size(0), 1)
fake = fake.type_as(imgs)
fake_loss = self.adversarial_loss(
self.discriminator(self(z).detach()), fake)
# discriminator loss is the average of these
d_loss = (real_loss + fake_loss) / 2
tqdm_dict = {'d_loss': d_loss}
return {'loss':d_loss}
def configure_optimizers(self):
lr = self.lr
b1 = self.b1
b2 = self.b2
opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
return [opt_g, opt_d], []
def train_dataloader(self):
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])])
dataset = MNIST(os.getcwd(), train=True, download=True, transform=transform)
return DataLoader(dataset, batch_size=self.batch_size)
def on_epoch_end(self):
pass
def main(args: Namespace) -> None:
# ------------------------
# 1 INIT LIGHTNING MODEL
# ------------------------
model = GAN(**vars(args))
# ------------------------
# 2 INIT TRAINER
# ------------------------
# If use distubuted training PyTorch recommends to use DistributedDataParallel.
# See: https://pytorch.org/docs/stable/nn.html#torch.nn.DataParallel
trainer = Trainer(gpus=8, num_nodes = 1, distributed_backend= 'ddp', profiler=True , max_epochs=10)
# ------------------------
# 3 START TRAINING
# ------------------------
trainer.fit(model)
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5,
help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999,
help="adam: decay of first order momentum of gradient")
parser.add_argument("--latent_dim", type=int, default=100,
help="dimensionality of the latent space")
hparams = parser.parse_args()
main(hparams)
Submit
#!/bin/bash
#SBATCH --gres=gpu:8
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=8
conda activate your_env
srun python gan.py
Expected behavior
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Multi-processing is handled by Slurm.
LOCAL_RANK: 6 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Multi-processing is handled by Slurm.
GPU available: True, used: True
LOCAL_RANK: 2 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
TPU available: False, using: 0 TPU cores
Multi-processing is handled by Slurm.
LOCAL_RANK: 4 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
initializing ddp: GLOBAL_RANK: 6, MEMBER: 7/8
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Multi-processing is handled by Slurm.
LOCAL_RANK: 3 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
initializing ddp: GLOBAL_RANK: 2, MEMBER: 3/8
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
initializing ddp: GLOBAL_RANK: 4, MEMBER: 5/8
Multi-processing is handled by Slurm.
LOCAL_RANK: 7 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Multi-processing is handled by Slurm.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
initializing ddp: GLOBAL_RANK: 3, MEMBER: 4/8
initializing ddp: GLOBAL_RANK: 7, MEMBER: 8/8
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/8
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Multi-processing is handled by Slurm.
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Multi-processing is handled by Slurm.
LOCAL_RANK: 5 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
initializing ddp: GLOBAL_RANK: 1, MEMBER: 2/8
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/8
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
| Name | Type | Params | In sizes | Out sizes
----------------------------------------------------------------------------
0 | generator | Generator | 1 M | [2, 100] | [2, 1, 28, 28]
1 | discriminator | Discriminator | 533 K | ? | ?
Training: 0it [00:00, ?it/s]
Training: 0%| | 0/118 [00:00<?, ?it/s]
Epoch 0: 1%| | 1/118 [00:00<01:10, 1.65it/s, loss=0.710, v_num=194638]
Epoch 0: 2%|▏ | 2/118 [00:00<00:36, 3.17it/s, loss=0.686, v_num=194638]
Epoch 0: 3%|▎ | 3/118 [00:00<00:24, 4.61it/s, loss=0.664, v_num=194638]
Epoch 0: 3%|▎ | 4/118 [00:00<00:19, 5.97it/s, loss=0.646, v_num=194638]
Epoch 0: 4%|▍ | 5/118 [00:00<00:15, 7.25it/s, loss=0.630, v_num=194638]
Epoch 0: 5%|▌ | 6/118 [00:00<00:13, 8.47it/s, loss=0.630, v_num=194638]
Epoch 0: 6%|▌ | 7/118 [00:00<00:11, 9.62it/s, loss=0.606, v_num=194638]
Epoch 0: 7%|▋ | 8/118 [00:00<00:10, 10.71it/s, loss=0.597, v_num=194638]
Epoch 0: 8%|▊ | 9/118 [00:00<00:09, 11.74it/s, loss=0.589, v_n
...
...
Actual behavior
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Multi-processing is handled by Slurm.
LOCAL_RANK: 6 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Multi-processing is handled by Slurm.
GPU available: True, used: True
LOCAL_RANK: 2 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
TPU available: False, using: 0 TPU cores
Multi-processing is handled by Slurm.
LOCAL_RANK: 4 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
initializing ddp: GLOBAL_RANK: 6, MEMBER: 7/8
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Multi-processing is handled by Slurm.
LOCAL_RANK: 3 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
initializing ddp: GLOBAL_RANK: 2, MEMBER: 3/8
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
initializing ddp: GLOBAL_RANK: 4, MEMBER: 5/8
Multi-processing is handled by Slurm.
LOCAL_RANK: 7 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Multi-processing is handled by Slurm.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
initializing ddp: GLOBAL_RANK: 3, MEMBER: 4/8
initializing ddp: GLOBAL_RANK: 7, MEMBER: 8/8
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/8
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Multi-processing is handled by Slurm.
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Multi-processing is handled by Slurm.
LOCAL_RANK: 5 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
initializing ddp: GLOBAL_RANK: 1, MEMBER: 2/8
initializing ddp: GLOBAL_RANK: 5, MEMBER: 6/8
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
Set SLURM handle signals.
| Name | Type | Params | In sizes | Out sizes
----------------------------------------------------------------------------
0 | generator | Generator | 1 M | [2, 100] | [2, 1, 28, 28]
1 | discriminator | Discriminator | 533 K | ? | ?
Training: 0it [00:00, ?it/s]
Training: 0%| | 0/118 [00:00<?, ?it/s]
Epoch 0: 0%| | 0/118 [00:00<?, ?it/s]
Epoch 0: 1%| | 1/118 [00:00<01:10, 1.65it/s]
Epoch 0: 1%| | 1/118 [00:00<01:10, 1.65it/s, loss=0.710, v_num=194638]
Epoch 0: 2%|▏ | 2/118 [00:00<00:36, 3.17it/s, loss=0.686, v_num=194638]
Epoch 0: 3%|▎ | 3/118 [00:00<00:24, 4.61it/s, loss=0.664, v_num=194638]
Epoch 0: 3%|▎ | 4/118 [00:00<00:19, 5.97it/s, loss=0.646, v_num=194638]
Epoch 0: 4%|▍ | 5/118 [00:00<00:15, 7.25it/s, loss=0.630, v_num=194638]
Epoch 0: 5%|▌ | 6/118 [00:00<00:13, 8.47it/s, loss=0.630, v_num=194638]
Epoch 0: 5%|▌ | 6/118 [00:00<00:13, 8.47it/s, loss=0.617, v_num=194638]
Epoch 0: 6%|▌ | 7/118 [00:00<00:11, 9.62it/s, loss=0.606, v_num=194638]
Epoch 0: 7%|▋ | 8/118 [00:00<00:10, 10.71it/s, loss=0.597, v_num=194638]
Epoch 0: 8%|▊ | 9/118 [00:00<00:09, 11.74it/s, loss=0.589, v_n
Training: 0it [00:00, ?it/s]
Training: 0%| | 0/118 [00:00<?, ?it/s]
Epoch 0: 0%| | 0/118 [00:00<?, ?it/s]
Epoch 0: 1%| | 1/118 [00:00<01:10, 1.65it/s]
Epoch 0: 1%| | 1/118 [00:00<01:11, 1.65it/s, loss=0.710, v_num=194638]
Epoch 0: 2%|▏ | 2/118 [00:00<00:36, 3.17it/s, loss=0.686, v_num=194638]
Epoch 0: 3%|▎ | 3/118 [00:00<00:24, 4.61it/s, loss=0.665, v_num=194638]
Epoch 0: 3%|▎ | 4/118 [00:00<00:19, 5.97it/s, loss=0.647, v_num=194638]
Epoch 0: 4%|▍ | 5/118 [00:00<00:15, 7.25it/s, loss=0.631, v_num=194638]
Epoch 0: 5%|▌ | 6/118 [00:00<00:13, 8.47it/s, loss=0.631, v_num=194638]
Epoch 0: 5%|▌ | 6/118 [00:00<00:13, 8.46it/s, loss=0.618, v_num=194638]
Epoch 0: 6%|▌ | 7/118 [00:00<00:11, 9.61it/s, loss=0.606, v_num=194638]
Epoch 0: 7%|▋ | 8/118 [00:00<00:10, 10.70it/s, loss=0.597, v_num=194638]
Epoch 0: 8%|▊ | 9/118 [00:00<00:09, 11.74it/s, loss=0.590, v_n
Training: 0it [00:00, ?it/s]
Training: 0%| | 0/118 [00:00<?, ?it/s]
Epoch 0: 0%| | 0/118 [00:00<?, ?it/s]
Epoch 0: 1%| | 1/118 [00:00<01:10, 1.65it/s]
Epoch 0: 1%| | 1/118 [00:00<01:10, 1.65it/s, loss=0.711, v_num=194638]
Epoch 0: 2%|▏ | 2/118 [00:00<00:36, 3.17it/s, loss=0.686, v_num=194638]
Environment
environment
OS:CentOS Linux
- Python version: 3.8
install packages
pytorch-lightning 1.0.1
torch 1.6.0
torchvision 0.7.0
How to fix (my understanding)
ddp_slurm_accelerator.py
# toggle prog bar
if self.trainer.global_rank == 0 and self.trainer.progress_bar_callback is not None:
self.trainer.progress_bar_callback.disable()
To
# toggle prog bar
if self.trainer.global_rank != 0 and self.trainer.progress_bar_callback is not None:
self.trainer.progress_bar_callback.disable()
Metadata
Metadata
Assignees
Labels
featureIs an improvement or enhancementIs an improvement or enhancementhelp wantedOpen to be worked onOpen to be worked on