-
Notifications
You must be signed in to change notification settings - Fork 369
Open
Labels
Description
from @yitzhaklevi
The issue is caused due to the fact that Float8Linear captures the input dtype (via -> https://github.com/pytorch-labs/float8_experimental/blob/main/float8_experimental/float8_linear.py#L303) , And later we have this assert (during sync_float8_amax_and_scale_history - https://github.com/pytorch-labs/float8_experimental/blob/main/float8_experimental/float8_linear_utils.py#L247) that causes the failure.
One trivial solution would be to use https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.custom_fwd (with cast_inputs=torch.get_autocast_gpu_dtype())
The following script reproduces the issue (run without args) and the trivial solution (add --wrap_linear_layer)
import torch
from float8_experimental.float8_linear_utils import (
sync_float8_amax_and_scale_history,
)
from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear
)
from float8_experimental.float8_linear import Float8Linear as BaseFloat8Linear
from torch import get_autocast_gpu_dtype
from torch.cuda.amp import custom_fwd
import argparse
def get_args():
p = argparse.ArgumentParser()
p.add_argument('--wrap_linear_layer', dest="wrap_linear_layer", action="store_true")
return p.parse_args()
class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(16, 16)
self.l2 = torch.nn.Linear(16, 32)
# norm layer is just an example - but can be any layer that outputs float32 regardless of autocast settings
self.norm_layer = torch.nn.LayerNorm(32)
self.l3 = torch.nn.Linear(32, 16)
def forward(self, x):
# x is still float32
x = self.l1(x)
# x is now bfloat16
x = self.l2(x)
# x is still bfloat16
x = self.norm_layer(x)
# x is now float32 (since the output of norm layer is float32 regardless of autocast settings)
x = self.l3(x)
# x is now bfloat16
return x
if __name__ == '__main__':
args = get_args()
m = SimpleModel().to('cuda')
if args.wrap_linear_layer:
class Float8Linear(BaseFloat8Linear):
@custom_fwd(cast_inputs=get_autocast_gpu_dtype())
def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)
else:
Float8Linear = BaseFloat8Linear
swap_linear_with_float8_linear(m, Float8Linear)
b = torch.rand([17, 16]).to('cuda')
with torch.amp.autocast(enabled=True, device_type='cuda', dtype=torch.bfloat16):
out = m(b)
sync_float8_amax_and_scale_history(m)
print('Done !')
copied from meta-pytorch/float8_experimental#257