Skip to content

QuantizedLinearNotImplementedError when inference with Int8DynamicActivationInt4WeightConfig #1909

@goldhuang

Description

@goldhuang

Hi, my inference code hits exception here https://github.com/pytorch/ao/blob/main/torchao/dtypes/affine_quantized_tensor_ops.py#L228
when I use Int8DynamicActivationInt4Weight. The inference is slower than bf16 inference, as it falls back and dequantized back to bf16.
I'm with torch2.5.0+cu124.
It will hit the exception too when I disable torch.compile().

import torch
from torchao.quantization import (
    quantize_,
    Int8DynamicActivationInt4WeightConfig,
)
from torchao.quantization.qat import (
    FakeQuantizeConfig,
    FromIntXQuantizationAwareTrainingConfig,
    IntXQuantizationAwareTrainingConfig,
)
from torchao.quantization.quant_primitives import (
    TorchAODType,
)


class PytorchLinear(torch.nn.Module):
    def __init__(self, in_features=4096, out_features=12288):
        super().__init__()
        self.linear = torch.nn.Linear(in_features, out_features).cuda()

    def forward(self, x):
        return self.linear(x)

model = PytorchLinear().cuda().to(torch.bfloat16)

input_tensors = [torch.randn((70000, 4096), dtype=torch.bfloat16, device="cuda") for _ in range(100)]

activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(TorchAODType.INT4, group_size=32)
quantize_(
    model,
    IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
)

# convert: transform fake quantization ops into actual quantized ops
# swap `FakeQuantizedLinear` back to `torch.nn.Linear` and inserts
# quantized activation and weight tensor subclasses
quantize_(model, FromIntXQuantizationAwareTrainingConfig())
quantize_(model, Int8DynamicActivationInt4WeightConfig(group_size=32))

model = torch.compile(model.eval(), mode="max-autotune-no-cudagraphs")

# CUDA events for precise timing
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

with torch.no_grad():
    for i in range(10):
        model(input_tensors[i])
torch.cuda.synchronize()

start_event.record()
with torch.no_grad():
    for i in range(50):
        model(input_tensors[i+10])
# Record end time
end_event.record()

# Wait for completion
torch.cuda.synchronize()

# Compute elapsed time (in milliseconds)
elapsed_time = start_event.elapsed_time(end_event) / 50  # Average per iteration

print(f"Avg Inference Time: {elapsed_time:.3f} ms")

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions