Skip to content

Commit aeff75b

Browse files
authored
Lint test dtypes (#1305)
1 parent 6234116 commit aeff75b

File tree

9 files changed

+369
-224
lines changed

9 files changed

+369
-224
lines changed

ruff.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ include = [
1010
"torchao/prototype/low_bit_optim/**.py",
1111
"test/float8/**/*.py",
1212
"test/quantization/test_observer.py",
13-
"test/dtypes/test_affine_quantized_float.py",
14-
"test/dtypes/test_nf4.py",
13+
"test/dtypes/**/*.py",
1514
"test/prototype/low_bit_optim/**.py",
1615
"torchao/utils.py",
1716

test/dtypes/test_affine_quantized.py

Lines changed: 56 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,24 @@
1+
import tempfile
2+
import unittest
3+
4+
import torch
5+
from torch.testing._internal import common_utils
16
from torch.testing._internal.common_utils import (
27
TestCase,
38
run_tests,
49
)
10+
11+
from torchao.dtypes import SemiSparseLayout
512
from torchao.quantization import (
13+
float8_weight_only,
614
int4_weight_only,
7-
int8_weight_only,
815
int8_dynamic_activation_int4_weight,
916
int8_dynamic_activation_int8_weight,
10-
int8_dynamic_activation_int8_semi_sparse_weight,
11-
float8_weight_only,
17+
int8_weight_only,
1218
)
1319
from torchao.quantization.quant_primitives import MappingType
14-
from torchao.dtypes import SemiSparseLayout
15-
from torch.testing._internal import common_utils
1620
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
1721

18-
import torch
19-
import unittest
20-
import tempfile
21-
2222
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
2323

2424

@@ -33,7 +33,9 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool):
3333
base_functions.append(int4_weight_only(group_size=32))
3434

3535
if do_sparse:
36-
base_functions.append(int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()))
36+
base_functions.append(
37+
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())
38+
)
3739

3840
if is_cuda_8_9:
3941
base_functions.append(float8_weight_only())
@@ -44,11 +46,11 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool):
4446
class TestAffineQuantized(TestCase):
4547
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
4648
def test_tensor_core_layout_transpose(self):
47-
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
48-
t = l.weight
49+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
50+
t = linear.weight
4951
shape = t.shape
5052
apply_int4_weight_only_quant = int4_weight_only(group_size=32)
51-
ql = apply_int4_weight_only_quant(l)
53+
ql = apply_int4_weight_only_quant(linear)
5254
aqt = ql.weight
5355
aqt_shape = aqt.shape
5456
self.assertEqual(aqt_shape, shape)
@@ -64,8 +66,8 @@ def test_tensor_core_layout_transpose(self):
6466
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
6567
@common_utils.parametrize("apply_quant", get_quantization_functions(True, True))
6668
def test_weights_only(self, apply_quant):
67-
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
68-
ql = apply_quant(l)
69+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
70+
ql = apply_quant(linear)
6971
with tempfile.NamedTemporaryFile() as f:
7072
torch.save(ql.state_dict(), f)
7173
f.seek(0)
@@ -78,33 +80,32 @@ def test_weights_only(self, apply_quant):
7880
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
7981
@common_utils.parametrize("apply_quant", get_quantization_functions(False, False))
8082
def test_to_device(self, apply_quant):
81-
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
82-
ql = apply_quant(l)
83+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
84+
ql = apply_quant(linear)
8385
ql.to("cuda")
8486

85-
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
86-
ql = apply_quant(l)
87+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
88+
ql = apply_quant(linear)
8789
ql.to(device="cuda")
8890

89-
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
90-
ql = apply_quant(l)
91+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
92+
ql = apply_quant(linear)
9193
ql.cuda()
9294

9395
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
9496
def test_register_new_dispatch(self):
97+
from torchao.dtypes import AffineQuantizedTensor, to_affine_quantized_intx
9598
from torchao.dtypes.affine_quantized_tensor_ops import (
96-
register_aqt_quantized_linear_dispatch,
9799
deregister_aqt_quantized_linear_dispatch,
100+
register_aqt_quantized_linear_dispatch,
98101
)
99-
from torchao.dtypes import to_affine_quantized_intx
100-
from torchao.dtypes import AffineQuantizedTensor
101102
from torchao.quantization.quant_primitives import MappingType
102103

103104
def dispatch_condition(input_tensor, weight_tensor, bias):
104105
return (
105-
isinstance(weight_tensor, AffineQuantizedTensor) and
106-
weight_tensor.quant_min == 0 and
107-
weight_tensor.quant_max == 2**6-1
106+
isinstance(weight_tensor, AffineQuantizedTensor)
107+
and weight_tensor.quant_min == 0
108+
and weight_tensor.quant_max == 2**6 - 1
108109
)
109110

110111
def impl(input_tensor, weight_tensor, bias):
@@ -115,23 +116,35 @@ def impl(input_tensor, weight_tensor, bias):
115116
register_aqt_quantized_linear_dispatch(dispatch_condition, impl)
116117

117118
def apply_uint6_weight_only_quant(linear):
118-
linear.weight = torch.nn.Parameter(to_affine_quantized_intx(linear.weight, MappingType.ASYMMETRIC, (1, linear.weight.shape[-1]), torch.uint8, 0, 2**6-1), requires_grad=False)
119+
linear.weight = torch.nn.Parameter(
120+
to_affine_quantized_intx(
121+
linear.weight,
122+
MappingType.ASYMMETRIC,
123+
(1, linear.weight.shape[-1]),
124+
torch.uint8,
125+
0,
126+
2**6 - 1,
127+
),
128+
requires_grad=False,
129+
)
119130
return linear
120131

121-
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
122-
apply_uint6_weight_only_quant(l)
132+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
133+
apply_uint6_weight_only_quant(linear)
123134

124135
example_input = torch.randn(1, 128, dtype=torch.bfloat16, device="cuda")
125-
with self.assertRaisesRegex(AssertionError, "dispatching to my impl for uint6 weight only quant"):
126-
l(example_input)
136+
with self.assertRaisesRegex(
137+
AssertionError, "dispatching to my impl for uint6 weight only quant"
138+
):
139+
linear(example_input)
127140

128141
deregister_aqt_quantized_linear_dispatch(dispatch_condition)
129142

130143
@common_utils.parametrize("apply_quant", get_quantization_functions(True, True))
131144
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
132145
def test_print_quantized_module(self, apply_quant):
133-
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
134-
ql = apply_quant(l)
146+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
147+
ql = apply_quant(linear)
135148
assert "AffineQuantizedTensor" in str(ql)
136149

137150

@@ -143,20 +156,25 @@ class TestAffineQuantizedBasic(TestCase):
143156
@common_utils.parametrize("device", COMMON_DEVICES)
144157
@common_utils.parametrize("dtype", COMMON_DTYPES)
145158
def test_flatten_unflatten(self, apply_quant, device, dtype):
146-
l = torch.nn.Linear(128, 256, dtype=dtype, device=device)
147-
ql = apply_quant(l)
159+
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
160+
ql = apply_quant(linear)
148161
lp_tensor = ql.weight
149162
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__()
150-
tensor_data_dict = {name: getattr(lp_tensor, name) for name in tensor_data_name_dict}
163+
tensor_data_dict = {
164+
name: getattr(lp_tensor, name) for name in tensor_data_name_dict
165+
}
151166
outer_size = lp_tensor.size()
152167
outer_stride = lp_tensor.stride()
153-
reconstructed = type(lp_tensor).__tensor_unflatten__(tensor_data_dict, tensor_attributes, outer_size, outer_stride)
168+
reconstructed = type(lp_tensor).__tensor_unflatten__(
169+
tensor_data_dict, tensor_attributes, outer_size, outer_stride
170+
)
154171
example_inputs = (torch.randn(32, 128, dtype=dtype, device=device),)
155172
ref = ql(*example_inputs)
156173
ql.weight = torch.nn.Parameter(reconstructed, requires_grad=False)
157174
reconstruct_res = ql(*example_inputs)
158175
self.assertEqual(reconstruct_res, ref)
159176

