Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/microbenchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def string_to_config(
)
if "marlin" in quantization:
if "qqq" in quantization:
from torchao.dtypes import MarlinQQQLayout
from torchao.prototype.dtypes import MarlinQQQLayout

return Int8DynamicActivationInt4WeightConfig(
group_size=128,
Expand Down
4 changes: 2 additions & 2 deletions docs/source/api_ref_dtypes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ Layouts and Tensor Subclasses
FloatxTensorCoreLayout
MarlinSparseLayout
UintxLayout
MarlinQQQTensor
MarlinQQQLayout
Int4CPULayout
CutlassSemiSparseLayout

Expand Down Expand Up @@ -53,6 +51,8 @@ Prototype
BlockSparseLayout
CutlassInt4PackedLayout
Int8DynamicActInt4WeightCPULayout
MarlinQQQTensor
MarlinQQQLayout

..
_NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring
Expand Down
1 change: 1 addition & 0 deletions test/dtypes/test_uintx.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def test_uintx_api_deprecation():
),
("CutlassInt4PackedLayout", "torchao.dtypes.uintx.cutlass_int4_packed_layout"),
("BlockSparseLayout", "torchao.dtypes.uintx.block_sparse_layout"),
("MarlinQQQLayout", "torchao.dtypes.uintx.marlin_qqq_tensor"),
]

for api_name, module_path in deprecated_apis:
Expand Down
2 changes: 1 addition & 1 deletion test/quantization/test_marlin_qqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch import nn
from torch.testing._internal.common_utils import TestCase, run_tests

from torchao.dtypes import MarlinQQQLayout
from torchao.prototype.dtypes import MarlinQQQLayout
from torchao.quantization.marlin_qqq import (
pack_to_marlin_qqq,
unpack_from_marlin_qqq,
Expand Down
2 changes: 1 addition & 1 deletion torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ def ffn_or_attn_only(mod, fqn):
)
if "marlin" in quantization:
if "qqq" in quantization:
from torchao.dtypes import MarlinQQQLayout
from torchao.prototype.dtypes import MarlinQQQLayout

quantize_(
model,
Expand Down
8 changes: 5 additions & 3 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,21 @@
from .uintx import (
Int4CPULayout,
Int4XPULayout,
MarlinQQQLayout,
MarlinQQQTensor,
MarlinSparseLayout,
PackedLinearInt8DynamicActivationIntxWeightLayout,
QDQLayout,
SemiSparseLayout,
TensorCoreTiledLayout,
UintxLayout,
to_marlinqqq_quantized_intx,
)
from .uintx.block_sparse_layout import BlockSparseLayout
from .uintx.cutlass_int4_packed_layout import CutlassInt4PackedLayout
from .uintx.dyn_int8_act_int4_wei_cpu_layout import Int8DynamicActInt4WeightCPULayout
from .uintx.marlin_qqq_tensor import (
MarlinQQQLayout,
MarlinQQQTensor,
to_marlinqqq_quantized_intx,
)
from .utils import (
Layout,
PlainLayout,
Expand Down
8 changes: 4 additions & 4 deletions torchao/dtypes/affine_quantized_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@
_linear_fp_act_uint4_weight_int8_zero_check,
_linear_fp_act_uint4_weight_int8_zero_impl,
)
from torchao.dtypes.uintx.marlin_qqq_tensor import (
_linear_int8_act_int4_weight_marlin_qqq_check,
_linear_int8_act_int4_weight_marlin_qqq_impl,
)
from torchao.dtypes.uintx.marlin_sparse_layout import (
_linear_fp_act_int4_weight_sparse_marlin_check,
_linear_fp_act_int4_weight_sparse_marlin_impl,
Expand Down Expand Up @@ -94,6 +90,10 @@
_linear_int8_act_int4_weight_cpu_check,
_linear_int8_act_int4_weight_cpu_impl,
)
from torchao.prototype.dtypes.uintx.marlin_qqq_tensor import (
_linear_int8_act_int4_weight_marlin_qqq_check,
_linear_int8_act_int4_weight_marlin_qqq_impl,
)
from torchao.quantization.quant_primitives import (
ZeroPointDomain,
_dequantize_affine_no_zero_point,
Expand Down
Loading
Loading