Skip to content

Commit ba6f428

Browse files
authored
[CPU] add Float8OpaqueTensor for dynamic float8 act float8 weight (#3075)
* [CPU] add Float8OpaqueTensor for dynamic float8 act float8 weight * Update _normalize_granularity * Update torchao/quantization/quant_api.py * Fix CI * remove unnecessary changes * Refine code * Refine code * Refine code
1 parent 0292cb8 commit ba6f428

File tree

6 files changed

+530
-40
lines changed

6 files changed

+530
-40
lines changed
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import tempfile
8+
import unittest
9+
10+
import torch
11+
from torch.testing._internal import common_utils
12+
from torch.testing._internal.common_utils import (
13+
TestCase,
14+
run_tests,
15+
)
16+
17+
from torchao import quantize_
18+
from torchao.quantization import (
19+
Float8DynamicActivationFloat8WeightConfig,
20+
PerGroup,
21+
PerRow,
22+
PerTensor,
23+
)
24+
from torchao.quantization.utils import compute_error
25+
from torchao.testing.model_architectures import ToyTwoLinearModel
26+
from torchao.utils import (
27+
torch_version_at_least,
28+
)
29+
30+
31+
def get_config(granularity):
32+
return Float8DynamicActivationFloat8WeightConfig(
33+
activation_dtype=torch.float8_e4m3fn,
34+
granularity=granularity,
35+
float8_packing_format="opaque",
36+
)
37+
38+
39+
@common_utils.instantiate_parametrized_tests
40+
class TestFloat8OpaqueTensor(TestCase):
41+
"""Test cases for Float8OpaqueTensor on CPU"""
42+
43+
def setUp(self):
44+
torch.set_grad_enabled(False)
45+
46+
@unittest.skipIf(
47+
"CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"),
48+
reason="cpp kernels not built",
49+
)
50+
@unittest.skipIf(not torch_version_at_least("2.6.0"), "Test only enabled for 2.6+")
51+
@common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half])
52+
@common_utils.parametrize("x_dim", [2, 3])
53+
@common_utils.parametrize("bias", [True, False])
54+
@common_utils.parametrize("bs", [1, 160])
55+
@common_utils.parametrize(
56+
"x_granularity",
57+
[PerTensor(), PerRow(), PerGroup(32), PerGroup(64), PerGroup(128)],
58+
)
59+
@common_utils.parametrize(
60+
"w_granularity",
61+
[PerTensor(), PerRow(), PerGroup(32), PerGroup(64), PerGroup(128)],
62+
)
63+
def test_dynamic_float8_linear(
64+
self, dtype, x_dim, bias, bs, x_granularity, w_granularity
65+
):
66+
if isinstance(x_granularity, PerGroup):
67+
if not isinstance(w_granularity, PerGroup):
68+
return
69+
if w_granularity.group_size != x_granularity.group_size:
70+
return
71+
device = "cpu"
72+
m = ToyTwoLinearModel(256, 256, 256, dtype, device, bias).eval()
73+
example_inputs = m.example_inputs(batch_size=bs)
74+
if x_dim == 3:
75+
example_inputs = (example_inputs[0].unsqueeze(0),)
76+
y = m(*example_inputs)
77+
78+
quantize_(
79+
m,
80+
get_config([x_granularity, w_granularity]),
81+
)
82+
y1 = m(*example_inputs)
83+
assert compute_error(y, y1) > 20
84+
y2, code = torch._inductor.utils.run_and_get_code(
85+
torch.compile(m, fullgraph=True, dynamic=True),
86+
*example_inputs,
87+
)
88+
# ensure the expected op is in the code
89+
assert "torch.ops.torchao.float8_linear_cpu.default" in code[0]
90+
assert compute_error(y, y2) > 20
91+
92+
@unittest.skipIf(
93+
"CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"),
94+
reason="cpp kernels not built",
95+
)
96+
@unittest.skipIf(not torch_version_at_least("2.6.0"), "Test only enabled for 2.6+")
97+
@common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half])
98+
@common_utils.parametrize("x_dim", [2, 3])
99+
@common_utils.parametrize("bias", [True, False])
100+
@common_utils.parametrize("bs", [4, 128])
101+
def test_dynamic_float8_linear_fallback_path(self, dtype, x_dim, bias, bs):
102+
"""
103+
Test the fallback implementation with a shape that is not supported by the optimized kernel
104+
"""
105+
device = "cpu"
106+
m = ToyTwoLinearModel(120, 120, 120, dtype, device, bias).eval()
107+
example_inputs = m.example_inputs(batch_size=bs)
108+
if x_dim == 3:
109+
example_inputs = (example_inputs[0].unsqueeze(0),)
110+
y = m(*example_inputs)
111+
112+
quantize_(
113+
m,
114+
get_config(PerRow()),
115+
)
116+
y1 = m(*example_inputs)
117+
assert compute_error(y, y1) > 20
118+
y2, code = torch._inductor.utils.run_and_get_code(
119+
torch.compile(m, fullgraph=True, dynamic=True),
120+
*example_inputs,
121+
)
122+
# ensure the expected op is in the code
123+
assert "torch.ops.torchao.float8_linear_cpu.default" in code[0]
124+
assert compute_error(y, y2) > 20
125+
126+
@unittest.skipIf(
127+
"CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"),
128+
reason="cpp kernels not built",
129+
)
130+
@common_utils.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
131+
def test_module_path(self, dtype):
132+
linear = torch.nn.Linear(128, 256, dtype=dtype)
133+
quantize_(linear, get_config(PerRow()))
134+
self.assertEqual(
135+
str(type(linear.weight)),
136+
"<class 'torchao.quantization.Float8OpaqueTensor'>",
137+
)
138+
139+
with tempfile.NamedTemporaryFile() as f:
140+
torch.save(linear.state_dict(), f)
141+
f.seek(0)
142+
state_dict = torch.load(f)
143+
self.assertEqual(
144+
str(type(state_dict["weight"])),
145+
"<class 'torchao.quantization.Float8OpaqueTensor'>",
146+
)
147+
148+
149+
if __name__ == "__main__":
150+
run_tests()

