-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workingpriority: 0High priority taskHigh priority taskstrategy: ddpDistributedDataParallelDistributedDataParallel
Milestone
Description
🐛 Bug
When using the ddp-spawn strategy, global flags like the torch.use_deterministic_algorithms flags set by the Trainer do not get transferred to the worker processes. The reason is that torch gets re-imported when spawning the workers (naturally) and thus the globals are back to defaults.
To Reproduce
import os
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
assert torch.are_deterministic_algorithms_enabled()
loss = self(batch).sum()
return {"loss": loss}
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run():
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_train_batches=1,
limit_val_batches=1,
limit_test_batches=1,
num_sanity_val_steps=0,
max_epochs=1,
enable_model_summary=False,
enable_progress_bar=False,
strategy="ddp_spawn",
devices=2,
accelerator="cpu",
deterministic=True,
)
trainer.fit(model, train_dataloaders=train_data)
if __name__ == "__main__":
run()Expected behavior
See assertion in the code.
Environment
PyTorch 1.10 and 1.11
PL 1.6 (master)
Additional context
I dreamed about it last night (no kidding) and checked this morning and the bug turned out to be real.
This applies to all globals from torch, not just deterministic flags. We already handle this problem for the seed in the DDPSpawnStrategy._worker_setup() method.
cc @tchaton @rohitgr7 @justusschock @kaushikb11 @awaelchli @akihironitta
akihironittaakihironitta
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingpriority: 0High priority taskHigh priority taskstrategy: ddpDistributedDataParallelDistributedDataParallel