Skip to content

Deterministic flags and other globals do not get transferred to spawned processes #12685

@awaelchli

Description

@awaelchli

🐛 Bug

When using the ddp-spawn strategy, global flags like the torch.use_deterministic_algorithms flags set by the Trainer do not get transferred to the worker processes. The reason is that torch gets re-imported when spawning the workers (naturally) and thus the globals are back to defaults.

To Reproduce

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        assert torch.are_deterministic_algorithms_enabled()
        loss = self(batch).sum()
        return {"loss": loss}

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
        enable_progress_bar=False,
        strategy="ddp_spawn",
        devices=2,
        accelerator="cpu",
        deterministic=True,
    )
    trainer.fit(model, train_dataloaders=train_data)


if __name__ == "__main__":
    run()

Expected behavior

See assertion in the code.

Environment

PyTorch 1.10 and 1.11
PL 1.6 (master)

Additional context

I dreamed about it last night (no kidding) and checked this morning and the bug turned out to be real.

This applies to all globals from torch, not just deterministic flags. We already handle this problem for the seed in the DDPSpawnStrategy._worker_setup() method.

cc @tchaton @rohitgr7 @justusschock @kaushikb11 @awaelchli @akihironitta

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingpriority: 0High priority taskstrategy: ddpDistributedDataParallel

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions