-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🐛 Bug
When running with distributed_backend="ddp" and automatic_optimization=False, there is an error when returning a detached tensor. More details below. I'm using lightning version 1.0.4; I also saw the same thing in 1.0.3. First, here's the code:
# ddp_bug.py
import os
import torch
from torch.utils.data import Dataset
from pytorch_lightning import Trainer, LightningModule
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self, detach=False):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
self.detach = detach
def forward(self, x):
return self.layer(x)
def loss(self, batch, prediction):
# An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))
def training_step(self, batch, batch_idx):
opt = self.optimizers()
output = self.layer(batch)
loss1 = self.loss(batch, output)
self.manual_backward(loss1, opt)
opt.step()
opt.zero_grad()
if self.detach:
loss1 = loss1.detach()
return loss1
def validation_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
return {"x": loss}
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]
def run_test(args):
class TestModel(BoringModel): pass
# fake data
train_data = torch.utils.data.DataLoader(RandomDataset(32, 64))
val_data = torch.utils.data.DataLoader(RandomDataset(32, 64))
test_data = torch.utils.data.DataLoader(RandomDataset(32, 64))
# model
model = TestModel(args.detach)
trainer = Trainer.from_argparse_args(
args,
default_root_dir=os.getcwd(),
limit_train_batches=5,
limit_val_batches=5,
max_epochs=1,
weights_summary=None,
automatic_optimization=False,
)
trainer.fit(model, train_data, val_data)
trainer.test(test_dataloaders=test_data)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--detach', action='store_true')
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()
run_test(args)To Reproduce
If you save the script above as ddp_bug.py and run the following command:
python ddp_bug.py --gpus=2 --distributed_backend=ddp --detach
the resulting error is thrown:
self.trainer.accelerator_backend.backward(result, optimizer, opt_idx, *args, **kwargs)
File "/home/catalys1/pylt/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 97, in backward
model.backward(closure_loss, optimizer, opt_idx, *args, **kwargs)
File "/home/catalys1/pylt/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py", line 1103, in backward
loss.backward(*args, **kwargs)
File "/home/catalys1/pylt/lib/python3.8/site-packages/torch/tensor.py", line 185, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/home/catalys1/pylt/lib/python3.8/site-packages/torch/autograd/__init__.py", line 125, in backward
Variable._execution_engine.run_backward(
RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons:
1) Use of a module parameter outside the `forward` function. Please make sure model parameters are not shared across
multiple concurrent forward-backward passes
2) Reused parameters in multiple reentrant backward passes. For example, if you use multiple `checkpoint` functions to
wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward
passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases yet.
The following commands don't throw an error
python ddp_bug.py --gpus=2 --distributed_backend=ddp
python ddp_bug.py --gpus=2 --distributed_backend=dp --detach
python ddp_bug.py --gpus=1 --detach
Expected behavior
My understanding is that setting automatic_optimization=False prevents any optimization steps from happening outside of the training_step function. It looks like maybe this isn't being honored? If it were, there should be no difference between returning a detached or not-detached tensor I would think.
Environment
- PyTorch Version (e.g., 1.0): 1.6
- OS (e.g., Linux): Linux
- How you installed PyTorch (
conda,pip, source): pip - Python version: 3.8
- CUDA/cuDNN version: 10.2
- GPU models and configuration: 1080 Ti
Additional context
I encountered this bug because I'm working on a model that requires multiple backwards passes per batch.