Skip to content

Float8Linear does not support autocast #568

@vkuzo

Description

@vkuzo

from @yitzhaklevi

The issue is caused due to the fact that Float8Linear captures the input dtype (via -> https://github.com/pytorch-labs/float8_experimental/blob/main/float8_experimental/float8_linear.py#L303) , And later we have this assert (during sync_float8_amax_and_scale_history - https://github.com/pytorch-labs/float8_experimental/blob/main/float8_experimental/float8_linear_utils.py#L247) that causes the failure.

One trivial solution would be to use https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.custom_fwd (with cast_inputs=torch.get_autocast_gpu_dtype())

The following script reproduces the issue (run without args) and the trivial solution (add --wrap_linear_layer)

import torch
from float8_experimental.float8_linear_utils import (
    sync_float8_amax_and_scale_history,
)
from float8_experimental.float8_linear_utils import (
    swap_linear_with_float8_linear
)
from float8_experimental.float8_linear import Float8Linear as BaseFloat8Linear
from torch import get_autocast_gpu_dtype
from torch.cuda.amp import custom_fwd
import argparse


def get_args():
    p = argparse.ArgumentParser()
    p.add_argument('--wrap_linear_layer', dest="wrap_linear_layer", action="store_true")
    return p.parse_args()


class SimpleModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(16, 16)
        self.l2 = torch.nn.Linear(16, 32)
        # norm layer is just an example - but can be any layer that outputs float32 regardless of autocast settings
        self.norm_layer = torch.nn.LayerNorm(32) 
        self.l3 = torch.nn.Linear(32, 16)

    def forward(self, x):
        # x is still float32
        x = self.l1(x)
        # x is now bfloat16
        x = self.l2(x)
        # x is still bfloat16
        x = self.norm_layer(x)
        # x is now float32 (since the output of norm layer is float32 regardless of autocast settings)
        x = self.l3(x)
        # x is now bfloat16
        return x


if __name__ == '__main__':
    args = get_args()
    m = SimpleModel().to('cuda')
    if args.wrap_linear_layer:
        class Float8Linear(BaseFloat8Linear):
            @custom_fwd(cast_inputs=get_autocast_gpu_dtype())
            def forward(self, *args, **kwargs):
                return super().forward(*args, **kwargs)
    else:
        Float8Linear = BaseFloat8Linear

    swap_linear_with_float8_linear(m, Float8Linear)
    b = torch.rand([17, 16]).to('cuda')
    with torch.amp.autocast(enabled=True, device_type='cuda', dtype=torch.bfloat16):
        out = m(b)
        sync_float8_amax_and_scale_history(m)
    print('Done !')

copied from meta-pytorch/float8_experimental#257

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions