Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
from torchao.dtypes import (
AffineQuantizedTensor,
)
from torchao.quantization import (
LinearActivationQuantizedTensor,
)
from torchao.quantization.quant_primitives import (
MappingType,
ZeroPointDomain,
)
from torchao.quantization.subclass import (
LinearActQuantizedTensor,
Int8WeightOnlyQuantizedLinearWeight,
Int4WeightOnlyQuantizedLinearWeight,
)
Expand Down Expand Up @@ -504,8 +506,8 @@ def test_quantized_tensor_subclass_8da4w(self):
example_inputs = m.example_inputs()
quantize_(m, int8_dynamic_activation_int4_weight(group_size=group_size))

assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
assert isinstance(m.linear1.weight, LinearActivationQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActivationQuantizedTensor)
assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor)
assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor)

Expand Down Expand Up @@ -577,8 +579,8 @@ def test_quantized_tensor_subclass_int8_dyn_quant(self):
example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda")
quantize_(m, int8_dynamic_activation_int8_weight())

assert isinstance(m.linear1.weight, LinearActQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActQuantizedTensor)
assert isinstance(m.linear1.weight, LinearActivationQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActivationQuantizedTensor)
assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor)
assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor)

Expand Down
60 changes: 22 additions & 38 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from torchao.utils import find_multiple
from torchao.dtypes.utils import (
_implements,
_ATEN_OP_OR_TORCH_FN_TABLE,
_dispatch__torch_function__,
_dispatch__torch_dispatch__,
_register_layout_cls,
_get_layout_tensor_constructor,
LayoutType,
Expand Down Expand Up @@ -295,17 +296,6 @@ def from_float_static(
def layout_type(self) -> LayoutType:
return self.layout_tensor.layout_type

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs

if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]:
return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](*args, **kwargs)

with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)


def _get_to_kwargs(self, *args, **kwargs):
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
device = self.device if device is None else device
Expand Down Expand Up @@ -347,29 +337,23 @@ def _apply_fn_to_data(self, fn):
strides=self.stride(),
)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
# Note: we only added cpu path here for 8da4w, this is for executorch, in the future
# 1. we'll add cpu/cuda version (int4mm etc.)
# 2. we'll need to hide the 8da4w executorch version under things like layouts (we also have multiple impl for cpu kernel as Michael mentioned), so it will be something like
# cpu device + et laytout --> gives current 8da4w executorch representation
# cpu device + avx layout --> gives optimized kernel for 8da4w in avx cpu etc.
# cuda device + some layout --> gives cuda kernel

# two scenarios where we currently fall back to vanilla mm:
# 1 - when tensor is on CUDA: we'll add this later, we'll also enable dispatching to optimized
# kernels in CPU as well, see the note above
# 2 - we're given non-floats - quantizing long to int8 is crazy

if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]:
return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](func, *args, **kwargs)
implements = classmethod(_implements)
# Note: we only added cpu path here for 8da4w, this is for executorch, in the future
# 1. we'll add cpu/cuda version (int4mm etc.)
# 2. we'll need to hide the 8da4w executorch version under things like layouts (we also have multiple impl for cpu kernel as Michael mentioned), so it will be something like
# cpu device + et laytout --> gives current 8da4w executorch representation
# cpu device + avx layout --> gives optimized kernel for 8da4w in avx cpu etc.
# cuda device + some layout --> gives cuda kernel

raise NotImplementedError(
f"AffineQuantizedTensor dispatch: attempting to run {func}, this is not supported"
)
# two scenarios where we currently fall back to vanilla mm:
# 1 - when tensor is on CUDA: we'll add this later, we'll also enable dispatching to optimized
# kernels in CPU as well, see the note above
# 2 - we're given non-floats - quantizing long to int8 is crazy
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
__torch_function__ = classmethod(_dispatch__torch_function__)

def implements(aten_ops_or_torch_fn):
return _implements(AffineQuantizedTensor, aten_ops_or_torch_fn)
implements = AffineQuantizedTensor.implements

