Skip to content

AMP scaler always causes backwards pass to overflow #9799

@cowwoc

Description

@cowwoc

Looking at the testcase found below and the discussion at https://discuss.pytorch.org/t/layernorms-grads-become-nan-after-first-epoch/133292/9 it sounds increasingly likely that there is a bug in version Lightning 1.4.9's AMP implementation.

My goal is to find out why the layer immediately proceeding the loss function ends up with a Inf gradient. I see the loss function returning a value of 1.4138 so I was surprised to see an overflow.

If you walk through the code between the end of forward() and the beginning of backward() you will see that grad_scaler.py:165 gets invoked: return outputs * self._scale.to(device=outputs.device, non_blocking=True) where self._scale multiplies the loss value by 65536. When using float16 tensors, I believe this operation is responsible for the overflow, but maybe you know something I don't...

Testcase:

import math
import os
from typing import Optional, Tuple

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.utilities.types import STEP_OUTPUT
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import DataLoader, Subset, Dataset

# Bug report: https://github.com/pytorch/pytorch/issues/65301
DETERMINISTIC = True
DETERMINISTIC_SEED = 41
# Debugging "one of the variables needed for gradient computation has been modified by an inplace operation"
torch.autograd.set_detect_anomaly(True)


class OutdoorTemperatureDataset(Dataset):
    def __init__(self, batch_size: int):
        self.batch_size = batch_size
        self.input_horizon = 1
        self.output_horizon = 2
        self.total_horizon = self.input_horizon + self.output_horizon
        self.outdoor_temperature = torch.tensor([1.0]).repeat(2, self.total_horizon)

    def __getitem__(self, index) -> Tuple[Tensor, Tensor]:
        samples = torch.stack([self.outdoor_temperature[index]])
        # Convert [features, samples] to [samples, features]
        samples = samples.permute(1, 0)
        x = samples[:self.input_horizon, :]
        y = samples[self.input_horizon:, 0]
        return x, y

    def __len__(self):
        return self.outdoor_temperature.shape[0]


class ProcessContext:
    def __init__(self, dataset: OutdoorTemperatureDataset):
        self.input_horizon = dataset.input_horizon
        self.output_horizon = dataset.output_horizon
        train_size = max(1,
                         min(len(dataset) - 1,
                             math.ceil(len(dataset) * 0.9)))
        val_size = len(dataset) - train_size
        assert train_size > 0
        assert val_size > 0
        self.train_dataset, self.val_dataset = torch.utils.data.random_split(
            Subset(dataset, range(0, (train_size + val_size))),
            [train_size, val_size])

    def get_train_dataset(self):
        return self.train_dataset

    def get_validation_dataset(self):
        return self.val_dataset

    def get_model(self, learning_rate: float, max_epochs: int, hidden_layer_size: int, batch_size: int):
        return Predictor(self.train_dataset, self.val_dataset, self.input_horizon, self.output_horizon,
                         learning_rate=learning_rate, max_epochs=max_epochs,
                         hidden_layer_size=hidden_layer_size, batch_size=batch_size)


class Predictor(LightningModule):
    def __init__(self, train_dataset: Dataset, val_dataset: Dataset, input_horizon: int, output_horizon: int,
                 learning_rate: float, max_epochs: int, hidden_layer_size: int, batch_size: int):
        super(Predictor, self).__init__()
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.input_horizon = input_horizon
        self.output_horizon = output_horizon
        self.total_horizon = self.input_horizon + self.output_horizon
        self.max_epochs = max_epochs
        self.learning_rate = learning_rate
        self.hidden_layer_size = hidden_layer_size

        self.input_norm = nn.LayerNorm(1)
        self.layer_norm = nn.LayerNorm(self.hidden_layer_size)
        self.lstm = nn.LSTM(1, self.hidden_layer_size, 1)

        self.linear_layer = nn.Linear(self.hidden_layer_size, self.output_horizon)
        self.loss_function = F.mse_loss
        self.batch_size = batch_size
        self.limits = torch.finfo(torch.float16)
        for module in self.modules():
            module.register_full_backward_hook(self._output_grads)

    def _output_grads(self, module, grad_input, grad_output):
        print(f"\nmodule={module}")
        print(f"grad_input={grad_input}")
        print(f"grad_output={grad_output}")

    def train_dataloader(self):
        return DataLoader(dataset=self.train_dataset, batch_size=self.batch_size, shuffle=True,
                          pin_memory=True)

    def val_dataloader(self):
        return DataLoader(dataset=self.val_dataset, batch_size=self.batch_size, pin_memory=True)

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.learning_rate)

    def forward(self, input):
        output = self.input_norm(input)
        # Input shape is [batch, sequence, feature] but lstm/gru expects [sequence, batch, feature]
        output = output.permute(1, 0, 2)
        output, _ = self.lstm(output)
        # Extract the hidden layer of the last element of the sequence
        output = output[-1, :, :]
        output = F.relu(output)
        output = self.layer_norm(output)
        output = self.linear_layer(output)
        return output

    def backward(self, loss: Tensor, optimizer: Optional[Optimizer], optimizer_idx: Optional[int], *args,
                 **kwargs) -> None:
        super().backward(loss, optimizer, optimizer_idx, *args, **kwargs)

    def training_step(self, batch, batch_idx) -> STEP_OUTPUT:
        input, expected = batch

        actual = self(input)
        loss = self.loss_function(actual, expected)
        return torch.clamp(loss, self.limits.min, self.limits.max)

    def validation_step(self, batch, batch_index) -> Optional[STEP_OUTPUT]:
        input, expected = batch

        actual = self(input)
        loss = self.loss_function(actual, expected)
        return torch.clamp(loss, self.limits.min, self.limits.max)


def train(dataset: OutdoorTemperatureDataset, learning_rate: float, max_epochs: int,
          hidden_layer_size: int) -> float:
    process_context = ProcessContext(dataset)
    model = process_context.get_model(learning_rate, max_epochs, hidden_layer_size, dataset.batch_size)
    model.learning_rate = learning_rate

    trainer = Trainer(gpus=-1, benchmark=not DETERMINISTIC, precision=16, weights_summary=None,
                      max_epochs=max_epochs, deterministic=DETERMINISTIC, num_sanity_val_steps=0,
                      gradient_clip_val=1.0)
    trainer.fit(model)
    return trainer.logged_metrics["val_loss"]


def main():
    if DETERMINISTIC:
        # https://pytorch.org/docs/stable/notes/randomness.html
        pl.seed_everything(DETERMINISTIC_SEED, workers=True)
        torch.use_deterministic_algorithms(True)
        # https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#torch.nn.LSTM
        os.putenv("CUBLAS_WORKSPACE_CONFIG", ":4096:8")
    LEARNING_RATE = 0.7320207424172239
    MAX_EPOCHS = 1000
    HIDDEN_LAYER_SIZE = 2
    batch_size = 1
    dataset = OutdoorTemperatureDataset(batch_size)
    while True:
        try:
            train(dataset, LEARNING_RATE, MAX_EPOCHS, HIDDEN_LAYER_SIZE)
        except RuntimeError as e:
            message = repr(e)
            if "CUDNN_STATUS_EXECUTION_FAILED" in message or "CUDA out of memory" in message:
                print(message)
                batch_size = batch_size // 2
                if batch_size <= 0:
                    raise e
                print(f"Reducing batch_size to {batch_size}")
            else:
                raise e


if __name__ == "__main__":
    main()

This issue might be related to #9694 but I'm not sure.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinghelp wantedOpen to be worked on

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions