@@ -175,7 +175,10 @@ def to_mx(
175175
176176 # For now, calculate the scale in floating point.
177177 # TODO(future) audit if there is a need to bit shift exponents instead.
178- scale_fp = torch .exp2 (scale_e8m0_unbiased ).to (torch .float32 )
178+ scale_fp = torch .pow (
179+ torch .full (max_abs .size (), 2.0 , device = scale_e8m0_biased .device ),
180+ scale_e8m0_unbiased ,
181+ )
179182
180183 # Today, 2**-127 returns 0 in compile+inductor+triton because it is in the
181184 # float32 denormal range. For now, manually adjust the fp scale. This is
@@ -230,10 +233,14 @@ def to_mx(
230233
231234
232235def get_fp_scale (scale_e8m0 ):
236+ s_offset = scale_e8m0 .to (torch .int16 ) - E8M0_EXPONENT_BIAS
237+ # TODO(later): it would be nice if there was a way to do the 2^x operation
238+ # in PyTorch without creating a tensor of twos
239+ two = torch .full (s_offset .size (), 2.0 , device = scale_e8m0 .device )
240+ # pow(two, s_offset) can be out of range of floating point formats.
233241 # TODO(later): handle this for float16 if we decide to support float16
234242 # scales.
235- s_offset = scale_e8m0 .to (torch .int16 ) - E8M0_EXPONENT_BIAS
236- s_fp = torch .exp2 (s_offset )
243+ s_fp = torch .pow (two , s_offset )
237244
238245 # If a block exponent was 255, set values of that block to NaN
239246 s_fp = torch .where (scale_e8m0 != E8M0_EXPONENT_NAN_VAL , s_fp , float ("nan" ))
0 commit comments