-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
3rd partyRelated to a 3rd-partyRelated to a 3rd-partybugSomething isn't workingSomething isn't workinghelp wantedOpen to be worked onOpen to be worked on
Description
🐛 Bug
File "repro apex.py", line 51, in <module>
trainer.fit(model)
File "/home/aw18f408/repositories/pytorch-lightning/pytorch_lightning/trainer/trainer.py", line 481, in fit
results = self.accelerator_backend.train()
File "/home/aw18f408/repositories/pytorch-lightning/pytorch_lightning/accelerators/gpu_accelerator.py", line 67, in train
results = self.train_or_test()
File "/home/aw18f408/repositories/pytorch-lightning/pytorch_lightning/accelerators/accelerator.py", line 68, in train_or_test
results = self.trainer.train()
File "/home/aw18f408/repositories/pytorch-lightning/pytorch_lightning/trainer/trainer.py", line 532, in train
self.train_loop.run_training_epoch()
File "/home/aw18f408/repositories/pytorch-lightning/pytorch_lightning/trainer/training_loop.py", line 572, in run_training_epoch
batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)
File "/home/aw18f408/repositories/pytorch-lightning/pytorch_lightning/trainer/training_loop.py", line 729, in run_training_batch
self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
File "/home/aw18f408/repositories/pytorch-lightning/pytorch_lightning/trainer/training_loop.py", line 505, in optimizer_step
model_ref.optimizer_step(
File "/home/aw18f408/repositories/pytorch-lightning/pytorch_lightning/core/lightning.py", line 1263, in optimizer_step
optimizer.step(closure=optimizer_closure)
File "/home/aw18f408/repositories/pytorch-lightning/pytorch_lightning/core/optimizer.py", line 278, in step
self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs)
File "/home/aw18f408/repositories/pytorch-lightning/pytorch_lightning/core/optimizer.py", line 133, in __optimizer_step
trainer.precision_connector.backend.optimizer_step(trainer, optimizer, closure)
File "/home/aw18f408/repositories/pytorch-lightning/pytorch_lightning/plugins/apex.py", line 138, in optimizer_step
closure()
File "/home/aw18f408/repositories/pytorch-lightning/pytorch_lightning/trainer/training_loop.py", line 719, in train_step_and_backward_closure
result = self.training_step_and_backward(
File "/home/aw18f408/repositories/pytorch-lightning/pytorch_lightning/trainer/training_loop.py", line 827, in training_step_and_backward
self.backward(result, optimizer, opt_idx)
File "/home/aw18f408/repositories/pytorch-lightning/pytorch_lightning/trainer/training_loop.py", line 847, in backward
result.closure_loss = self.trainer.accelerator_backend.backward(
File "/home/aw18f408/repositories/pytorch-lightning/pytorch_lightning/accelerators/accelerator.py", line 97, in backward
closure_loss = self.trainer.precision_connector.backend.backward(
File "/home/aw18f408/repositories/pytorch-lightning/pytorch_lightning/plugins/apex.py", line 53, in backward
model.backward(closure_loss, optimizer, opt_idx)
File "/home/aw18f408/repositories/pytorch-lightning/pytorch_lightning/core/lightning.py", line 1155, in backward
loss.backward(*args, **kwargs)
File "/home/aw18f408/.conda/envs/lightning/lib/python3.8/site-packages/torch/tensor.py", line 221, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/home/aw18f408/.conda/envs/lightning/lib/python3.8/site-packages/torch/autograd/__init__.py", line 130, in backward
Variable._execution_engine.run_backward(
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
To Reproduce
import torch
from torch import optim
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import LightningModule, Trainer
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class AMPModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx, optimizer_idx):
output = self(batch)
loss = output.mean()
return {"loss": loss}
def train_dataloader(self):
return DataLoader(RandomDataset(32, 64))
def configure_optimizers(self):
optimizer1 = torch.optim.Adam(self.parameters(), lr=0.01)
optimizer2 = optim.SGD(self.parameters(), lr=0.01)
return [optimizer1, optimizer2]
if __name__ == "__main__":
model = AMPModel()
trainer = Trainer(
max_epochs=1,
precision=16,
amp_backend='apex',
gpus=1,
)
trainer.fit(model)Expected behavior
No crash
Environment
- CUDA:
- GPU:
- GeForce RTX 2080 Ti
- GeForce RTX 2080 Ti
- available: True
- version: 11.0 - Packages:
- numpy: 1.19.5
- pyTorch_debug: False
- pyTorch_version: 1.7.1
- pytorch-lightning: 1.2.0dev
- tqdm: 4.56.0 - System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.8.3
- version: Proposal for help #1 SMP Thu Apr 9 13:49:54 UTC 2020
Additional context
discovered in #5507, in the test tests/models/test_amp::test_amp_with_apex
Metadata
Metadata
Assignees
Labels
3rd partyRelated to a 3rd-partyRelated to a 3rd-partybugSomething isn't workingSomething isn't workinghelp wantedOpen to be worked onOpen to be worked on