def register_layout_cls(layout_type_class: type(LayoutType)):
return _register_layout_cls(AffineQuantizedTensor, layout_type_class)
Expand Down Expand Up @@ -827,7 +811,7 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias):


@implements(torch.nn.functional.linear)
def functional_linear(*args, **kwargs):
def _(func, types, *args, **kwargs):
input_tensor, weight_tensor, bias = (
args[0],
args[1],
Expand All @@ -846,7 +830,7 @@ def functional_linear(*args, **kwargs):
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)

@implements([aten.mm.default, aten.addmm.default])
def aten_mm(func, *args, **kwargs):
def _(func, types, *args, **kwargs):
if not args[0].is_floating_point():
raise NotImplementedError(f"{func} is not implemented for non floating point input")

Expand Down Expand Up @@ -885,21 +869,21 @@ def aten_mm(func, *args, **kwargs):
return func(input_tensor, weight_tensor)

@implements([aten.detach.default])
def detach(func, *args, **kwargs):
def _(func, types, *args, **kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)


@implements([aten.clone.default])
def clone(func, *args, **kwargs):
def _(func, types, *args, **kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)


@implements([aten._to_copy.default])
def _to_copy(func, *args, **kwargs):
def _(func, types, *args, **kwargs):
return return_and_correct_aliasing(
func,
args,
Expand All @@ -908,7 +892,7 @@ def _to_copy(func, *args, **kwargs):
)

@implements([aten.t.default])
def t(func, *args, **kwargs):
def _(func, types, *args, **kwargs):
block_size = args[0].block_size
assert len(block_size) == 2
transposed_block_size = (block_size[1], block_size[0])
Expand Down
56 changes: 48 additions & 8 deletions torchao/dtypes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,28 @@
from dataclasses import dataclass

"""
torch_function and torch_dispatch operator dispatch registrations

first key is a tensor subclass type like AffineQuantizedTensor,
second key is a `func` in __torhc_function__ or __torch_dispatch__,
value is a function that implements the dispatch
Helper function for implementing aten op or torch function dispatch
and dispatching to these implementations.
"""
_ATEN_OP_OR_TORCH_FN_TABLE: Dict[Callable, Dict[Callable, Callable]] = defaultdict(dict)

def _implements(cls, aten_ops_or_torch_fns):
"""Use this decorator to implement a function for an aten ops in __torch_dispatch__
(if user passed in a list of ops)
or torch function in __torch_function__ (if user passed in a single object)

class MyTensor(torch.Tensor):
...
implements = classmethod(_implements)

implements = MyTensor.implements

@implements(torch.nn.functional.linear):
def _(func, types, args, kwargs):
...

"""
if not hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE"):
cls._ATEN_OP_OR_TORCH_FN_TABLE = {}

if not isinstance(aten_ops_or_torch_fns, (list, tuple)):
aten_ops_or_torch_fns = [aten_ops_or_torch_fns]
def decorator(func):
Expand All @@ -26,10 +35,41 @@ def decorator(func):
def wrapper(*args, **kwargs):
return func(*args, **kwargs)

_ATEN_OP_OR_TORCH_FN_TABLE[cls][op] = wrapper
cls._ATEN_OP_OR_TORCH_FN_TABLE[op] = wrapper
return func
return decorator

def _dispatch__torch_function__(cls, func, types, args=(), kwargs=None):
"""Use this util function for a common `__torch_function__` implementation
that dispatches to ops/functions registered with `_implements`

class MyTensor(torch.Tensor):
...
__torch_function__ = classmethod(_dispatch__torch_function__)
"""
kwargs = {} if kwargs is None else kwargs
if hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE") and \
func in cls._ATEN_OP_OR_TORCH_FN_TABLE:
return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, *args, **kwargs)

with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)

def _dispatch__torch_dispatch__(cls, func, types, args, kwargs):
"""Use this util function for a common `__torch_dispatch__` implementation
that dispatches to ops/functions registered with `_implements`

class MyTensor(torch.Tensor):
...
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
"""
if hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE") and \
func in cls._ATEN_OP_OR_TORCH_FN_TABLE:
return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, *args, **kwargs)

raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run unimplemented operator/function: {func}")


"""
Base class for different LayoutType, should not be instantiated directly
"""
Expand Down
17 changes: 6 additions & 11 deletions torchao/prototype/low_bit_optim/subclass_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
from torch import Tensor
from torchao.dtypes.utils import _implements, _ATEN_OP_OR_TORCH_FN_TABLE
from torchao.dtypes.utils import _implements, _dispatch__torch_dispatch__

from .quant_utils import create_dynamic_map, scale_tensor, quantize_4bit_with_qmap, dequant_with_qmap

Expand Down Expand Up @@ -85,16 +85,11 @@ def __repr__(self):
f"shape={tuple(self.shape)}, device={self.device}, requires_grad={self.requires_grad})"
)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]:
return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](func, *args, **kwargs)

raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run {func}, this is not supported")
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)


@OptimState4bit.implements(aten.copy_.default)
def _(func, *args, **kwargs):
def _(func, types, *args, **kwargs):
dst = args[0]
src = args[1]

Expand All @@ -121,14 +116,14 @@ def _(func, *args, **kwargs):


@OptimState4bit.implements(aten.lerp.Scalar)
def _(func, *args, **kwargs):
def _(func, types, *args, **kwargs):
args = [x.dequantize() if isinstance(x, OptimState4bit) else x for x in args]
return func(*args, **kwargs)


# this is needed for DTensor.from_local() and for flattening tensor
@OptimState4bit.implements(aten.view.default)
def _(func, *args, **kwargs):
def _(func, types, *args, **kwargs):
x, shape = args

if tuple(x.shape) == tuple(shape):
Expand All @@ -147,7 +142,7 @@ def _(func, *args, **kwargs):
c10d_functional.wait_tensor.default,
_c10d_functional.wait_tensor.default,
])
def _(func, *args, **kwargs):
def _(func, types, *args, **kwargs):
x = args[0]
if not isinstance(x, OptimState4bit):
raise ValueError(f"expecting a OptimState4bit but found {type(x)}")
Expand Down
17 changes: 6 additions & 11 deletions torchao/prototype/low_bit_optim/subclass_8bit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from torch import Tensor
from torchao.dtypes.utils import _implements, _ATEN_OP_OR_TORCH_FN_TABLE
from torchao.dtypes.utils import _implements, _dispatch__torch_dispatch__

from .quant_utils import create_dynamic_map, scale_tensor, quantize_8bit_with_qmap, dequant_with_qmap

Expand Down Expand Up @@ -71,16 +71,11 @@ def __repr__(self):
f"shape={tuple(self.shape)}, device={self.device}, requires_grad={self.requires_grad})"
)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]:
return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](func, *args, **kwargs)

raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run {func}, this is not supported")
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)


@OptimState8bit.implements(aten.copy_.default)
def _(func, *args, **kwargs):
def _(func, types, *args, **kwargs):
dst = args[0]
src = args[1]

Expand All @@ -103,14 +98,14 @@ def _(func, *args, **kwargs):


@OptimState8bit.implements(aten.lerp.Scalar)
def _(func, *args, **kwargs):
def _(func, types, *args, **kwargs):
args = [x.dequantize() if isinstance(x, OptimState8bit) else x for x in args]
return func(*args, **kwargs)


# this is needed for DTensor.from_local()
@OptimState8bit.implements(aten.view.default)
def _(func, *args, **kwargs):
def _(func, types, *args, **kwargs):
x, shape = args
return OptimState8bit(x.codes.view(shape), x.scale, x.qmap, x.signed)

Expand All @@ -122,7 +117,7 @@ def _(func, *args, **kwargs):
c10d_functional.wait_tensor.default,
_c10d_functional.wait_tensor.default,
])
def _(func, *args, **kwargs):
def _(func, types, *args, **kwargs):
x = args[0]
if not isinstance(x, OptimState8bit):
raise ValueError(f"expecting a OptimState8bit but found {type(x)}")
Expand Down
Loading