-
Notifications
You must be signed in to change notification settings - Fork 326
Closed
Description
Hi, my inference code hits exception here https://github.com/pytorch/ao/blob/main/torchao/dtypes/affine_quantized_tensor_ops.py#L228
when I use Int8DynamicActivationInt4Weight. The inference is slower than bf16 inference, as it falls back and dequantized back to bf16.
I'm with torch2.5.0+cu124.
It will hit the exception too when I disable torch.compile().
import torch
from torchao.quantization import (
quantize_,
Int8DynamicActivationInt4WeightConfig,
)
from torchao.quantization.qat import (
FakeQuantizeConfig,
FromIntXQuantizationAwareTrainingConfig,
IntXQuantizationAwareTrainingConfig,
)
from torchao.quantization.quant_primitives import (
TorchAODType,
)
class PytorchLinear(torch.nn.Module):
def __init__(self, in_features=4096, out_features=12288):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features).cuda()
def forward(self, x):
return self.linear(x)
model = PytorchLinear().cuda().to(torch.bfloat16)
input_tensors = [torch.randn((70000, 4096), dtype=torch.bfloat16, device="cuda") for _ in range(100)]
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(TorchAODType.INT4, group_size=32)
quantize_(
model,
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
)
# convert: transform fake quantization ops into actual quantized ops
# swap `FakeQuantizedLinear` back to `torch.nn.Linear` and inserts
# quantized activation and weight tensor subclasses
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32))
model = torch.compile(model.eval(), mode="max-autotune-no-cudagraphs")
# CUDA events for precise timing
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
with torch.no_grad():
for i in range(10):
model(input_tensors[i])
torch.cuda.synchronize()
start_event.record()
with torch.no_grad():
for i in range(50):
model(input_tensors[i+10])
# Record end time
end_event.record()
# Wait for completion
torch.cuda.synchronize()
# Compute elapsed time (in milliseconds)
elapsed_time = start_event.elapsed_time(end_event) / 50 # Average per iteration
print(f"Avg Inference Time: {elapsed_time:.3f} ms")
Metadata
Metadata
Assignees
Labels
No labels