-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Looking at the testcase found below and the discussion at https://discuss.pytorch.org/t/layernorms-grads-become-nan-after-first-epoch/133292/9 it sounds increasingly likely that there is a bug in version Lightning 1.4.9's AMP implementation.
My goal is to find out why the layer immediately proceeding the loss function ends up with a Inf gradient. I see the loss function returning a value of 1.4138 so I was surprised to see an overflow.
If you walk through the code between the end of forward() and the beginning of backward() you will see that grad_scaler.py:165 gets invoked: return outputs * self._scale.to(device=outputs.device, non_blocking=True) where self._scale multiplies the loss value by 65536. When using float16 tensors, I believe this operation is responsible for the overflow, but maybe you know something I don't...
Testcase:
import math
import os
from typing import Optional, Tuple
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.utilities.types import STEP_OUTPUT
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import DataLoader, Subset, Dataset
# Bug report: https://github.com/pytorch/pytorch/issues/65301
DETERMINISTIC = True
DETERMINISTIC_SEED = 41
# Debugging "one of the variables needed for gradient computation has been modified by an inplace operation"
torch.autograd.set_detect_anomaly(True)
class OutdoorTemperatureDataset(Dataset):
def __init__(self, batch_size: int):
self.batch_size = batch_size
self.input_horizon = 1
self.output_horizon = 2
self.total_horizon = self.input_horizon + self.output_horizon
self.outdoor_temperature = torch.tensor([1.0]).repeat(2, self.total_horizon)
def __getitem__(self, index) -> Tuple[Tensor, Tensor]:
samples = torch.stack([self.outdoor_temperature[index]])
# Convert [features, samples] to [samples, features]
samples = samples.permute(1, 0)
x = samples[:self.input_horizon, :]
y = samples[self.input_horizon:, 0]
return x, y
def __len__(self):
return self.outdoor_temperature.shape[0]
class ProcessContext:
def __init__(self, dataset: OutdoorTemperatureDataset):
self.input_horizon = dataset.input_horizon
self.output_horizon = dataset.output_horizon
train_size = max(1,
min(len(dataset) - 1,
math.ceil(len(dataset) * 0.9)))
val_size = len(dataset) - train_size
assert train_size > 0
assert val_size > 0
self.train_dataset, self.val_dataset = torch.utils.data.random_split(
Subset(dataset, range(0, (train_size + val_size))),
[train_size, val_size])
def get_train_dataset(self):
return self.train_dataset
def get_validation_dataset(self):
return self.val_dataset
def get_model(self, learning_rate: float, max_epochs: int, hidden_layer_size: int, batch_size: int):
return Predictor(self.train_dataset, self.val_dataset, self.input_horizon, self.output_horizon,
learning_rate=learning_rate, max_epochs=max_epochs,
hidden_layer_size=hidden_layer_size, batch_size=batch_size)
class Predictor(LightningModule):
def __init__(self, train_dataset: Dataset, val_dataset: Dataset, input_horizon: int, output_horizon: int,
learning_rate: float, max_epochs: int, hidden_layer_size: int, batch_size: int):
super(Predictor, self).__init__()
self.train_dataset = train_dataset
self.val_dataset = val_dataset
self.input_horizon = input_horizon
self.output_horizon = output_horizon
self.total_horizon = self.input_horizon + self.output_horizon
self.max_epochs = max_epochs
self.learning_rate = learning_rate
self.hidden_layer_size = hidden_layer_size
self.input_norm = nn.LayerNorm(1)
self.layer_norm = nn.LayerNorm(self.hidden_layer_size)
self.lstm = nn.LSTM(1, self.hidden_layer_size, 1)
self.linear_layer = nn.Linear(self.hidden_layer_size, self.output_horizon)
self.loss_function = F.mse_loss
self.batch_size = batch_size
self.limits = torch.finfo(torch.float16)
for module in self.modules():
module.register_full_backward_hook(self._output_grads)
def _output_grads(self, module, grad_input, grad_output):
print(f"\nmodule={module}")
print(f"grad_input={grad_input}")
print(f"grad_output={grad_output}")
def train_dataloader(self):
return DataLoader(dataset=self.train_dataset, batch_size=self.batch_size, shuffle=True,
pin_memory=True)
def val_dataloader(self):
return DataLoader(dataset=self.val_dataset, batch_size=self.batch_size, pin_memory=True)
def configure_optimizers(self):
return optim.Adam(self.parameters(), lr=self.learning_rate)
def forward(self, input):
output = self.input_norm(input)
# Input shape is [batch, sequence, feature] but lstm/gru expects [sequence, batch, feature]
output = output.permute(1, 0, 2)
output, _ = self.lstm(output)
# Extract the hidden layer of the last element of the sequence
output = output[-1, :, :]
output = F.relu(output)
output = self.layer_norm(output)
output = self.linear_layer(output)
return output
def backward(self, loss: Tensor, optimizer: Optional[Optimizer], optimizer_idx: Optional[int], *args,
**kwargs) -> None:
super().backward(loss, optimizer, optimizer_idx, *args, **kwargs)
def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
input, expected = batch
actual = self(input)
loss = self.loss_function(actual, expected)
return torch.clamp(loss, self.limits.min, self.limits.max)
def validation_step(self, batch, batch_index) -> Optional[STEP_OUTPUT]:
input, expected = batch
actual = self(input)
loss = self.loss_function(actual, expected)
return torch.clamp(loss, self.limits.min, self.limits.max)
def train(dataset: OutdoorTemperatureDataset, learning_rate: float, max_epochs: int,
hidden_layer_size: int) -> float:
process_context = ProcessContext(dataset)
model = process_context.get_model(learning_rate, max_epochs, hidden_layer_size, dataset.batch_size)
model.learning_rate = learning_rate
trainer = Trainer(gpus=-1, benchmark=not DETERMINISTIC, precision=16, weights_summary=None,
max_epochs=max_epochs, deterministic=DETERMINISTIC, num_sanity_val_steps=0,
gradient_clip_val=1.0)
trainer.fit(model)
return trainer.logged_metrics["val_loss"]
def main():
if DETERMINISTIC:
# https://pytorch.org/docs/stable/notes/randomness.html
pl.seed_everything(DETERMINISTIC_SEED, workers=True)
torch.use_deterministic_algorithms(True)
# https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#torch.nn.LSTM
os.putenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8")
LEARNING_RATE = 0.7320207424172239
MAX_EPOCHS = 1000
HIDDEN_LAYER_SIZE = 2
batch_size = 1
dataset = OutdoorTemperatureDataset(batch_size)
while True:
try:
train(dataset, LEARNING_RATE, MAX_EPOCHS, HIDDEN_LAYER_SIZE)
except RuntimeError as e:
message = repr(e)
if "CUDNN_STATUS_EXECUTION_FAILED" in message or "CUDA out of memory" in message:
print(message)
batch_size = batch_size // 2
if batch_size <= 0:
raise e
print(f"Reducing batch_size to {batch_size}")
else:
raise e
if __name__ == "__main__":
main()This issue might be related to #9694 but I'm not sure.