From 7569b08ca6a9bd7967d74ddb089e3f5f0e8f869b Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Thu, 9 Oct 2025 16:20:23 -0700 Subject: [PATCH] Removed support for non-per-tensor quantized relu (#14788) Summary: Not supporting quantized relu default, so removing it from ref_implementations Reviewed By: zonglinpeng Differential Revision: D83874866 --- backends/cadence/aot/ref_implementations.py | 49 +++----------- .../aot/tests/test_ref_implementations.py | 64 ++++++------------- 2 files changed, 31 insertions(+), 82 deletions(-) diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 4f612e3bab4..6a13a4424da 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -1125,7 +1125,6 @@ def quantized_relu_common( def quantized_relu_variant( - per_tensor: bool, dtype: torch.dtype | None = None, ) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]: """Create a quantized relu variant with type checking.""" @@ -1133,43 +1132,20 @@ def quantized_relu_variant( def decorator(_: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: def variant( X: torch.Tensor, - X_zero_point: torch.Tensor | int, + X_zero_point: int, out_zero_point: int, - out_multiplier: torch.Tensor | int, - out_shift: torch.Tensor | int, + out_multiplier: int, + out_shift: int, ) -> torch.Tensor: - if per_tensor: - if dtype and X.dtype != dtype: - raise ValueError(f"X dtype must be {dtype}. Got {X.dtype}") - - assert isinstance(out_shift, int) - assert isinstance(out_multiplier, int) - _out_shift = out_shift - _out_multiplier = out_multiplier - else: - assert isinstance(out_multiplier, torch.Tensor) - if out_multiplier.numel() > 1: - raise ValueError("Only scalar out_multiplier is supported") - - assert isinstance(out_shift, torch.Tensor) - if out_shift.numel() > 1: - raise ValueError("Only scalar out_shift is supported") - - assert isinstance(X_zero_point, torch.Tensor) - if X_zero_point.shape != X.shape: - raise ValueError( - f"X_zero_point shape must be {X.shape}. Got {X_zero_point.shape}" - ) - - _out_multiplier = int(out_multiplier.item()) - _out_shift = int(out_shift.item()) + if dtype and X.dtype != dtype: + raise ValueError(f"X dtype must be {dtype}. Got {X.dtype}") return quantized_relu_common( X, X_zero_point, out_zero_point, - _out_multiplier, - _out_shift, + out_multiplier, + out_shift, ) return variant @@ -1177,23 +1153,18 @@ def variant( return decorator -@impl(m, "quantized_relu") -@quantized_relu_variant(False) -def quantized_relu() -> torch.Tensor: ... - - @impl(m, "quantized_relu.per_tensor") -@quantized_relu_variant(True) +@quantized_relu_variant() def quantized_relu_per_tensor() -> torch.Tensor: ... @impl(m, "quantized_relu_asym8s_asym8s.per_tensor") -@quantized_relu_variant(True, torch.int8) +@quantized_relu_variant(torch.int8) def quantized_relu_asym8s_asym8s_per_tensor() -> torch.Tensor: ... @impl(m, "quantized_relu_asym8u_asym8u.per_tensor") -@quantized_relu_variant(True, torch.uint8) +@quantized_relu_variant(torch.uint8) def quantized_relu_asym8u_asym8u_per_tensor() -> torch.Tensor: ... diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 5856c9def66..f679bae9485 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -1080,61 +1080,39 @@ def test_quantized_conv_per_tensor( ) for dtype in [torch.uint8] ], - # Test case 4: Non-per-tensor - *[ - ( - "non_per_tensor", - torch.tensor([-1, -2, -3, 1, 2, 3], dtype=dtype), # input - torch.tensor([0, 0, 0, 1, 1, 1]), # X_zero_point - 5, # out_zero_point - torch.tensor([1073741824]), # out_multiplier (0.5 * 2^31) - torch.tensor([1]), # out_shift (multiply by 2^1 = 2) - dtype, # dtype - torch.tensor([5, 5, 5, 5, 4, 3], dtype=dtype), - ) - for dtype in [torch.int8] - ], ] ) def test_quantized_relu( self, name: str, X: torch.Tensor, - X_zero_point: torch.Tensor | int, + X_zero_point: int, out_zero_point: int, - out_multiplier: torch.Tensor | int, - out_shift: torch.Tensor | int, + out_multiplier: int, + out_shift: int, dtype: torch.dtype, expected_output: torch.Tensor, ) -> None: - if isinstance(X_zero_point, int): - assert isinstance(out_multiplier, int) - assert isinstance(out_shift, int) - - match dtype: - case torch.int8: - quantized_relu = ( - torch.ops.cadence.quantized_relu_asym8s_asym8s.per_tensor - ) - case torch.uint8: - quantized_relu = ( - torch.ops.cadence.quantized_relu_asym8u_asym8u.per_tensor - ) - case _: - quantized_relu = torch.ops.cadence.quantized_relu_per_tensor + match dtype: + case torch.int8: + quantized_relu = ( + torch.ops.cadence.quantized_relu_asym8s_asym8s.per_tensor + ) + case torch.uint8: + quantized_relu = ( + torch.ops.cadence.quantized_relu_asym8u_asym8u.per_tensor + ) + case _: + quantized_relu = torch.ops.cadence.quantized_relu_per_tensor - output = quantized_relu( - X, - X_zero_point, - out_zero_point, - out_multiplier, - out_shift, - ) - else: - output = torch.ops.cadence.quantized_relu( - X, X_zero_point, out_zero_point, out_multiplier, out_shift - ) + output = quantized_relu( + X, + X_zero_point, + out_zero_point, + out_multiplier, + out_shift, + ) # Verify output properties self.assertEqual(output.dtype, dtype, f"Output dtype should be {dtype}")