Skip to content

Commit 49b47a2

Browse files
committed
Fix CI
1 parent 0b0a3a8 commit 49b47a2

File tree

4 files changed

+15
-5
lines changed

4 files changed

+15
-5
lines changed

test/quantization/test_quant_primitives.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from torchao.utils import (
2929
TORCH_VERSION_AFTER_2_3,
3030
TORCH_VERSION_AFTER_2_4,
31+
TORCH_VERSION_AFTER_2_5,
3132
is_fbcode,
3233
)
3334

@@ -98,7 +99,8 @@ def _groupwise_affine_quantize_tensor_from_qparams(
9899
.to(torch.int32)
99100
.reshape_as(w)
100101
)
101-
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8)
102+
if TORCH_VERSION_AFTER_2_5:
103+
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8)
102104

103105
return w_int4x8
104106

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
)
2525
from typing import ClassVar
2626
from dataclasses import dataclass
27+
from torchao.utils import TORCH_VERSION_AFTER_2_5
2728

2829
aten = torch.ops.aten
2930

@@ -500,8 +501,11 @@ def from_plain(
500501
layout_type: LayoutType
501502
):
502503
assert isinstance(layout_type, TensorCoreTiledLayoutType)
503-
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
504-
assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack expects `uint8` dtype"
504+
if TORCH_VERSION_AFTER_2_5:
505+
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
506+
assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype"
507+
else:
508+
assert int_data.dtype == torch.int32, "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype"
505509
packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, layout_type.inner_k_tiles)
506510
scale = scale.reshape(int_data.shape[0], -1)
507511
zero_point = zero_point.reshape(int_data.shape[0], -1)

torchao/prototype/hqq/hqq_tinygemm_linear.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from hqq.core.utils import *
1313

1414
import torch.nn.functional as F
15+
from torchao.utils import TORCH_VERSION_AFTER_2_5
1516

1617

1718
class HQQLinearTorchWeightOnlyInt4(torch.nn.Module):
@@ -198,7 +199,8 @@ def hqq_quants_to_torch_quants(
198199
.reshape(shape)
199200
.contiguous()
200201
)
201-
W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8)
202+
if TORCH_VERSION_AFTER_2_5:
203+
W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8)
202204

203205
# group_dequantize_tensor_from_qparams
204206
# W_r = W_q*scales + min_val

torchao/quantization/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
dequantize_affine,
1818
int_scaled_matmul,
1919
)
20+
from torchao.utils import TORCH_VERSION_AFTER_2_5
2021

2122
__all__ = [
2223
"compute_error",
@@ -349,7 +350,8 @@ def groupwise_affine_quantize_tensor_from_qparams(
349350
quant_max = 2 ** n_bit - 1
350351

351352
int_data = quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain = ZeroPointDomain.FLOAT)
352-
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
353+
if TORCH_VERSION_AFTER_2_5:
354+
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
353355
return int_data
354356

355357
def groupwise_affine_dequantize_tensor_from_qparams(

0 commit comments

Comments
 (0)