Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 10 additions & 39 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,75 +1125,46 @@ def quantized_relu_common(


def quantized_relu_variant(
per_tensor: bool,
dtype: torch.dtype | None = None,
) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]:
"""Create a quantized relu variant with type checking."""

def decorator(_: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
def variant(
X: torch.Tensor,
X_zero_point: torch.Tensor | int,
X_zero_point: int,
out_zero_point: int,
out_multiplier: torch.Tensor | int,
out_shift: torch.Tensor | int,
out_multiplier: int,
out_shift: int,
) -> torch.Tensor:
if per_tensor:
if dtype and X.dtype != dtype:
raise ValueError(f"X dtype must be {dtype}. Got {X.dtype}")

assert isinstance(out_shift, int)
assert isinstance(out_multiplier, int)
_out_shift = out_shift
_out_multiplier = out_multiplier
else:
assert isinstance(out_multiplier, torch.Tensor)
if out_multiplier.numel() > 1:
raise ValueError("Only scalar out_multiplier is supported")

assert isinstance(out_shift, torch.Tensor)
if out_shift.numel() > 1:
raise ValueError("Only scalar out_shift is supported")

assert isinstance(X_zero_point, torch.Tensor)
if X_zero_point.shape != X.shape:
raise ValueError(
f"X_zero_point shape must be {X.shape}. Got {X_zero_point.shape}"
)

_out_multiplier = int(out_multiplier.item())
_out_shift = int(out_shift.item())
if dtype and X.dtype != dtype:
raise ValueError(f"X dtype must be {dtype}. Got {X.dtype}")

return quantized_relu_common(
X,
X_zero_point,
out_zero_point,
_out_multiplier,
_out_shift,
out_multiplier,
out_shift,
)

return variant

return decorator


@impl(m, "quantized_relu")
@quantized_relu_variant(False)
def quantized_relu() -> torch.Tensor: ...


@impl(m, "quantized_relu.per_tensor")
@quantized_relu_variant(True)
@quantized_relu_variant()
def quantized_relu_per_tensor() -> torch.Tensor: ...


@impl(m, "quantized_relu_asym8s_asym8s.per_tensor")
@quantized_relu_variant(True, torch.int8)
@quantized_relu_variant(torch.int8)
def quantized_relu_asym8s_asym8s_per_tensor() -> torch.Tensor: ...


@impl(m, "quantized_relu_asym8u_asym8u.per_tensor")
@quantized_relu_variant(True, torch.uint8)
@quantized_relu_variant(torch.uint8)
def quantized_relu_asym8u_asym8u_per_tensor() -> torch.Tensor: ...


Expand Down
Loading