Skip to content

DDP + mixed precision + sharded not working on PL 1.2.1 #6322

@mees

Description

@mees

🐛 Bug

After upgrading to pytorch-lightning 1.2.1, training with DDP + 16 bit precision + sharded is broken, as the training loss doesn't go down (stays around 2.31). Without the sharded option it seems to work.

To Reproduce

from argparse import ArgumentParser
import torch
from torch.nn import functional as F

import pytorch_lightning as pl
from pl_examples.basic_examples.mnist_datamodule import MNISTDataModule


class LitClassifier(pl.LightningModule):

    def __init__(self, hidden_dim=128, learning_rate=1e-3):
        super().__init__()
        self.save_hyperparameters()
        self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
        self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.l1(x))
        x = torch.relu(self.l2(x))
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--hidden_dim', type=int, default=128)
        parser.add_argument('--batch_size', type=int, default=32)
        parser.add_argument('--num_workers', type=int, default=4)
        parser.add_argument('--learning_rate', type=float, default=0.0001)
        return parser


def cli_main():
    pl.seed_everything(1234)
    parser = ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser = LitClassifier.add_model_specific_args(parser)
    parser = MNISTDataModule.add_argparse_args(parser)
    args = parser.parse_args()

    dm = MNISTDataModule.from_argparse_args(args)
    model = LitClassifier(args.hidden_dim, args.learning_rate)
    trainer = pl.Trainer.from_argparse_args(args, precision=16, gpus=[0, 1], accelerator="ddp", plugins='ddp_sharded')
    trainer.fit(model, datamodule=dm)


if __name__ == '__main__':
    cli_main()

Expected behavior

Training loss starts to decrease.

Environment

  • CUDA:
    - GPU: 4x TITAN X (Pascal)
    - available: True
    - version: 10.2
  • Packages:
    - numpy: 1.20.1
    - pyTorch_debug: True
    - pyTorch_version: 1.7.0
    - pytorch-lightning: 1.2.1
    - tqdm: 4.56.0

Metadata

Metadata

Assignees

Labels

bugSomething isn't workinghelp wantedOpen to be worked onpriority: 0High priority task

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions