Skip to content

Commit 1955a40

Browse files
add triton kernels for float8 quantization with jagged rowwise scales
lint update docstrings add bench script add bench script bench against compile comment clean up fix masks lint integrate triton kernels into scaled grouped mm lint
1 parent 9af2a45 commit 1955a40

File tree

8 files changed

+303
-144
lines changed

8 files changed

+303
-144
lines changed
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
7+
8+
import itertools
9+
import time
10+
from dataclasses import dataclass
11+
from typing import List
12+
13+
import torch
14+
from tabulate import tabulate
15+
from tqdm import tqdm
16+
17+
from torchao.prototype.scaled_grouped_mm import _scaled_grouped_mm
18+
19+
device = torch.device("cuda")
20+
21+
# Needed since changing args to function causes recompiles
22+
torch._dynamo.config.cache_size_limit = 1000
23+
24+
25+
@dataclass(frozen=True)
26+
class ExperimentConfig:
27+
high_precision_dtype: torch.dtype
28+
A_shape: tuple[int]
29+
B_shape: tuple[int]
30+
31+
32+
@dataclass(frozen=True)
33+
class ExperimentResult:
34+
time_us: float
35+
36+
37+
@dataclass(frozen=True)
38+
class Experiment:
39+
config: ExperimentConfig
40+
result: ExperimentResult
41+
42+
43+
def get_configs() -> List[ExperimentConfig]:
44+
A_shapes = [(2**8, 4096), (2**12, 4096), (2**16, 4096)]
45+
B_shapes = [(4, 4096, 4096), (8, 4096, 4096), (16, 4096, 4096)]
46+
high_precision_dtypes = [torch.bfloat16]
47+
configs = []
48+
for A_shape, B_shape, high_precision_dtype in itertools.product(
49+
A_shapes, B_shapes, high_precision_dtypes
50+
):
51+
configs.append(
52+
ExperimentConfig(
53+
A_shape=A_shape,
54+
B_shape=B_shape,
55+
high_precision_dtype=high_precision_dtype,
56+
)
57+
)
58+
return configs
59+
60+
61+
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
62+
# define test inputs
63+
A = torch.randn(
64+
*config.A_shape,
65+
dtype=config.high_precision_dtype,
66+
device=device,
67+
requires_grad=True,
68+
)
69+
B_t = torch.randn(
70+
*config.B_shape,
71+
dtype=config.high_precision_dtype,
72+
device=device,
73+
requires_grad=True,
74+
).transpose(-2, -1)
75+
76+
# - configure input to be row-major with groups divided along the column dimension,
77+
# representing the left operand of grad_weight = grad_output_t @ input
78+
# that occurs in the backward pass of the differentiable scaled grouped mm.
79+
# - the transposed tensor in col-major format with groups along the row dimension,
80+
# which represents the right operand.
81+
n_groups = config.B_shape[0]
82+
group_size = A.shape[0] // n_groups
83+
offs = torch.arange(
84+
group_size,
85+
group_size * n_groups + 1,
86+
group_size,
87+
device=device,
88+
dtype=torch.int32,
89+
)
90+
91+
def warmup(func, *args, **kwargs):
92+
for _ in range(10):
93+
func(*args, **kwargs)
94+
95+
def forward_backward(A, B_t, offs):
96+
out = _scaled_grouped_mm(A, B_t, offs=offs, out_dtype=torch.bfloat16)
97+
out.sum().backward()
98+
99+
# bench triton
100+
warmup(forward_backward, A, B_t, offs)
101+
start_time_ns = time.perf_counter_ns()
102+
forward_backward(A, B_t, offs)
103+
time_ns = time.perf_counter_ns() - start_time_ns
104+
time_us = time_ns / 1e3
105+
106+
return ExperimentResult(time_us=time_us)
107+
108+
109+
def print_results(experiments: List[Experiment]):
110+
headers = [
111+
"A_shape",
112+
"B_shape",
113+
"high_precision_dtype",
114+
"time_us",
115+
]
116+
rows = []
117+
for experiment in experiments:
118+
A_shape = f"({experiment.config.A_shape[0]}, {experiment.config.A_shape[1]})"
119+
B_shape = f"({experiment.config.B_shape[0]}, {experiment.config.B_shape[1]}, {experiment.config.B_shape[2]})"
120+
rows.append(
121+
[
122+
A_shape,
123+
B_shape,
124+
experiment.config.high_precision_dtype,
125+
experiment.result.time_us,
126+
]
127+
)
128+
print(tabulate(rows, headers=headers))
129+
130+
131+
def main():
132+
torch.random.manual_seed(123)
133+
configs = get_configs()
134+
results = []
135+
for config in tqdm(configs):
136+
result = run_experiment(config)
137+
results.append(Experiment(config=config, result=result))
138+
139+
# Use Tabulate to print results
140+
print_results(results)
141+
142+
143+
if __name__ == "__main__":
144+
main()

