-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workinghelp wantedOpen to be worked onOpen to be worked onpriority: 0High priority taskHigh priority task
Description
🐛 Bug
After upgrading to pytorch-lightning 1.2.1, training with DDP + 16 bit precision + sharded is broken, as the training loss doesn't go down (stays around 2.31). Without the sharded option it seems to work.
To Reproduce
from argparse import ArgumentParser
import torch
from torch.nn import functional as F
import pytorch_lightning as pl
from pl_examples.basic_examples.mnist_datamodule import MNISTDataModule
class LitClassifier(pl.LightningModule):
def __init__(self, hidden_dim=128, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()
self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = torch.relu(self.l1(x))
x = torch.relu(self.l2(x))
return x
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--hidden_dim', type=int, default=128)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--num_workers', type=int, default=4)
parser.add_argument('--learning_rate', type=float, default=0.0001)
return parser
def cli_main():
pl.seed_everything(1234)
parser = ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = LitClassifier.add_model_specific_args(parser)
parser = MNISTDataModule.add_argparse_args(parser)
args = parser.parse_args()
dm = MNISTDataModule.from_argparse_args(args)
model = LitClassifier(args.hidden_dim, args.learning_rate)
trainer = pl.Trainer.from_argparse_args(args, precision=16, gpus=[0, 1], accelerator="ddp", plugins='ddp_sharded')
trainer.fit(model, datamodule=dm)
if __name__ == '__main__':
cli_main()
Expected behavior
Training loss starts to decrease.
Environment
- CUDA:
- GPU: 4x TITAN X (Pascal)
- available: True
- version: 10.2 - Packages:
- numpy: 1.20.1
- pyTorch_debug: True
- pyTorch_version: 1.7.0
- pytorch-lightning: 1.2.1
- tqdm: 4.56.0
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workinghelp wantedOpen to be worked onOpen to be worked onpriority: 0High priority taskHigh priority task