Skip to content

Commit 0b0a3a8

Browse files
committed
fix CI
1 parent 5099846 commit 0b0a3a8

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,6 @@ def from_float(
232232

233233
scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
234234
int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)
235-
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
236235
int_data = layout_type.post_process(int_data)
237236

238237
layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type))
@@ -501,6 +500,7 @@ def from_plain(
501500
layout_type: LayoutType
502501
):
503502
assert isinstance(layout_type, TensorCoreTiledLayoutType)
503+
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
504504
assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack expects `uint8` dtype"
505505
packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, layout_type.inner_k_tiles)
506506
scale = scale.reshape(int_data.shape[0], -1)

0 commit comments

Comments
 (0)