-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Bug description
I've tried the new Adam(fused=True) implementation and noticed a massive performance decrease in my project while using AMP. Just changing the flag from True to False reduces my loss by over 200% (in a mipnerf variant implemented in PyTorch-lightning). I wouldn't expect the fused kernel to perform drastically worse so it could be a bug in the implementation // interaction with AMP. My current theory is that pytorch-lightning AMP scaler does not interact nicely with FusedAadam (which does the scaler logic on the GPU directly).
How to reproduce the bug
(Precision = 32, Fused = True) == (Precision = 16, Fused = False) == (Precision = 32, Fused = False)
but for (Precision = 16, Fused True) the loss deviates.
pytorch-lightning colab
relevant code:
torch.manual_seed(2809)
#declare fused & precision here
precision = 32
fused = True
target = torch.randn(10, 10, device='cuda')
data = torch.randn(10, 10, device='cuda')
class TestSystem(LightningModule):
def __init__(self):
super().__init__()
self.model = nn.Linear(10, 10)
self.loss_fn = nn.MSELoss()
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.model.parameters(), lr=1., fused=fused)
return [optimizer], []
#this is a dummy create the class
def train_dataloader(self):
return DataLoader(data)
def forward(self, batch):
return self.model(data)
def training_step(self, batch):
output = self(batch)
loss = loss_fn(output, target)
return loss
def training_epoch_end(self, outputs) -> None:
for i, loss in enumerate(outputs):
print("epoch {}, loss {:.3f}".format(i, loss['loss']))
torch.manual_seed(2809)
system = TestSystem()
trainer = Trainer(max_steps = 10, accelerator = 'gpu',precision=precision,
enable_model_summary=False, log_every_n_steps=-1)
trainer.fit(system)
#Fused True, Precision 32
#epoch 0, loss 1.225
#epoch 1, loss 12.567
#epoch 2, loss 1.926
#epoch 3, loss 3.361
#epoch 4, loss 6.156
#epoch 5, loss 3.749
#Fused False, Precision 16
#epoch 0, loss 1.225
#epoch 1, loss 12.567
#epoch 2, loss 1.926
#epoch 3, loss 3.361
#epoch 4, loss 6.156
#epoch 5, loss 3.749
#Fused False, Precision 32
#epoch 0, loss 1.225
#epoch 1, loss 12.567
#epoch 2, loss 1.926
#epoch 3, loss 3.361
#epoch 4, loss 6.156
#epoch 5, loss 3.749
#Fused True, Precision 16
#epoch 0, loss 1.225
#epoch 1, loss 11.941
#epoch 2, loss 1.866
#epoch 3, loss 3.285
#epoch 4, loss 5.882
#epoch 5, loss 3.499Error messages and logs
# Error messages and logs here please
Environment
Versions
Collecting environment information...
PyTorch version: 1.13.0
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.4 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.16.3
Libc version: glibc-2.31
Python version: 3.9.12 (main, Jun 1 2022, 11:38:51) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.15.0-43-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.7.99
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 2070
Nvidia driver version: 515.65.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==1.23.4
[pip3] pytorch-lightning==1.7.7
[pip3] torch==1.13.0
[pip3] torch-tb-profiler==0.4.0
[pip3] torchaudio==0.13.0
[pip3] torchmetrics==0.10.0
[pip3] torchvision==0.14.0
[conda] blas 1.0 mkl
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py39h7f8727e_0
[conda] mkl_fft 1.3.1 py39hd3c417c_0
[conda] mkl_random 1.2.2 py39h51133e4_0
[conda] numpy 1.23.4 pypi_0 pypi
[conda] numpy-base 1.23.3 py39h31eccc5_0
[conda] pytorch 1.13.0 py3.9_cuda11.7_cudnn8.5.0_0 pytorch
[conda] pytorch-cuda 11.7 h67b0de4_0 pytorch
[conda] pytorch-lightning 1.7.7 pypi_0 pypi
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] torch 1.12.1+cu113 pypi_0 pypi
[conda] torch-tb-profiler 0.4.0 pypi_0 pypi
[conda] torchaudio 0.13.0 py39_cu117 pytorch
[conda] torchmetrics 0.10.0 pypi_0 pypi
[conda] torchvision 0.13.1+cu113 pypi_0 pypi
More info
I initially thought it was Pytorch related but I think it might be Pytorch-lightning at the end, initially I mentioned this issue here: Pytorch Issue
cc @carmocca @justusschock @awaelchli @akihironitta @rohitgr7 @Borda