Skip to content

Commit 6238167

Browse files
committed
Use exp2 for mx scaling
stack-info: PR: #1530, branch: drisspg/stack/26
1 parent 8259a38 commit 6238167

File tree

2 files changed

+6
-20
lines changed

2 files changed

+6
-20
lines changed

torchao/prototype/mx_formats/custom_cast.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,13 @@
1212
_f32_to_floatx_unpacked,
1313
_floatx_unpacked_to_f32,
1414
)
15-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4
16-
17-
# TODO(future): if needed, make the below work on previous PyTorch versions,
18-
# just need to hunt down the previous location of `libdevice`. An assert
19-
# at the callsite prevents usage of this on unsupported versions.
20-
if TORCH_VERSION_AT_LEAST_2_4 and has_triton():
21-
from torch._inductor.runtime.triton_helpers import libdevice
22-
2315
from torchao.prototype.mx_formats.constants import (
2416
E8M0_EXPONENT_BIAS,
2517
E8M0_EXPONENT_NAN_VAL,
2618
F4_E2M1_EXP_BIAS,
2719
F32_EXP_BIAS,
2820
)
21+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4
2922

3023

3124
def get_bits(x: torch.Tensor) -> str:
@@ -294,8 +287,8 @@ def triton_f4_to_scaled_bf16_kernel(
294287
s = tl.load(s_ptr + offsets_s, mask=mask_s)
295288

296289
# create the scale in bf16
297-
s_offset = s.to(tl.int16) - e8m0_exponent_bias
298-
s_fp = libdevice.pow(2.0, s_offset).to(tl.bfloat16)
290+
# S is already biased by 127, so we just have to shift it to align w/ bf16
291+
s_fp = (s.to(tl.uint16) << 7).to(tl.bfloat16, bitcast=True)
299292
s_fp = tl.where(s != e8m0_exponent_nan_val, s_fp, float("nan"))
300293

301294
# multiply output by scale

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,7 @@ def to_mx(
127127

128128
# For now, calculate the scale in floating point.
129129
# TODO(future) audit if there is a need to bit shift exponents instead.
130-
scale_fp = torch.pow(
131-
torch.full(max_abs.size(), 2.0, device=scale_e8m0_biased.device),
132-
scale_e8m0_unbiased,
133-
)
130+
scale_fp = torch.exp2(scale_e8m0_unbiased).to(torch.float32)
134131

135132
# Today, 2**-127 returns 0 in compile+inductor+triton because it is in the
136133
# float32 denormal range. For now, manually adjust the fp scale. This is
@@ -176,14 +173,10 @@ def to_mx(
176173

177174

178175
def get_fp_scale(scale_e8m0):
179-
s_offset = scale_e8m0.to(torch.int16) - E8M0_EXPONENT_BIAS
180-
# TODO(later): it would be nice if there was a way to do the 2^x operation
181-
# in PyTorch without creating a tensor of twos
182-
two = torch.full(s_offset.size(), 2.0, device=scale_e8m0.device)
183-
# pow(two, s_offset) can be out of range of floating point formats.
184176
# TODO(later): handle this for float16 if we decide to support float16
185177
# scales.
186-
s_fp = torch.pow(two, s_offset)
178+
s_offset = scale_e8m0.to(torch.int16) - E8M0_EXPONENT_BIAS
179+
s_fp = torch.exp2(s_offset)
187180

188181
# If a block exponent was 255, set values of that block to NaN
189182
s_fp = torch.where(scale_e8m0 != E8M0_EXPONENT_NAN_VAL, s_fp, float("nan"))

0 commit comments

Comments
 (0)