Skip to content

Commit 914de78

Browse files
authored
Revert "Use exp2 for mx scaling" (#1813)
Revert "Use exp2 for mx scaling (#1530)" This reverts commit 890e0ac.
1 parent 55600a1 commit 914de78

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

torchao/prototype/mx_formats/custom_cast.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,20 @@
1212
_f32_to_floatx_unpacked,
1313
_floatx_unpacked_to_f32,
1414
)
15+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4
16+
17+
# TODO(future): if needed, make the below work on previous PyTorch versions,
18+
# just need to hunt down the previous location of `libdevice`. An assert
19+
# at the callsite prevents usage of this on unsupported versions.
20+
if TORCH_VERSION_AT_LEAST_2_4 and has_triton():
21+
from torch._inductor.runtime.triton_helpers import libdevice
22+
1523
from torchao.prototype.mx_formats.constants import (
1624
E8M0_EXPONENT_BIAS,
1725
E8M0_EXPONENT_NAN_VAL,
1826
F4_E2M1_EXP_BIAS,
1927
F32_EXP_BIAS,
2028
)
21-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4
2229

2330

2431
def get_bits(x: torch.Tensor) -> str:
@@ -287,8 +294,8 @@ def triton_f4_to_scaled_bf16_kernel(
287294
s = tl.load(s_ptr + offsets_s, mask=mask_s)
288295

289296
# create the scale in bf16
290-
# S is already biased by 127, so we just have to shift it to align w/ bf16
291-
s_fp = (s.to(tl.uint16) << 7).to(tl.bfloat16, bitcast=True)
297+
s_offset = s.to(tl.int16) - e8m0_exponent_bias
298+
s_fp = libdevice.pow(2.0, s_offset).to(tl.bfloat16)
292299
s_fp = tl.where(s != e8m0_exponent_nan_val, s_fp, float("nan"))
293300

294301
# multiply output by scale

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,10 @@ def to_mx(
175175

176176
# For now, calculate the scale in floating point.
177177
# TODO(future) audit if there is a need to bit shift exponents instead.
178-
scale_fp = torch.exp2(scale_e8m0_unbiased).to(torch.float32)
178+
scale_fp = torch.pow(
179+
torch.full(max_abs.size(), 2.0, device=scale_e8m0_biased.device),
180+
scale_e8m0_unbiased,
181+
)
179182

180183
# Today, 2**-127 returns 0 in compile+inductor+triton because it is in the
181184
# float32 denormal range. For now, manually adjust the fp scale. This is
@@ -230,10 +233,14 @@ def to_mx(
230233

231234

232235
def get_fp_scale(scale_e8m0):
236+
s_offset = scale_e8m0.to(torch.int16) - E8M0_EXPONENT_BIAS
237+
# TODO(later): it would be nice if there was a way to do the 2^x operation
238+
# in PyTorch without creating a tensor of twos
239+
two = torch.full(s_offset.size(), 2.0, device=scale_e8m0.device)
240+
# pow(two, s_offset) can be out of range of floating point formats.
233241
# TODO(later): handle this for float16 if we decide to support float16
234242
# scales.
235-
s_offset = scale_e8m0.to(torch.int16) - E8M0_EXPONENT_BIAS
236-
s_fp = torch.exp2(s_offset)
243+
s_fp = torch.pow(two, s_offset)
237244

238245
# If a block exponent was 255, set values of that block to NaN
239246
s_fp = torch.where(scale_e8m0 != E8M0_EXPONENT_NAN_VAL, s_fp, float("nan"))

0 commit comments

Comments
 (0)