Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions backends/arm/_passes/annotate_decomposed_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,14 @@ def call(self, graph_module: GraphModule) -> PassResult:
if quantized_input:
matmul_args = matmul_node.all_input_nodes
for node in matmul_args:
# Find the dq-node connected to this mm/bmm arg
input_node = self._match_partition_to_node(
node, partition.input_nodes
)

# Remove partition input dq-node
input_node.replace_all_uses_with(input_node.all_input_nodes[0])
graph_module.graph.erase_node(input_node)
input_node_qargs = QuantArgs.from_operator(
input_node.target, input_node.args
)

# Insert new dq-node just before the mm/bmm with input_node's qparams
with graph_module.graph.inserting_before(matmul_node):
# Create new dq-node before matmul
dq_node = create_node(
Expand All @@ -90,6 +87,13 @@ def call(self, graph_module: GraphModule) -> PassResult:
dq_node.args = (node, *input_node_qargs)
matmul_node.replace_input_with(node, dq_node)

for partition_input in partition.input_nodes:
# Remove partition input dq-node
partition_input.replace_all_uses_with(
partition_input.all_input_nodes[0]
)
graph_module.graph.erase_node(partition_input)

partition_output = list(partition.output_nodes[0].users)[0]
quantized_output = partition_output.target == q_op
if quantized_output:
Expand Down
19 changes: 0 additions & 19 deletions backends/arm/test/ops/test_bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,6 @@ class BMM(torch.nn.Module):
def forward(self, x, y):
return torch.bmm(x, y)

class MatMul(torch.nn.Module):
test_data_generators = [
lambda: (torch.rand(2, 3, 5), torch.rand(2, 5, 2)),
lambda: (torch.rand(1, 2, 3, 5), torch.rand(1, 2, 5, 2)),
]

def forward(self, x, y):
return torch.matmul(x, y)

class BMMSingleInput(torch.nn.Module):
test_data_generators = [
lambda: (torch.rand(20, 3, 3),),
Expand Down Expand Up @@ -129,16 +120,6 @@ def test_bmm_single_input_tosa_MI(self, test_data_generator: Callable[[], Tuple]
test_data = test_data_generator()
self._test_bmm_tosa_MI_pipeline(self.BMMSingleInput(), test_data)

@parameterized.expand(MatMul.test_data_generators)
def test_matmul_tosa_MI(self, test_data_generator: Callable[[], Tuple]):
test_data = test_data_generator()
self._test_bmm_tosa_MI_pipeline(self.MatMul(), test_data)

@parameterized.expand(MatMul.test_data_generators)
def test_matmul_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
test_data = test_data_generator()
self._test_bmm_tosa_BI_pipeline(self.MatMul(), test_data)

@parameterized.expand(BMM.test_data_generators)
def test_bmm_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
test_data = test_data_generator()
Expand Down
197 changes: 197 additions & 0 deletions backends/arm/test/ops/test_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import (
EthosU55PipelineBI,
EthosU85PipelineBI,
TosaPipelineBI,
TosaPipelineMI,
)

aten_op_mm = "torch.ops.aten.matmul.default"
exir_op_mm = "executorch_exir_dialects_edge__ops_aten_matmul_default"
input_t1 = Tuple[torch.Tensor, torch.Tensor] # Input x


class MatMul(torch.nn.Module):
test_data_generators = {
"rand_rand_3d": lambda: (torch.rand(2, 3, 5), torch.rand(2, 5, 2)),
"rand_rand_4d": lambda: (torch.rand(1, 2, 3, 5), torch.rand(1, 2, 5, 2)),
}

def forward(self, x: torch.Tensor, y: torch.Tensor):
return torch.matmul(x, y)


class MatMulSingleInput(torch.nn.Module):
test_data_generators = {
"rand_3d": lambda: (torch.rand(2, 5, 5),),
"rand_4d": lambda: (torch.rand(1, 2, 5, 5),),
}

def forward(self, x: torch.Tensor):
return torch.matmul(x, x)


class MatMulCombo(torch.nn.Module):
test_data_generators = {
"rand_rand_rand_3d": lambda: (
torch.rand(2, 5, 5),
torch.rand(2, 5, 2),
torch.rand(2, 2, 5),
),
"rand_rand_rand_4d": lambda: (
torch.rand(1, 2, 5, 5),
torch.rand(1, 2, 5, 2),
torch.rand(1, 2, 2, 5),
),
}

def forward(self, x1: torch.Tensor, x2: torch.Tensor, x3: torch.Tensor):
y1 = torch.matmul(x1, x1)
y2 = torch.matmul(x2, x3)
return y1 + y2


@common.parametrize("test_data", MatMul.test_data_generators)
def test_matmul_tosa_MI(test_data: input_t1):
pipeline = TosaPipelineMI[input_t1](MatMul(), test_data(), aten_op_mm, exir_op_mm)
pipeline.run()


@common.parametrize("test_data", MatMulSingleInput.test_data_generators)
def test_matmul_single_input_tosa_MI(test_data: input_t1):
pipeline = TosaPipelineMI[input_t1](
MatMulSingleInput(), test_data(), aten_op_mm, exir_op_mm
)
pipeline.run()


@common.parametrize("test_data", MatMulCombo.test_data_generators)
def test_matmul_combo_tosa_MI(test_data: input_t1):
pipeline = TosaPipelineMI[input_t1](
MatMulCombo(), test_data(), aten_op_mm, exir_op_mm
)
pipeline.run()


@common.parametrize("test_data", MatMul.test_data_generators)
def test_matmul_tosa_BI(test_data: input_t1):
pipeline = TosaPipelineBI[input_t1](
MatMul(), test_data(), aten_op_mm, exir_op_mm, qtol=1
)
pipeline.run()


@common.parametrize("test_data", MatMulSingleInput.test_data_generators)
def test_matmul_single_input_tosa_BI(test_data: input_t1):
pipeline = TosaPipelineMI[input_t1](
MatMulSingleInput(),
test_data(),
aten_op_mm,
exir_op_mm,
qtol=1,
)
pipeline.run()


@common.parametrize("test_data", MatMulCombo.test_data_generators)
def test_matmul_combo_tosa_BI(test_data: input_t1):
pipeline = TosaPipelineBI[input_t1](
MatMulCombo(),
test_data(),
aten_op_mm,
exir_op_mm,
qtol=1,
)
pipeline.run()


@common.parametrize("test_data", MatMul.test_data_generators)
@common.XfailIfNoCorstone300
def test_matmul_u55_BI(test_data: input_t1):
pipeline = EthosU55PipelineBI[input_t1](
MatMul(),
test_data(),
aten_op_mm,
exir_op_mm,
run_on_fvp=True,
use_to_edge_transform_and_lower=True,
)
pipeline.run()


@common.parametrize("test_data", MatMulSingleInput.test_data_generators)
@common.XfailIfNoCorstone300
def test_matmul_single_input_u55_BI(test_data: input_t1):
pipeline = EthosU55PipelineBI[input_t1](
MatMulSingleInput(),
test_data(),
aten_op_mm,
exir_op_mm,
run_on_fvp=True,
use_to_edge_transform_and_lower=True,
)
pipeline.run()


@common.parametrize("test_data", MatMulCombo.test_data_generators)
@common.XfailIfNoCorstone300
def test_matmul_combo_u55_BI(test_data: input_t1):
pipeline = EthosU55PipelineBI[input_t1](
MatMulCombo(),
test_data(),
aten_op_mm,
exir_op_mm,
run_on_fvp=True,
use_to_edge_transform_and_lower=True,
)
pipeline.run()


@common.parametrize("test_data", MatMul.test_data_generators)
@common.XfailIfNoCorstone320
def test_matmul_u85_BI(test_data: input_t1):
pipeline = EthosU85PipelineBI[input_t1](
MatMul(),
test_data(),
aten_op_mm,
exir_op_mm,
run_on_fvp=True,
use_to_edge_transform_and_lower=True,
)
pipeline.run()


@common.parametrize("test_data", MatMulSingleInput.test_data_generators)
@common.XfailIfNoCorstone320
def test_matmul_single_input_u85_BI(test_data: input_t1):
pipeline = EthosU85PipelineBI[input_t1](
MatMulSingleInput(),
test_data(),
aten_op_mm,
exir_op_mm,
run_on_fvp=True,
use_to_edge_transform_and_lower=True,
)
pipeline.run()


@common.parametrize("test_data", MatMulCombo.test_data_generators)
@common.XfailIfNoCorstone320
def test_matmul_combo_u85_BI(test_data: input_t1):
pipeline = EthosU85PipelineBI[input_t1](
MatMulCombo(),
test_data(),
aten_op_mm,
exir_op_mm,
run_on_fvp=True,
use_to_edge_transform_and_lower=True,
)
pipeline.run()
Loading