Skip to content

Commit 22e34bd

Browse files
lint
1 parent 1955a40 commit 22e34bd

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

torchao/prototype/scaled_grouped_mm/kernels/jagged_float8_scales.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,11 @@ def _triton_fp8_row_major_jagged_rowwise_scales(
160160
data = tl.load(input_ptr + block_offs, mask=block_mask, other=0.0).to(
161161
input_dtype
162162
)
163-
# we need to cast back to input dtype since triton promotes bf16 to fp32:
163+
# we need to cast back to input dtype since triton promotes bf16 to fp32:
164164
# https://github.com/triton-lang/triton/blob/981e987eed9053b952f81153bc0779c99d8c642e/python/triton/language/standard.py#L173
165-
amax_buffer = tl.maximum(amax_buffer, tl.max(tl.abs(data), axis=1)).to(input_dtype)
165+
amax_buffer = tl.maximum(amax_buffer, tl.max(tl.abs(data), axis=1)).to(
166+
input_dtype
167+
)
166168

167169
# compute rowwise scales for this group. round scales to nearest power of 2.
168170
amax_buffer = amax_buffer.to(tl.float64)
@@ -319,9 +321,11 @@ def _triton_fp8_col_major_jagged_colwise_scales(
319321
data = tl.load(input_ptr + block_offs, mask=block_mask, other=0.0).to(
320322
input_dtype
321323
)
322-
# we need to cast back to input dtype since triton promotes bf16 to fp32:
324+
# we need to cast back to input dtype since triton promotes bf16 to fp32:
323325
# https://github.com/triton-lang/triton/blob/981e987eed9053b952f81153bc0779c99d8c642e/python/triton/language/standard.py#L173
324-
amax_buffer = tl.maximum(amax_buffer, tl.max(tl.abs(data), axis=0)).to(input_dtype)
326+
amax_buffer = tl.maximum(amax_buffer, tl.max(tl.abs(data), axis=0)).to(
327+
input_dtype
328+
)
325329

326330
# compute rowwise scales for this group.
327331
amax_buffer = amax_buffer.to(tl.float64)

torchao/prototype/scaled_grouped_mm/scaled_grouped_mm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Optional, Tuple
7+
from typing import Optional
88

99
import torch
1010

0 commit comments

Comments
 (0)