torchao/quantization/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
quantize_affine,
9393
)
9494
from .quantize_.workflows import (
95+
Float8OpaqueTensor,
9596
Float8Tensor,
9697
Int4MarlinSparseTensor,
9798
Int4OpaqueTensor,
@@ -174,6 +175,7 @@
174175
"Int4TilePackedTo4dTensor",
175176
"Float8Tensor",
176177
"Int4OpaqueTensor",
178+
"Float8OpaqueTensor",
177179
# smooth quant - subject to change
178180
"get_scale",
179181
"SmoothFakeDynQuantMixin",

torchao/quantization/quant_api.py

Lines changed: 68 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@
7373
KernelPreference,
7474
)
7575
from torchao.quantization.quantize_.workflows import (
76+
Float8OpaqueTensor,
77+
Float8PackingFormat,
7678
Float8Tensor,
7779
Int4ChooseQParamsAlgorithm,
7880
Int4MarlinSparseTensor,
@@ -1774,14 +1776,23 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
17741776
kernel_preference: KernelPreference = KernelPreference.AUTO
17751777
set_inductor_config: bool = True
17761778
version: int = 2
1779+
float8_packing_format: Float8PackingFormat = Float8PackingFormat.PLAIN
17771780

17781781
def __post_init__(self):
17791782
torch._C._log_api_usage_once(
17801783
"torchao.quantization.Float8DynamicActivationFloat8WeightConfig"
17811784
)
1782-
activation_granularity, weight_granularity = _normalize_granularity(
1783-
self.granularity
1784-
)
1785+
if (
1786+
self.version == 2
1787+
and self.float8_packing_format == Float8PackingFormat.OPAQUE
1788+
):
1789+
activation_granularity, weight_granularity = (
1790+
Float8OpaqueTensor._normalize_and_check_granularity(self.granularity)
1791+
)
1792+
else:
1793+
activation_granularity, weight_granularity = _normalize_granularity(
1794+
self.granularity
1795+
)
17851796
self.granularity = [activation_granularity, weight_granularity]
17861797

17871798
default_use_fast_accum = True
@@ -1811,44 +1822,48 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
18111822
activation_value_lb = config.activation_value_lb
18121823
activation_value_ub = config.activation_value_ub
18131824
kernel_preference = config.kernel_preference
1825+
float8_packing_format = config.float8_packing_format
18141826

18151827
# Ensure works on device
1816-
_check_hardware_support(granularity)
18171828
activation_granularity, weight_granularity = granularity
18181829

1819-
# Note: right now we assume it's weights of conv2d and conv3d purely based
1820-
# on the dimension of weight, currently there is no conflict with linear 2d
1821-
# and moe weights 3d
1822-
# if we need to support conv1d, which also has 3d weight, we may have to
1823-
# pass around the module as well to distinguish between conv1d and 3d moe weight
1824-
if weight.dim() in [4, 5]:
1825-
# weights for conv2d or 3d
1826-
assert isinstance(activation_granularity, PerTensor) and isinstance(
1827-
weight_granularity, PerTensor
1828-
), "4D/5D tensor only supports per tensor activation and weight quantization"
1829-
1830-
# conv3d weight dim: (C_out, C_in, K1, K2, K3)
1831-
# conv2d weight dim: (C_out, C_in, K1, K2)
1832-
# skip quantization when either C_out or C_in
1833-
# is not a multiple of 16
1834-
if weight.shape[0] % 16 != 0 or weight.shape[1] % 16 != 0:
1835-
return weight
1830+
if float8_packing_format == Float8PackingFormat.PLAIN:
1831+
# Note: right now we assume it's weights of conv2d and conv3d purely based
1832+
# on the dimension of weight, currently there is no conflict with linear 2d
1833+
# and moe weights 3d
1834+
# if we need to support conv1d, which also has 3d weight, we may have to
1835+
# pass around the module as well to distinguish between conv1d and 3d moe weight
1836+
if weight.dim() in [4, 5]:
1837+
# weights for conv2d or 3d
1838+
assert isinstance(activation_granularity, PerTensor) and isinstance(
1839+
weight_granularity, PerTensor
1840+
), (
1841+
"4D/5D tensor only supports per tensor activation and weight quantization"
1842+
)
18361843

1837-
elif not _fp8_mm_compat(weight):
1838-
# TODO(future PR): this should really throw an exception instead of silently
1839-
# not doing what the user asked
1840-
return weight
1844+
# conv3d weight dim: (C_out, C_in, K1, K2, K3)
1845+
# conv2d weight dim: (C_out, C_in, K1, K2)
1846+
# skip quantization when either C_out or C_in
1847+
# is not a multiple of 16
1848+
if weight.shape[0] % 16 != 0 or weight.shape[1] % 16 != 0:
1849+
return weight
18411850

1842-
if isinstance(weight_granularity, PerRow):
1843-
assert weight.dtype == torch.bfloat16, (
1844-
"PerRow quantization only works for bfloat16 precision input weight"
1845-
)
1851+
elif not _fp8_mm_compat(weight):
1852+
# TODO(future PR): this should really throw an exception instead of silently
1853+
# not doing what the user asked
1854+
return weight
1855+
1856+
if isinstance(weight_granularity, PerRow):
1857+
assert weight.dtype == torch.bfloat16, (
1858+
"PerRow quantization only works for bfloat16 precision input weight"
1859+
)
18461860

18471861
if config.version == 1:
18481862
warnings.warn(
18491863
"Config Deprecation: version 1 of Float8DynamicActivationFloat8WeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2649 for more details"
18501864
)
18511865

1866+
_check_hardware_support(granularity)
18521867
block_size = get_block_size(weight.shape[-2:], weight_granularity)
18531868
if weight.dim() == 3:
18541869
block_size = tuple([1] + list(block_size))
@@ -1879,14 +1894,26 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
18791894
kernel_preference=kernel_preference,
18801895
)
18811896

1882-
quantized_weight = Float8Tensor.from_hp(
1883-
weight,
1884-
float8_dtype=weight_dtype,
1885-
granularity=weight_granularity,
1886-
mm_config=mm_config,
1887-
kernel_preference=kernel_preference,
1888-
act_quant_kwargs=act_quant_kwargs,
1889-
)
1897+
if float8_packing_format == Float8PackingFormat.PLAIN:
1898+
quantized_weight = Float8Tensor.from_hp(
1899+
weight,
1900+
float8_dtype=weight_dtype,
1901+
granularity=weight_granularity,
1902+
mm_config=mm_config,
1903+
kernel_preference=kernel_preference,
1904+
act_quant_kwargs=act_quant_kwargs,
1905+
)
1906+
elif float8_packing_format == Float8PackingFormat.OPAQUE:
1907+
block_size = get_block_size(weight.shape, weight_granularity)
1908+
quantized_weight = Float8OpaqueTensor.from_hp(
1909+
weight,
1910+
block_size=block_size,
1911+
act_quant_kwargs=act_quant_kwargs,
1912+
)
1913+
else:
1914+
raise ValueError(
1915+
f"Unsupported float8 packing format: {float8_packing_format}"
1916+
)
18901917

18911918
return quantized_weight
18921919

@@ -1898,9 +1925,10 @@ def _float8_dynamic_activation_float8_weight_transform(
18981925
*,
18991926
parameter_name: str = "weight",
19001927
):
1901-
assert is_sm_at_least_89() or is_MI300(), (
1902-
"Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+"
1903-
)
1928+
if config.float8_packing_format == Float8PackingFormat.PLAIN:
1929+
assert is_sm_at_least_89() or is_MI300(), (
1930+
"Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+"
1931+
)
19041932
if config.set_inductor_config:
19051933
torchao.quantization.utils.recommended_inductor_config_setter()
19061934

torchao/quantization/quantize_/workflows/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
from .float8.float8_opaque_tensor import (
2+
Float8OpaqueTensor,
3+
)
4+
from .float8.float8_packing_format import Float8PackingFormat
15
from .float8.float8_tensor import (
26
Float8Tensor,
37
QuantizeTensorToFloat8Kwargs,
@@ -37,7 +41,9 @@
3741
"Int4MarlinSparseTensor",
3842
"Int4PlainInt32Tensor",
3943
"Int4TilePackedTo4dTensor",
44+
"Float8OpaqueTensor",
4045
"Float8Tensor",
46+
"Float8PackingFormat",
4147
"QuantizeTensorToFloat8Kwargs",
4248
"Int4OpaqueTensor",
4349
"Int4ChooseQParamsAlgorithm",

0 commit comments

Comments
 (0)