177+
160178
common_utils.instantiate_parametrized_tests(TestAffineQuantized)
161179
common_utils.instantiate_parametrized_tests(TestAffineQuantizedBasic)
162180

test/dtypes/test_affine_quantized_tensor_parallel.py

Lines changed: 38 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,28 @@
1-
import torch
21
import unittest
3-
from torch.testing._internal.common_utils import run_tests
2+
3+
import torch
4+
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
45
from torch.testing._internal import common_utils
6+
from torch.testing._internal.common_utils import run_tests
7+
from torch.testing._internal.distributed._tensor.common_dtensor import (
8+
DTensorTestBase,
9+
with_comms,
10+
)
11+
512
from torchao.quantization import (
13+
float8_dynamic_activation_float8_weight,
14+
float8_weight_only,
615
int4_weight_only,
716
int8_weight_only,
8-
float8_weight_only,
9-
float8_dynamic_activation_float8_weight,
1017
)
1118
from torchao.quantization.observer import PerRow, PerTensor
12-
import torch.distributed as dist
13-
from torch.distributed._tensor import DTensor, Replicate, Shard, DeviceMesh
14-
from torch.testing._internal.distributed._tensor.common_dtensor import (
15-
DTensorTestBase,
16-
with_comms,
17-
NUM_DEVICES,
18-
)
1919
from torchao.quantization.quant_api import quantize_
20-
from torchao.dtypes import AffineQuantizedTensor
2120
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6
2221

22+
2323
class TestAffineQuantizedTensorParallel(DTensorTestBase):
24-
"""Basic test case for tensor subclasses
25-
"""
24+
"""Basic test case for tensor subclasses"""
25+
2626
QUANT_METHOD_FN = staticmethod(int8_weight_only)
2727
QUANT_METHOD_KWARGS = {}
2828

@@ -40,9 +40,7 @@ def colwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
4040
# Construct DTensor from local shard
4141
dtensor = DTensor.from_local(local_shard, mesh, [Shard(0)])
4242
# Replace parameter in module
43-
m.linear.weight = torch.nn.Parameter(
44-
dtensor, requires_grad=False
45-
)
43+
m.linear.weight = torch.nn.Parameter(dtensor, requires_grad=False)
4644
return m
4745

4846
@staticmethod
@@ -59,9 +57,7 @@ def rowwise_shard(m: torch.nn.Module, mesh: DeviceMesh) -> torch.nn.Module:
5957
# Construct DTensor from local shard
6058
dtensor = DTensor.from_local(local_shard, mesh, [Shard(1)], run_check=True)
6159
# Replace parameter in module
62-
m.linear.weight = torch.nn.Parameter(
63-
dtensor, requires_grad=False
64-
)
60+
m.linear.weight = torch.nn.Parameter(dtensor, requires_grad=False)
6561
return m
6662

6763
def quantize(self, m: torch.nn.Module) -> torch.nn.Module:
@@ -79,7 +75,9 @@ def _test_tp(self, dtype):
7975
class M(torch.nn.Module):
8076
def __init__(self, in_features, out_features, **kwargs) -> None:
8177
super().__init__(**kwargs)
82-
self.linear = torch.nn.Linear(in_features, out_features, bias=False, device="cuda")
78+
self.linear = torch.nn.Linear(
79+
in_features, out_features, bias=False, device="cuda"
80+
)
8381

8482
def forward(self, x: torch.Tensor) -> torch.Tensor:
8583
return self.linear(x)
@@ -91,11 +89,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
9189
proj_up = M(1024, 2048).to(device).to(dtype)
9290
proj_dn = M(2048, 1024).to(device).to(dtype)
9391
example_input = 100 * torch.randn(128, 1024, device=device, dtype=dtype)
94-
y = proj_dn(proj_up(example_input))
92+
proj_dn(proj_up(example_input))
9593
# Quantize the model
9694
up_quant = self.quantize(proj_up)
9795
dn_quant = self.quantize(proj_dn)
98-
y_q = dn_quant(up_quant(example_input))
96+
dn_quant(up_quant(example_input))
9997

