@@ -127,10 +127,7 @@ def to_mx(
127127
128128 # For now, calculate the scale in floating point.
129129 # TODO(future) audit if there is a need to bit shift exponents instead.
130- scale_fp = torch .pow (
131- torch .full (max_abs .size (), 2.0 , device = scale_e8m0_biased .device ),
132- scale_e8m0_unbiased ,
133- )
130+ scale_fp = torch .exp2 (scale_e8m0_unbiased ).to (torch .float32 )
134131
135132 # Today, 2**-127 returns 0 in compile+inductor+triton because it is in the
136133 # float32 denormal range. For now, manually adjust the fp scale. This is
@@ -177,13 +174,10 @@ def to_mx(
177174
178175def get_fp_scale (scale_e8m0 ):
179176 s_offset = scale_e8m0 .to (torch .int16 ) - E8M0_EXPONENT_BIAS
180- # TODO(later): it would be nice if there was a way to do the 2^x operation
181- # in PyTorch without creating a tensor of twos
182- two = torch .full (s_offset .size (), 2.0 , device = scale_e8m0 .device )
183- # pow(two, s_offset) can be out of range of floating point formats.
177+
184178 # TODO(later): handle this for float16 if we decide to support float16
185179 # scales.
186- s_fp = torch .pow ( two , s_offset )
180+ s_fp = torch .exp2 ( s_offset )
187181
188182 # If a block exponent was 255, set values of that block to NaN
189183 s_fp = torch .where (scale_e8m0 != E8M0_EXPONENT_NAN_VAL , s_fp , float ("nan" ))
0 commit comments