Skip to content

Commit f9b685d

Browse files
committed
Fix torch.intx support in FakeQuantizeConfig
**Summary:** Fixes the following error when passing `torch.intx` to `FakeQuantizeConfig`. These dtypes were introduced in PyTorch 2.6+: ``` ValueError: Unsupported dtype 'torch.int4', choose from [torch.int8, torch.uint8, <TorchAODType.INT1: 1>, <TorchAODType.INT2: 2>, <TorchAODType.INT3: 3>, <TorchAODType.INT4: 4>, <TorchAODType.INT5: 5>, <TorchAODType.INT6: 6>, <TorchAODType.INT7: 7>, torch.uint1, torch.uint2, torch.uint3, torch.uint4, torch.uint5, torch.uint6, torch.uint7] ``` **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_torch_intx
1 parent b5b739b commit f9b685d

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

test/quantization/test_qat.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from torchao.utils import (
6060
TORCH_VERSION_AT_LEAST_2_3,
6161
TORCH_VERSION_AT_LEAST_2_4,
62+
TORCH_VERSION_AT_LEAST_2_6,
6263
)
6364

6465
# TODO: put this in a common test utils file
@@ -1262,6 +1263,26 @@ def test_quantize_api_errors(self):
12621263
lambda m, _: isinstance(m, torch.nn.ReLU),
12631264
)
12641265

1266+
@unittest.skipIf(
1267+
not TORCH_VERSION_AT_LEAST_2_6, "skipping when torch version is 2.6 or lower"
1268+
)
1269+
def test_fake_quantize_config_torch_intx(self):
1270+
"""
1271+
Test that `FakeQuantizeConfig` works with torch.intx.
1272+
"""
1273+
group_size = 16
1274+
config1 = FakeQuantizeConfig(TorchAODType.INT4, group_size=group_size)
1275+
config2 = FakeQuantizeConfig(torch.int4, group_size=group_size)
1276+
linear1 = FakeQuantizedLinear(32, 64, weight_config=config1)
1277+
linear2 = FakeQuantizedLinear(32, 64, weight_config=config2)
1278+
linear2.weight = linear1.weight
1279+
torch.manual_seed(self.SEED)
1280+
x = torch.randn((1, 32)).to(torch.float)
1281+
x2 = copy.deepcopy(x)
1282+
out1 = linear1(*x)
1283+
out2 = linear2(*x2)
1284+
torch.testing.assert_close(out1, out2, atol=0, rtol=0)
1285+
12651286

12661287
if __name__ == "__main__":
12671288
unittest.main()

torchao/quantization/quant_primitives.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torchao.utils import (
1919
TORCH_VERSION_AT_LEAST_2_3,
2020
TORCH_VERSION_AT_LEAST_2_5,
21+
TORCH_VERSION_AT_LEAST_2_6,
2122
_is_float8_type,
2223
_register_custom_op,
2324
)
@@ -162,6 +163,31 @@ class TorchAODType(Enum):
162163
}
163164
)
164165

166+
# torch.intX available only in PyTorch 2.6+
167+
if TORCH_VERSION_AT_LEAST_2_6:
168+
_SUB_BYTE_INT_BOUNDS.update(
169+
{
170+
torch.int1: (-(2**0), 2**0 - 1),
171+
torch.int2: (-(2**1), 2**1 - 1),
172+
torch.int3: (-(2**2), 2**2 - 1),
173+
torch.int4: (-(2**3), 2**3 - 1),
174+
torch.int5: (-(2**4), 2**4 - 1),
175+
torch.int6: (-(2**5), 2**5 - 1),
176+
torch.int7: (-(2**6), 2**6 - 1),
177+
}
178+
)
179+
_DTYPE_TO_BIT_WIDTH.update(
180+
{
181+
torch.int1: 1,
182+
torch.int2: 2,
183+
torch.int3: 3,
184+
torch.int4: 4,
185+
torch.int5: 5,
186+
torch.int6: 6,
187+
torch.int7: 7,
188+
}
189+
)
190+
165191
_DTYPE_TO_QVALUE_BOUNDS.update(_SUB_BYTE_UINT_BOUNDS)
166192
_DTYPE_TO_QVALUE_BOUNDS.update(_SUB_BYTE_INT_BOUNDS)
167193
assert _DTYPE_TO_BIT_WIDTH.keys() == _DTYPE_TO_QVALUE_BOUNDS.keys()

0 commit comments

Comments
 (0)