@@ -160,9 +160,11 @@ def _triton_fp8_row_major_jagged_rowwise_scales(
160160 data = tl .load (input_ptr + block_offs , mask = block_mask , other = 0.0 ).to (
161161 input_dtype
162162 )
163- # we need to cast back to input dtype since triton promotes bf16 to fp32:
163+ # we need to cast back to input dtype since triton promotes bf16 to fp32:
164164 # https://github.com/triton-lang/triton/blob/981e987eed9053b952f81153bc0779c99d8c642e/python/triton/language/standard.py#L173
165- amax_buffer = tl .maximum (amax_buffer , tl .max (tl .abs (data ), axis = 1 )).to (input_dtype )
165+ amax_buffer = tl .maximum (amax_buffer , tl .max (tl .abs (data ), axis = 1 )).to (
166+ input_dtype
167+ )
166168
167169 # compute rowwise scales for this group. round scales to nearest power of 2.
168170 amax_buffer = amax_buffer .to (tl .float64 )
@@ -319,9 +321,11 @@ def _triton_fp8_col_major_jagged_colwise_scales(
319321 data = tl .load (input_ptr + block_offs , mask = block_mask , other = 0.0 ).to (
320322 input_dtype
321323 )
322- # we need to cast back to input dtype since triton promotes bf16 to fp32:
324+ # we need to cast back to input dtype since triton promotes bf16 to fp32:
323325 # https://github.com/triton-lang/triton/blob/981e987eed9053b952f81153bc0779c99d8c642e/python/triton/language/standard.py#L173
324- amax_buffer = tl .maximum (amax_buffer , tl .max (tl .abs (data ), axis = 0 )).to (input_dtype )
326+ amax_buffer = tl .maximum (amax_buffer , tl .max (tl .abs (data ), axis = 0 )).to (
327+ input_dtype
328+ )
325329
326330 # compute rowwise scales for this group.
327331 amax_buffer = amax_buffer .to (tl .float64 )
0 commit comments