From 63ed233cd5158d535f7421d894f2411c417c3e7b Mon Sep 17 00:00:00 2001 From: Oscar Andersson Date: Mon, 28 Apr 2025 14:49:30 +0200 Subject: [PATCH] Add support for single input matmul AnnotateDecomposedMatmul makes sure that a decomposed matmul will two dq-nodes before and a q-node after it's mm/bmm-node. Previously it assumed that the partition always had two input nodes (two dq-nodes), but this is not the case for a single input matmul, e.g. torch.matmul(x, x). In such a case we must copy the dq-node and insert it before the mm/bmm's two inputs. Before pass: -> expand -> view -> / \ x -> dq bmm -> view -> q \ / -> expand -> view -> After pass: -> expand -> view -> dq / \ x bmm -> q -> view \ / -> expand -> view -> dq Signed-off-by: Oscar Andersson Change-Id: I5ac381ccd712a535736fa16d1ee864dc76ae2b30 --- .../arm/_passes/annotate_decomposed_matmul.py | 14 +- backends/arm/test/ops/test_bmm.py | 19 -- backends/arm/test/ops/test_matmul.py | 197 ++++++++++++++++++ 3 files changed, 206 insertions(+), 24 deletions(-) create mode 100644 backends/arm/test/ops/test_matmul.py diff --git a/backends/arm/_passes/annotate_decomposed_matmul.py b/backends/arm/_passes/annotate_decomposed_matmul.py index c45cd63f9cf..72c42f0f829 100644 --- a/backends/arm/_passes/annotate_decomposed_matmul.py +++ b/backends/arm/_passes/annotate_decomposed_matmul.py @@ -70,17 +70,14 @@ def call(self, graph_module: GraphModule) -> PassResult: if quantized_input: matmul_args = matmul_node.all_input_nodes for node in matmul_args: + # Find the dq-node connected to this mm/bmm arg input_node = self._match_partition_to_node( node, partition.input_nodes ) - - # Remove partition input dq-node - input_node.replace_all_uses_with(input_node.all_input_nodes[0]) - graph_module.graph.erase_node(input_node) input_node_qargs = QuantArgs.from_operator( input_node.target, input_node.args ) - + # Insert new dq-node just before the mm/bmm with input_node's qparams with graph_module.graph.inserting_before(matmul_node): # Create new dq-node before matmul dq_node = create_node( @@ -90,6 +87,13 @@ def call(self, graph_module: GraphModule) -> PassResult: dq_node.args = (node, *input_node_qargs) matmul_node.replace_input_with(node, dq_node) + for partition_input in partition.input_nodes: + # Remove partition input dq-node + partition_input.replace_all_uses_with( + partition_input.all_input_nodes[0] + ) + graph_module.graph.erase_node(partition_input) + partition_output = list(partition.output_nodes[0].users)[0] quantized_output = partition_output.target == q_op if quantized_output: diff --git a/backends/arm/test/ops/test_bmm.py b/backends/arm/test/ops/test_bmm.py index 247f5a166b8..375e77cb9b0 100644 --- a/backends/arm/test/ops/test_bmm.py +++ b/backends/arm/test/ops/test_bmm.py @@ -32,15 +32,6 @@ class BMM(torch.nn.Module): def forward(self, x, y): return torch.bmm(x, y) - class MatMul(torch.nn.Module): - test_data_generators = [ - lambda: (torch.rand(2, 3, 5), torch.rand(2, 5, 2)), - lambda: (torch.rand(1, 2, 3, 5), torch.rand(1, 2, 5, 2)), - ] - - def forward(self, x, y): - return torch.matmul(x, y) - class BMMSingleInput(torch.nn.Module): test_data_generators = [ lambda: (torch.rand(20, 3, 3),), @@ -129,16 +120,6 @@ def test_bmm_single_input_tosa_MI(self, test_data_generator: Callable[[], Tuple] test_data = test_data_generator() self._test_bmm_tosa_MI_pipeline(self.BMMSingleInput(), test_data) - @parameterized.expand(MatMul.test_data_generators) - def test_matmul_tosa_MI(self, test_data_generator: Callable[[], Tuple]): - test_data = test_data_generator() - self._test_bmm_tosa_MI_pipeline(self.MatMul(), test_data) - - @parameterized.expand(MatMul.test_data_generators) - def test_matmul_tosa_BI(self, test_data_generator: Callable[[], Tuple]): - test_data = test_data_generator() - self._test_bmm_tosa_BI_pipeline(self.MatMul(), test_data) - @parameterized.expand(BMM.test_data_generators) def test_bmm_tosa_BI(self, test_data_generator: Callable[[], Tuple]): test_data = test_data_generator() diff --git a/backends/arm/test/ops/test_matmul.py b/backends/arm/test/ops/test_matmul.py new file mode 100644 index 00000000000..11a4786c4af --- /dev/null +++ b/backends/arm/test/ops/test_matmul.py @@ -0,0 +1,197 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineBI, + EthosU85PipelineBI, + TosaPipelineBI, + TosaPipelineMI, +) + +aten_op_mm = "torch.ops.aten.matmul.default" +exir_op_mm = "executorch_exir_dialects_edge__ops_aten_matmul_default" +input_t1 = Tuple[torch.Tensor, torch.Tensor] # Input x + + +class MatMul(torch.nn.Module): + test_data_generators = { + "rand_rand_3d": lambda: (torch.rand(2, 3, 5), torch.rand(2, 5, 2)), + "rand_rand_4d": lambda: (torch.rand(1, 2, 3, 5), torch.rand(1, 2, 5, 2)), + } + + def forward(self, x: torch.Tensor, y: torch.Tensor): + return torch.matmul(x, y) + + +class MatMulSingleInput(torch.nn.Module): + test_data_generators = { + "rand_3d": lambda: (torch.rand(2, 5, 5),), + "rand_4d": lambda: (torch.rand(1, 2, 5, 5),), + } + + def forward(self, x: torch.Tensor): + return torch.matmul(x, x) + + +class MatMulCombo(torch.nn.Module): + test_data_generators = { + "rand_rand_rand_3d": lambda: ( + torch.rand(2, 5, 5), + torch.rand(2, 5, 2), + torch.rand(2, 2, 5), + ), + "rand_rand_rand_4d": lambda: ( + torch.rand(1, 2, 5, 5), + torch.rand(1, 2, 5, 2), + torch.rand(1, 2, 2, 5), + ), + } + + def forward(self, x1: torch.Tensor, x2: torch.Tensor, x3: torch.Tensor): + y1 = torch.matmul(x1, x1) + y2 = torch.matmul(x2, x3) + return y1 + y2 + + +@common.parametrize("test_data", MatMul.test_data_generators) +def test_matmul_tosa_MI(test_data: input_t1): + pipeline = TosaPipelineMI[input_t1](MatMul(), test_data(), aten_op_mm, exir_op_mm) + pipeline.run() + + +@common.parametrize("test_data", MatMulSingleInput.test_data_generators) +def test_matmul_single_input_tosa_MI(test_data: input_t1): + pipeline = TosaPipelineMI[input_t1]( + MatMulSingleInput(), test_data(), aten_op_mm, exir_op_mm + ) + pipeline.run() + + +@common.parametrize("test_data", MatMulCombo.test_data_generators) +def test_matmul_combo_tosa_MI(test_data: input_t1): + pipeline = TosaPipelineMI[input_t1]( + MatMulCombo(), test_data(), aten_op_mm, exir_op_mm + ) + pipeline.run() + + +@common.parametrize("test_data", MatMul.test_data_generators) +def test_matmul_tosa_BI(test_data: input_t1): + pipeline = TosaPipelineBI[input_t1]( + MatMul(), test_data(), aten_op_mm, exir_op_mm, qtol=1 + ) + pipeline.run() + + +@common.parametrize("test_data", MatMulSingleInput.test_data_generators) +def test_matmul_single_input_tosa_BI(test_data: input_t1): + pipeline = TosaPipelineMI[input_t1]( + MatMulSingleInput(), + test_data(), + aten_op_mm, + exir_op_mm, + qtol=1, + ) + pipeline.run() + + +@common.parametrize("test_data", MatMulCombo.test_data_generators) +def test_matmul_combo_tosa_BI(test_data: input_t1): + pipeline = TosaPipelineBI[input_t1]( + MatMulCombo(), + test_data(), + aten_op_mm, + exir_op_mm, + qtol=1, + ) + pipeline.run() + + +@common.parametrize("test_data", MatMul.test_data_generators) +@common.XfailIfNoCorstone300 +def test_matmul_u55_BI(test_data: input_t1): + pipeline = EthosU55PipelineBI[input_t1]( + MatMul(), + test_data(), + aten_op_mm, + exir_op_mm, + run_on_fvp=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +@common.parametrize("test_data", MatMulSingleInput.test_data_generators) +@common.XfailIfNoCorstone300 +def test_matmul_single_input_u55_BI(test_data: input_t1): + pipeline = EthosU55PipelineBI[input_t1]( + MatMulSingleInput(), + test_data(), + aten_op_mm, + exir_op_mm, + run_on_fvp=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +@common.parametrize("test_data", MatMulCombo.test_data_generators) +@common.XfailIfNoCorstone300 +def test_matmul_combo_u55_BI(test_data: input_t1): + pipeline = EthosU55PipelineBI[input_t1]( + MatMulCombo(), + test_data(), + aten_op_mm, + exir_op_mm, + run_on_fvp=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +@common.parametrize("test_data", MatMul.test_data_generators) +@common.XfailIfNoCorstone320 +def test_matmul_u85_BI(test_data: input_t1): + pipeline = EthosU85PipelineBI[input_t1]( + MatMul(), + test_data(), + aten_op_mm, + exir_op_mm, + run_on_fvp=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +@common.parametrize("test_data", MatMulSingleInput.test_data_generators) +@common.XfailIfNoCorstone320 +def test_matmul_single_input_u85_BI(test_data: input_t1): + pipeline = EthosU85PipelineBI[input_t1]( + MatMulSingleInput(), + test_data(), + aten_op_mm, + exir_op_mm, + run_on_fvp=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +@common.parametrize("test_data", MatMulCombo.test_data_generators) +@common.XfailIfNoCorstone320 +def test_matmul_combo_u85_BI(test_data: input_t1): + pipeline = EthosU85PipelineBI[input_t1]( + MatMulCombo(), + test_data(), + aten_op_mm, + exir_op_mm, + run_on_fvp=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run()