@@ -175,10 +175,7 @@ def to_mx(
175175
176176 # For now, calculate the scale in floating point.
177177 # TODO(future) audit if there is a need to bit shift exponents instead.
178- scale_fp = torch .pow (
179- torch .full (max_abs .size (), 2.0 , device = scale_e8m0_biased .device ),
180- scale_e8m0_unbiased ,
181- )
178+ scale_fp = torch .exp2 (scale_e8m0_unbiased ).to (torch .float32 )
182179
183180 # Today, 2**-127 returns 0 in compile+inductor+triton because it is in the
184181 # float32 denormal range. For now, manually adjust the fp scale. This is
@@ -233,14 +230,10 @@ def to_mx(
233230
234231
235232def get_fp_scale (scale_e8m0 ):
236- s_offset = scale_e8m0 .to (torch .int16 ) - E8M0_EXPONENT_BIAS
237- # TODO(later): it would be nice if there was a way to do the 2^x operation
238- # in PyTorch without creating a tensor of twos
239- two = torch .full (s_offset .size (), 2.0 , device = scale_e8m0 .device )
240- # pow(two, s_offset) can be out of range of floating point formats.
241233 # TODO(later): handle this for float16 if we decide to support float16
242234 # scales.
243- s_fp = torch .pow (two , s_offset )
235+ s_offset = scale_e8m0 .to (torch .int16 ) - E8M0_EXPONENT_BIAS
236+ s_fp = torch .exp2 (s_offset )
244237
245238 # If a block exponent was 255, set values of that block to NaN
246239 s_fp = torch .where (scale_e8m0 != E8M0_EXPONENT_NAN_VAL , s_fp , float ("nan" ))
0 commit comments