10098
mesh = self.build_device_mesh()
10199
mesh.device_type = "cuda"
@@ -105,11 +103,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
105103
dn_dist = self.rowwise_shard(dn_quant, mesh)
106104

107105
# We need to turn inputs into DTensor form as well -- just a format change
108-
input_dtensor = DTensor.from_local(
109-
example_input, mesh, [Replicate()]
110-
)
106+
input_dtensor = DTensor.from_local(example_input, mesh, [Replicate()])
111107

112-
y_d = dn_dist(up_dist(input_dtensor))
108+
dn_dist(up_dist(input_dtensor))
113109

114110
if not TORCH_VERSION_AT_LEAST_2_6:
115111
# Need torch 2.6 to support compiled tensor parallelism
@@ -118,7 +114,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
118114
up_compiled = torch.compile(up_dist)
119115
y_up = up_compiled(input_dtensor)
120116
dn_compiled = torch.compile(dn_dist)
121-
y_dn = dn_compiled(y_up)
117+
dn_compiled(y_up)
122118

123119

124120
class TestInt8woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel):
@@ -142,11 +138,13 @@ class TestInt4woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel)
142138
def test_tp(self, dtype):
143139
return self._test_tp(dtype)
144140

141+
145142
common_utils.instantiate_parametrized_tests(TestInt8woAffineQuantizedTensorParallel)
146143
common_utils.instantiate_parametrized_tests(TestInt4woAffineQuantizedTensorParallel)
147144

148145
# Run only on H100
149146
if torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0):
147+
150148
class TestFloat8woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel):
151149
QUANT_METHOD_FN = staticmethod(float8_weight_only)
152150
COMMON_DTYPES = [torch.bfloat16, torch.float16, torch.float32]
@@ -157,7 +155,9 @@ class TestFloat8woAffineQuantizedTensorParallel(TestAffineQuantizedTensorParalle
157155
def test_tp(self, dtype):
158156
return self._test_tp(dtype)
159157

160-
class TestFloat8dqTensorAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel):
158+
class TestFloat8dqTensorAffineQuantizedTensorParallel(
159+
TestAffineQuantizedTensorParallel
160+
):
161161
QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight)
162162
QUANT_METHOD_KWARGS = {"granularity": PerTensor()}
163163
COMMON_DTYPES = [torch.bfloat16, torch.float16, torch.float32]
@@ -168,7 +168,9 @@ class TestFloat8dqTensorAffineQuantizedTensorParallel(TestAffineQuantizedTensorP
168168
def test_tp(self, dtype):
169169
return self._test_tp(dtype)
170170

171-
class TestFloat8dqRowAffineQuantizedTensorParallel(TestAffineQuantizedTensorParallel):
171+
class TestFloat8dqRowAffineQuantizedTensorParallel(
172+
TestAffineQuantizedTensorParallel
173+
):
172174
QUANT_METHOD_FN = staticmethod(float8_dynamic_activation_float8_weight)
173175
QUANT_METHOD_KWARGS = {"granularity": PerRow()}
174176
COMMON_DTYPES = [torch.bfloat16]
@@ -179,7 +181,11 @@ class TestFloat8dqRowAffineQuantizedTensorParallel(TestAffineQuantizedTensorPara
179181
def test_tp(self, dtype):
180182
return self._test_tp(dtype)
181183

182-
common_utils.instantiate_parametrized_tests(TestFloat8dqTensorAffineQuantizedTensorParallel)
183-
common_utils.instantiate_parametrized_tests(TestFloat8dqRowAffineQuantizedTensorParallel)
184+
common_utils.instantiate_parametrized_tests(
185+
TestFloat8dqTensorAffineQuantizedTensorParallel
186+
)
187+
common_utils.instantiate_parametrized_tests(
188+
TestFloat8dqRowAffineQuantizedTensorParallel
189+
)
184190
if __name__ == "__main__":
185191
run_tests()

0 commit comments

Comments
 (0)