Skip to content

Commit 5dd0132

Browse files
authored
Unskip test_choose_qparams_token_asym in 2.6 (#1004)
* Unskip `test_choose_qparams_token_asym` in 2.6 Summary: Fixes: #970 The test was broken by a recent refactor in pytorch: pytorch/pytorch#136807 Test Plan: CI Reviewers: Subscribers: Tasks: Tags: * fix
1 parent 0cb91ea commit 5dd0132

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

test/quantization/test_quant_primitives.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,14 +202,16 @@ def test_choose_qparams_group_sym_no_clipping_err(self):
202202
self.assertTrue(torch.equal(scale, scale_ref))
203203
self.assertTrue(torch.equal(zero_point, zp_ref))
204204
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "skipping when torch version is 2.3 or lower")
205-
@unittest.skipIf(TORCH_VERSION_AT_LEAST_2_6, "skipping when torch version is 2.6 or higher")
206205
@unittest.skipIf(is_fbcode(), "broken in fbcode")
207206
def test_choose_qparams_token_asym(self):
208207
input = torch.randn(10, 10)
209208
mapping_type = MappingType.ASYMMETRIC
210209
dtype = torch.int8
211210
block_size = (1, 10)
212-
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
211+
if TORCH_VERSION_AT_LEAST_2_6:
212+
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float64, zero_point_dtype=torch.int64)
213+
else:
214+
scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, eps=torch.finfo(torch.float32).eps)
213215

214216
scale_ref, zp_ref = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric(input, dtype)
215217
scale_ref = scale_ref.squeeze()

0 commit comments

Comments
 (0)