torchao/prototype/scaled_grouped_mm/kernels/benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
triton_fp8_col_major_jagged_colwise_scales,
1919
triton_fp8_row_major_jagged_rowwise_scales,
2020
)
21-
from torchao.prototype.scaled_grouped_mm.scaled_grouped_mm import (
21+
from torchao.prototype.scaled_grouped_mm.test.utils import (
2222
_to_2d_jagged_float8_tensor_colwise,
2323
_to_2d_jagged_float8_tensor_rowwise,
2424
)

torchao/prototype/scaled_grouped_mm/kernels/jagged_float8_scales.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,9 @@ def _triton_fp8_row_major_jagged_rowwise_scales(
160160
data = tl.load(input_ptr + block_offs, mask=block_mask, other=0.0).to(
161161
input_dtype
162162
)
163-
amax_buffer = tl.maximum(amax_buffer, tl.max(tl.abs(data), axis=1))
163+
# we need to cast back to input dtype since triton promotes bf16 to fp32:
164+
# https://github.com/triton-lang/triton/blob/981e987eed9053b952f81153bc0779c99d8c642e/python/triton/language/standard.py#L173
165+
amax_buffer = tl.maximum(amax_buffer, tl.max(tl.abs(data), axis=1)).to(input_dtype)
164166

165167
# compute rowwise scales for this group. round scales to nearest power of 2.
166168
amax_buffer = amax_buffer.to(tl.float64)
@@ -317,7 +319,9 @@ def _triton_fp8_col_major_jagged_colwise_scales(
317319
data = tl.load(input_ptr + block_offs, mask=block_mask, other=0.0).to(
318320
input_dtype
319321
)
320-
amax_buffer = tl.maximum(amax_buffer, tl.max(tl.abs(data), axis=0))
322+
# we need to cast back to input dtype since triton promotes bf16 to fp32:
323+
# https://github.com/triton-lang/triton/blob/981e987eed9053b952f81153bc0779c99d8c642e/python/triton/language/standard.py#L173
324+
amax_buffer = tl.maximum(amax_buffer, tl.max(tl.abs(data), axis=0)).to(input_dtype)
321325

322326
# compute rowwise scales for this group.
323327
amax_buffer = amax_buffer.to(tl.float64)

torchao/prototype/scaled_grouped_mm/kernels/test_jagged_float8_scales.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
triton_fp8_col_major_jagged_colwise_scales,
1212
triton_fp8_row_major_jagged_rowwise_scales,
1313
)
14-
from torchao.prototype.scaled_grouped_mm.scaled_grouped_mm import (
14+
from torchao.prototype.scaled_grouped_mm.test.utils import (
1515
_to_2d_jagged_float8_tensor_colwise,
1616
_to_2d_jagged_float8_tensor_rowwise,
1717
)

torchao/prototype/scaled_grouped_mm/scaled_grouped_mm.py

Lines changed: 9 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010

1111
from torchao.float8.config import ScalingGranularity
1212
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
13+
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
14+
triton_fp8_col_major_jagged_colwise_scales,
15+
triton_fp8_row_major_jagged_rowwise_scales,
16+
)
1317
from torchao.prototype.scaled_grouped_mm.utils import _is_column_major
1418

1519

@@ -189,17 +193,18 @@ def backward(ctx, grad_output: torch.Tensor):
189193
# grad_B is a special case. both operands of the grouped gemm will be 2D with offsets determing the "groups."
190194
# Compute scales for grad_output_t and A, which are both 2D tensors with offsets which define the "jagged" groups.
191195
grad_output_t_fp8_row_major, grad_output_t_scales = (
192-
_to_2d_jagged_float8_tensor_rowwise(
196+
triton_fp8_row_major_jagged_rowwise_scales(
193197
grad_output_t_row_major,
194198
offs,
195-
target_dtype=torch.float8_e4m3fn,
199+
output_dtype=torch.float8_e4m3fn,
196200
round_scales_to_power_of_2=True,
197201
)
198202
)
199-
A_fp8_col_major, A_scales = _to_2d_jagged_float8_tensor_colwise(
203+
204+
A_fp8_col_major, A_scales = triton_fp8_col_major_jagged_colwise_scales(
200205
A_col_major,
201206
offs,
202-
target_dtype=torch.float8_e4m3fn,
207+
output_dtype=torch.float8_e4m3fn,
203208
round_scales_to_power_of_2=True,
204209
)
205210

@@ -216,139 +221,3 @@ def backward(ctx, grad_output: torch.Tensor):
216221
use_fast_accum=True,
217222
)
218223
return grad_A, grad_B.transpose(-2, -1), None, None, None, None
219-
220-
221-
def _to_2d_jagged_float8_tensor_colwise(
222-
A_col_major: torch.Tensor,
223-
offs: torch.Tensor,
224-
target_dtype: torch.dtype = torch.float8_e4m3fn,
225-
round_scales_to_power_of_2: bool = False,
226-
) -> Tuple[torch.Tensor, torch.Tensor]:
227-
"""
228-
This function converts the 2D input tensor A to a jagged float8 tensor,
229-
with scales computed along *logical columns* for each group individually,
230-
where groups are determined based on the offsets.
231-
232-
For the right operand of a normal scaled GEMM, the rowwise scales are computed over logical columns.
233-
(i.e., a tensor of (K,N) will have scales of shape (1,N).
234-
235-
However, for a 2D right operand of a grouped GEMM, these logical columns go through multiple distinct
236-
groups/subtensors, for which we want to compute scales individually. So we cannot take one set of scales
237-
along the logical columns and apply it to the entire tensor.
238-
239-
Instead, we need to compute scales for each subtensor individually. For a tensor of shape (K,N) this results
240-
in scales of shape (1,N * num_groups).
241-
242-
Args:
243-
A (torch.Tensor): The input tensor to be converted to a jagged float8 tensor.
244-
245-
Returns:
246-
A tuple containing the jagged float8 tensor and the scales used for the conversion.
247-
"""
248-
assert A_col_major.ndim == 2, "A must be 2D"
249-
250-
num_groups = offs.numel()
251-
A_fp8_col_major = torch.empty_like(A_col_major, dtype=target_dtype)
252-
A_scales = torch.empty(
253-
A_fp8_col_major.size(1) * num_groups,
254-
dtype=torch.float32,
255-
device=A_fp8_col_major.device,
256-
)
257-
258-
start_idx = 0
259-
next_scale_idx = 0
260-
for end_idx in offs.tolist():
261-
# Get the subtensor of A for this group, fetching the next group of rows, with all columns for each.
262-
subtensor = A_col_major[start_idx:end_idx, :] # (local_group_size, K)
263-
264-
# Compute local rowwise scales for this subtensor, which are along logical columns for the right operand.
265-
subtensor_scales = tensor_to_scale(
266-
subtensor,
267-
target_dtype,
268-
scaling_granularity=ScalingGranularity.AXISWISE,
269-
axiswise_dim=0,
270-
round_scales_to_power_of_2=round_scales_to_power_of_2,
271-
)
272-
273-
# Apply scales to subtensor and convert to float8.
274-
tensor_scaled = subtensor.to(torch.float32) * subtensor_scales
275-
float8_subtensor = to_fp8_saturated(tensor_scaled, target_dtype)
276-
277-
# Store this portion of the resulting float8 tensor and scales.
278-
A_fp8_col_major[start_idx:end_idx, :] = float8_subtensor
279-
A_scales[next_scale_idx : next_scale_idx + subtensor_scales.numel()] = (
280-
subtensor_scales.squeeze()
281-
)
282-
283-
# Update start index for next group.
284-
start_idx = end_idx
285-
next_scale_idx += subtensor_scales.numel()
286-
287-
return A_fp8_col_major, A_scales
288-
289-
290-
def _to_2d_jagged_float8_tensor_rowwise(
291-
x: torch.Tensor,
292-
offs: torch.Tensor,
293-
target_dtype: torch.dtype,
294-
round_scales_to_power_of_2: bool = False,
295-
) -> Tuple[torch.Tensor, torch.Tensor]:
296-
"""
297-
This function converts the 2D input tensor to a jagged float8 tensor,
298-
with scales computed along *logical rows* for each group individually,
299-
where groups are determined based on the offsets.
300-
301-
For a 2D *left* operand of a normal scaled GEMM, the rowwise scales are computed over logical rows.
302-
(i.e., a tensor of (M,K) will have scales of shape (M,1).
303-
304-
However, for a 2D left operand of a grouped GEMM, these logical rows go through multiple distinct
305-
groups/subtensors, for which we want to compute scales individually. So we cannot take one set of scales
306-
along the logical rows and apply it to the entire tensor.
307-
308-
Instead, we need to compute scales for each subtensor individually. For a tensor of shape (M,K) this results
309-
in scales of shape (M * num_groups, 1).
310-
311-
Args:
312-
A (torch.Tensor): The input tensor to be converted to a jagged float8 tensor.
313-
314-
Returns:
315-
A tuple containing the jagged float8 tensor and the scales used for the conversion.
316-
"""
317-
assert x.ndim == 2, "input tensor must be 2D"
318-
319-
num_groups = offs.numel()
320-
x_fp8 = torch.empty_like(x, dtype=target_dtype)
321-
x_scales = torch.empty(
322-
x_fp8.size(0) * num_groups, dtype=torch.float32, device=x_fp8.device
323-
)
324-
325-
start_idx = 0
326-
next_scale_idx = 0
327-
for end_idx in offs.tolist():
328-
# Get the subtensor of A for this group, fetching all rows with the next group of rows.
329-
subtensor = x[:, start_idx:end_idx] # (M, local_group_size)
330-
331-
# Compute local rowwise scales for this subtensor, which are along logical rows for the left operand.
332-
subtensor_scales = tensor_to_scale(
333-
subtensor,
334-
target_dtype,
335-
scaling_granularity=ScalingGranularity.AXISWISE,
336-
axiswise_dim=-1,
337-
round_scales_to_power_of_2=round_scales_to_power_of_2,
338-
)
339-
340-
# Apply scales to subtensor and convert to float8.
341-
tensor_scaled = subtensor.to(torch.float32) * subtensor_scales
342-
float8_subtensor = to_fp8_saturated(tensor_scaled, target_dtype)
343-
344-
# Store this portion of the resulting float8 tensor and scales.
345-
x_fp8[:, start_idx:end_idx] = float8_subtensor
346-
x_scales[next_scale_idx : next_scale_idx + subtensor_scales.numel()] = (
347-
subtensor_scales.squeeze()
348-
)
349-
350-
# Update start index for next group.
351-
start_idx = end_idx
352-
next_scale_idx += subtensor_scales.numel()
353-
354-
return x_fp8, x_scales

torchao/prototype/scaled_grouped_mm/test/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)