Skip to content

Commit d252612

Browse files
authored
Revert "[ROCm] use dataclass for fnuz type setting" (#1148)
Revert "[ROCm] use dataclass for fnuz type setting (#1142)" This reverts commit eb1fb3a.
1 parent eb1fb3a commit d252612

File tree

4 files changed

+38
-55
lines changed

4 files changed

+38
-55
lines changed

test/float8/test_base.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424

2525

2626
from torchao.float8.config import (
27-
CastConfig,
28-
Float8LinearConfig,
27+
CastConfig,
28+
Float8LinearConfig,
2929
ScalingGranularity,
3030
ScalingType,
3131
Float8LinearRecipeName,
@@ -109,15 +109,15 @@ def test_split_cat(self):
109109

110110
def test_index_put(self):
111111
a = torch.rand(16, dtype=torch.bfloat16)
112-
scale_a = tensor_to_scale(a, e4m3_dtype)
113-
fp8_a = hp_tensor_and_scale_to_float8(a, scale_a, e4m3_dtype)
112+
scale_a = tensor_to_scale(a, torch.float8_e4m3fn)
113+
fp8_a = hp_tensor_and_scale_to_float8(a, scale_a, torch.float8_e4m3fn)
114114

115115
index = torch.randint(0, 15, (16,), dtype=torch.long)
116116

117117
b = torch.rand(16, 16, dtype=torch.bfloat16)
118-
scale_b = tensor_to_scale(b, e4m3_dtype)
119-
fp8_b = hp_tensor_and_scale_to_float8(b, scale_a, e4m3_dtype)
120-
fp8_b_bad = hp_tensor_and_scale_to_float8(b, scale_b, e4m3_dtype)
118+
scale_b = tensor_to_scale(b, torch.float8_e4m3fn)
119+
fp8_b = hp_tensor_and_scale_to_float8(b, scale_a, torch.float8_e4m3fn)
120+
fp8_b_bad = hp_tensor_and_scale_to_float8(b, scale_b, torch.float8_e4m3fn)
121121

122122
with pytest.raises(AssertionError):
123123
b[index] = fp8_a
@@ -127,8 +127,8 @@ def test_index_put(self):
127127

128128
def test_copy_(self):
129129
a = torch.rand(16, dtype=torch.bfloat16)
130-
scale_a = tensor_to_scale(a, e4m3_dtype)
131-
fp8_a = hp_tensor_and_scale_to_float8(a, scale_a, e4m3_dtype)
130+
scale_a = tensor_to_scale(a, torch.float8_e4m3fn)
131+
fp8_a = hp_tensor_and_scale_to_float8(a, scale_a, torch.float8_e4m3fn)
132132

133133
b = torch.empty(16, dtype=torch.bfloat16)
134134
b.copy_(fp8_a) # Should work
@@ -137,7 +137,7 @@ def test_copy_(self):
137137
fp8_a.copy_(b) # Should fail
138138

139139
fp8_b = Float8Tensor(
140-
torch.empty(16, dtype=e4m3_dtype),
140+
torch.empty(16, dtype=torch.float8_e4m3fn),
141141
scale_a,
142142
torch.bfloat16,
143143
fp8_a._linear_mm_config,
@@ -332,11 +332,11 @@ def _test_linear_impl(
332332
@pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True])
333333
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
334334
@pytest.mark.parametrize(
335-
"scaling_type_input",
335+
"scaling_type_input",
336336
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
337337
)
338338
@pytest.mark.parametrize(
339-
"scaling_type_weight",
339+
"scaling_type_weight",
340340
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
341341
)
342342
@pytest.mark.parametrize(
@@ -377,7 +377,7 @@ def test_linear_from_config_params(
377377
# to combine with the main testing function.
378378
# TODO(future PR): make this cleaner.
379379
@pytest.mark.parametrize(
380-
"recipe_name",
380+
"recipe_name",
381381
[Float8LinearRecipeName.ALL_AXISWISE, Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP],
382382
)
383383
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
@@ -610,7 +610,7 @@ def test_different_configs_error(self):
610610
@pytest.mark.parametrize("use_fast_accum", [True, False])
611611
def test_pad_inner_dim(self, base_dtype, use_fast_accum):
612612
torch.manual_seed(42)
613-
input_dtype = e4m3_dtype
613+
input_dtype = torch.float8_e4m3fn
614614
compare_type = torch.float32
615615

616616
a = torch.randn(16, 41, device="cuda", dtype=base_dtype)

test/float8/test_compile.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
import torch
2121
import torch.nn as nn
2222
from torchao.float8.config import (
23-
CastConfig,
24-
Float8LinearConfig,
25-
ScalingType,
23+
CastConfig,
24+
Float8LinearConfig,
25+
ScalingType,
2626
Float8LinearRecipeName,
2727
recipe_name_to_linear_config,
2828
)
@@ -77,7 +77,7 @@ def _test_compile_base(
7777
y_fp8.sum().backward()
7878
y_ref = m_ref(x_ref)
7979
y_ref.sum().backward()
80-
# TODO(future PR): can also test fp8 eager vs compile here with a tigher
80+
# TODO(future PR): can also test fp8 eager vs compile here with a tigher
8181
# tolerance
8282
torch.testing.assert_close(y_fp8, y_ref, atol=9.5e-2, rtol=9.5e-2)
8383
torch.testing.assert_close(
@@ -199,7 +199,7 @@ def test_inductor_from_config_params(
199199
# to combine with the main testing function.
200200
# TODO(future PR): make this cleaner.
201201
@pytest.mark.parametrize(
202-
"recipe_name",
202+
"recipe_name",
203203
[Float8LinearRecipeName.ALL_AXISWISE, Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP],
204204
)
205205
@unittest.skipIf(not is_H100, "CUDA with capability 9.0 or greater not available")
@@ -412,14 +412,14 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype):
412412
)
413413
float8_eager = hp_tensor_to_float8_dynamic(
414414
hp_tensor1,
415-
e4m3_dtype,
415+
torch.float8_e4m3fn,
416416
linear_mm_config,
417417
gemm_input_role=GemmInputRole.WEIGHT,
418418
)
419419
torch._dynamo.reset()
420420
float8_compile = torch.compile(hp_tensor_to_float8_dynamic)(
421421
hp_tensor2,
422-
e4m3_dtype,
422+
torch.float8_e4m3fn,
423423
linear_mm_config,
424424
gemm_input_role=GemmInputRole.WEIGHT,
425425
)

torchao/float8/config.py

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -96,29 +96,6 @@ def __post_init__(self):
9696
), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now."
9797

9898

99-
@dataclass
100-
class Float8TypeConfig:
101-
"""
102-
Configuration for selecting the preferred float8 type pair, either e4m3fn/e5m2 or e4m3fnuz/e5m2fnuz.
103-
104-
Currently, ROCm only supports fnuz variants.
105-
"""
106-
107-
# The preferred e4m3 type.
108-
e4m3_dtype = torch.float8_e4m3fn
109-
110-
# The preferred e5m2 type.
111-
e5m2_dtype = torch.float8_e5m2
112-
113-
def __post_init__(self):
114-
if torch.version.hip:
115-
prop = torch.cuda.get_device_properties(0)
116-
MI300_ARCH = ("gfx940", "gfx941", "gfx942")
117-
if prop.gcnArchName.split(":")[0] in MI300_ARCH:
118-
self.e4m3_dtype = torch.float8_e4m3fnuz
119-
self.e5m2_dtype = torch.float8_e5m2fnuz
120-
121-
12299
@dataclass(frozen=True)
123100
class Float8GemmConfig:
124101
"""
@@ -141,11 +118,11 @@ class Float8LinearConfig:
141118
# Per-tensor configuration for casting of `input`, `weight`, `grad_output`
142119
# for the operands of gemms calculating `output`, `grad_weight`, and `grad_input`.
143120
#
144-
# Note:
145-
# 1. if `cast_config_input_for_grad_weight` is None, then
121+
# Note:
122+
# 1. if `cast_config_input_for_grad_weight` is None, then
146123
# `cast_config_input` is used for scaling `input` for both gemms that
147-
# use `input.
148-
# 2. if `cast_config_input_for_grad_weight` is specified, then
124+
# use `input.
125+
# 2. if `cast_config_input_for_grad_weight` is specified, then
149126
# a. `cast_config_input` is used for scaling `input` for the gemm that calculates
150127
# `output`
151128
# b. `cast_config_input_for_grad_weight` is used for scaling `input` for
@@ -263,6 +240,12 @@ def __post_init__(self):
263240
f"incompatible operand precision for {gemm_name}"
264241

265242

243+
# If True, use 'fnuz' float8 types for calculations.
244+
# Currently, ROCm only supports fnuz variants.
245+
# TODO(future PR): move this to Float8LinearConfig
246+
use_fnuz_dtype = False
247+
248+
266249
# Pre-made recipes for common configurations
267250
# TODO(future PR): go through a round of design on this, and eventually expose
268251
# as a top level public API.
@@ -289,7 +272,7 @@ def recipe_name_to_linear_config(
289272
cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
290273
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
291274
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
292-
275+
293276
# The current rowwise CUTLASS kernels in `torch._scaled_mm` are only
294277
# fast with `use_fast_accum=True`. Note that rowwise scaling is more
295278
# accurate than tensorwise scaling, so the overall impact on accuracy
@@ -317,8 +300,8 @@ def recipe_name_to_linear_config(
317300
#
318301
# key characteristics:
319302
# * increased accuracy for grad_weight
320-
# * `input`, `weight` and `grad_output` now only need to be scaled
321-
# axiswise across a single dim compared to vanilla all-axiswise,
303+
# * `input`, `weight` and `grad_output` now only need to be scaled
304+
# axiswise across a single dim compared to vanilla all-axiswise,
322305
# which is more amenable to fast kernels
323306

324307
# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1

torchao/float8/float8_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
import torch
1010
import torch.distributed as dist
1111

12-
from torchao.float8.config import Float8TypeConfig, ScalingGranularity
12+
import torchao.float8.config as config
13+
from torchao.float8.config import ScalingGranularity
1314

1415
# Helpful visualizer for debugging (only supports fp32):
1516
# https://www.h-schmidt.net/FloatConverter/IEEE754.html
@@ -28,9 +29,8 @@
2829

2930

3031
# User defined type for using the individual F8 type based on config
31-
type_config = Float8TypeConfig()
32-
e4m3_dtype = type_config.e4m3_dtype
33-
e5m2_dtype = type_config.e5m2_dtype
32+
e4m3_dtype = torch.float8_e4m3fn if not config.use_fnuz_dtype else torch.float8_e4m3fnuz
33+
e5m2_dtype = torch.float8_e5m2 if not config.use_fnuz_dtype else torch.float8_e5m2fnuz
3434

3535

3636
@torch.no_grad()

0 commit comments

Comments
 (0)