Skip to content

Commit ea32965

Browse files
committed
[int8 woq] make the scale type the same as input for bf16 autocast
1 parent 5787e9e commit ea32965

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -730,17 +730,15 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):
730730
w_vals_int8_t = weight_qtensor.layout_tensor.int_data.t()
731731
scale = weight_qtensor.layout_tensor.scale
732732
orig_dtype = input_tensor.dtype
733-
y = (
734-
torch.mm(
733+
m = torch.mm(
735734
input_tensor.reshape(-1, input_tensor.shape[-1]),
736735
w_vals_int8_t.to(input_tensor.dtype),
737736
)
738-
* scale
739-
)
737+
y = m * scale.to(m.dtype)
740738
y = y.reshape(*input_tensor.shape[:-1], y.shape[-1])
741739
if bias is not None:
742-
y += bias
743-
return y.to(orig_dtype)
740+
y += bias.to(m.dtype)
741+
return y
744742

745743
# is_cpu and is_mps only, some issue with is_contiguous() currently
746744
# return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), w_vals_int8_t, weight_qtensor.layout_tensor.scale)

0 commit comments

Comments
 (0)