Skip to content

Commit bb5ddb3

Browse files
reorganize
1 parent 22e34bd commit bb5ddb3

File tree

7 files changed

+150
-151
lines changed

7 files changed

+150
-151
lines changed

torchao/prototype/scaled_grouped_mm/kernels/benchmark.py renamed to torchao/prototype/scaled_grouped_mm/benchmarks/benchmark_kernels.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
triton_fp8_col_major_jagged_colwise_scales,
1919
triton_fp8_row_major_jagged_rowwise_scales,
2020
)
21-
from torchao.prototype.scaled_grouped_mm.test.utils import (
21+
from torchao.prototype.scaled_grouped_mm.utils import (
2222
_to_2d_jagged_float8_tensor_colwise,
2323
_to_2d_jagged_float8_tensor_rowwise,
2424
)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
2+
triton_fp8_col_major_jagged_colwise_scales as triton_fp8_col_major_jagged_colwise_scales,
3+
)
4+
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
5+
triton_fp8_row_major_jagged_rowwise_scales as triton_fp8_row_major_jagged_rowwise_scales,
6+
)

torchao/prototype/scaled_grouped_mm/scaled_grouped_mm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from torchao.float8.config import ScalingGranularity
1212
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
13-
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
13+
from torchao.prototype.scaled_grouped_mm.kernels import (
1414
triton_fp8_col_major_jagged_colwise_scales,
1515
triton_fp8_row_major_jagged_rowwise_scales,
1616
)
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
triton_fp8_col_major_jagged_colwise_scales,
1212
triton_fp8_row_major_jagged_rowwise_scales,
1313
)
14-
from torchao.prototype.scaled_grouped_mm.test.utils import (
14+
from torchao.prototype.scaled_grouped_mm.utils import (
15+
_is_column_major,
1516
_to_2d_jagged_float8_tensor_colwise,
1617
_to_2d_jagged_float8_tensor_rowwise,
1718
)
18-
from torchao.prototype.scaled_grouped_mm.utils import _is_column_major
1919

2020

2121
@pytest.mark.parametrize("round_scales_to_power_of_2", [True, False])

torchao/prototype/scaled_grouped_mm/test/utils.py

Lines changed: 0 additions & 142 deletions
This file was deleted.

torchao/prototype/scaled_grouped_mm/utils.py

Lines changed: 140 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,146 @@
1-
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
# All rights reserved.
3-
#
4-
# This source code is licensed under the BSD 3-Clause license found in the
5-
# LICENSE file in the root directory of this source tree.
1+
from typing import Tuple
62

73
import torch
84

5+
from torchao.float8.config import ScalingGranularity
6+
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
7+
8+
9+
def _to_2d_jagged_float8_tensor_colwise(
10+
A_col_major: torch.Tensor,
11+
offs: torch.Tensor,
12+
target_dtype: torch.dtype = torch.float8_e4m3fn,
13+
round_scales_to_power_of_2: bool = False,
14+
) -> Tuple[torch.Tensor, torch.Tensor]:
15+
"""
16+
This function converts the 2D input tensor A to a jagged float8 tensor,
17+
with scales computed along *logical columns* for each group individually,
18+
where groups are determined based on the offsets.
19+
20+
For the right operand of a normal scaled GEMM, the rowwise scales are computed over logical columns.
21+
(i.e., a tensor of (K,N) will have scales of shape (1,N).
22+
23+
However, for a 2D right operand of a grouped GEMM, these logical columns go through multiple distinct
24+
groups/subtensors, for which we want to compute scales individually. So we cannot take one set of scales
25+
along the logical columns and apply it to the entire tensor.
26+
27+
Instead, we need to compute scales for each subtensor individually. For a tensor of shape (K,N) this results
28+
in scales of shape (1,N * num_groups).
29+
30+
Args:
31+
A (torch.Tensor): The input tensor to be converted to a jagged float8 tensor.
32+
33+
Returns:
34+
A tuple containing the jagged float8 tensor and the scales used for the conversion.
35+
"""
36+
assert A_col_major.ndim == 2, "A must be 2D"
37+
38+
num_groups = offs.numel()
39+
A_fp8_col_major = torch.empty_like(A_col_major, dtype=target_dtype)
40+
A_scales = torch.empty(
41+
A_fp8_col_major.size(1) * num_groups,
42+
dtype=torch.float32,
43+
device=A_fp8_col_major.device,
44+
)
45+
46+
start_idx = 0
47+
next_scale_idx = 0
48+
for end_idx in offs.tolist():
49+
# Get the subtensor of A for this group, fetching the next group of rows, with all columns for each.
50+
subtensor = A_col_major[start_idx:end_idx, :] # (local_group_size, K)
51+
52+
# Compute local rowwise scales for this subtensor, which are along logical columns for the right operand.
53+
subtensor_scales = tensor_to_scale(
54+
subtensor,
55+
target_dtype,
56+
scaling_granularity=ScalingGranularity.AXISWISE,
57+
axiswise_dim=0,
58+
round_scales_to_power_of_2=round_scales_to_power_of_2,
59+
)
60+
61+
# Apply scales to subtensor and convert to float8.
62+
tensor_scaled = subtensor.to(torch.float32) * subtensor_scales
63+
float8_subtensor = to_fp8_saturated(tensor_scaled, target_dtype)
64+
65+
# Store this portion of the resulting float8 tensor and scales.
66+
A_fp8_col_major[start_idx:end_idx, :] = float8_subtensor
67+
A_scales[next_scale_idx : next_scale_idx + subtensor_scales.numel()] = (
68+
subtensor_scales.squeeze()
69+
)
70+
71+
# Update start index for next group.
72+
start_idx = end_idx
73+
next_scale_idx += subtensor_scales.numel()
74+
75+
return A_fp8_col_major, A_scales
76+
77+
78+
def _to_2d_jagged_float8_tensor_rowwise(
79+
x: torch.Tensor,
80+
offs: torch.Tensor,
81+
target_dtype: torch.dtype,
82+
round_scales_to_power_of_2: bool = False,
83+
) -> Tuple[torch.Tensor, torch.Tensor]:
84+
"""
85+
This function converts the 2D input tensor to a jagged float8 tensor,
86+
with scales computed along *logical rows* for each group individually,
87+
where groups are determined based on the offsets.
88+
89+
For a 2D *left* operand of a normal scaled GEMM, the rowwise scales are computed over logical rows.
90+
(i.e., a tensor of (M,K) will have scales of shape (M,1).
91+
92+
However, for a 2D left operand of a grouped GEMM, these logical rows go through multiple distinct
93+
groups/subtensors, for which we want to compute scales individually. So we cannot take one set of scales
94+
along the logical rows and apply it to the entire tensor.
95+
96+
Instead, we need to compute scales for each subtensor individually. For a tensor of shape (M,K) this results
97+
in scales of shape (M * num_groups, 1).
98+
99+
Args:
100+
A (torch.Tensor): The input tensor to be converted to a jagged float8 tensor.
101+
102+
Returns:
103+
A tuple containing the jagged float8 tensor and the scales used for the conversion.
104+
"""
105+
assert x.ndim == 2, "input tensor must be 2D"
106+
107+
num_groups = offs.numel()
108+
x_fp8 = torch.empty_like(x, dtype=target_dtype)
109+
x_scales = torch.empty(
110+
x_fp8.size(0) * num_groups, dtype=torch.float32, device=x_fp8.device
111+
)
112+
113+
start_idx = 0
114+
next_scale_idx = 0
115+
for end_idx in offs.tolist():
116+
# Get the subtensor of A for this group, fetching all rows with the next group of rows.
117+
subtensor = x[:, start_idx:end_idx] # (M, local_group_size)
118+
119+
# Compute local rowwise scales for this subtensor, which are along logical rows for the left operand.
120+
subtensor_scales = tensor_to_scale(
121+
subtensor,
122+
target_dtype,
123+
scaling_granularity=ScalingGranularity.AXISWISE,
124+
axiswise_dim=-1,
125+
round_scales_to_power_of_2=round_scales_to_power_of_2,
126+
)
127+
128+
# Apply scales to subtensor and convert to float8.
129+
tensor_scaled = subtensor.to(torch.float32) * subtensor_scales
130+
float8_subtensor = to_fp8_saturated(tensor_scaled, target_dtype)
131+
132+
# Store this portion of the resulting float8 tensor and scales.
133+
x_fp8[:, start_idx:end_idx] = float8_subtensor
134+
x_scales[next_scale_idx : next_scale_idx + subtensor_scales.numel()] = (
135+
subtensor_scales.squeeze()
136+
)
137+
138+
# Update start index for next group.
139+
start_idx = end_idx
140+
next_scale_idx += subtensor_scales.numel()
141+
142+
return x_fp8, x_scales
143+
9144

10145
def _is_column_major(x: torch.Tensor) -> bool:
11146
"""

0 commit comments

Comments
 (0)