Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions backends/arm/operators/op_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,9 @@ def define_node(
[ts.DType.INT8, ts.DType.INT32],
output.tosa_spec,
)

scale_back = 1.0
if inputs[0].dtype == ts.DType.INT8:
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32_maxscale(
tosa_graph, inputs, node, self.tosa_spec
)
else:
Expand Down Expand Up @@ -85,7 +84,12 @@ def define_node(
# Scale output back to 8 bit
# pyre-ignore
tqutils.insert_rescale_op_to_int8(
tosa_graph, add_output, scale_back, node, self.tosa_spec
tosa_graph,
add_output,
scale_back,
node,
compute_rescale=False,
tosa_spec=self.tosa_spec,
) # type: ignore[possibly-undefined]


Expand Down
9 changes: 7 additions & 2 deletions backends/arm/operators/op_sub.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def define_node(

scale_back = 1.0
if inputs[0].dtype == ts.DType.INT8:
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32_maxscale(
tosa_graph, inputs, node, self.tosa_spec
)
else:
Expand Down Expand Up @@ -86,7 +86,12 @@ def define_node(
# Scale output back to 8 bit
# pyre-ignore
tqutils.insert_rescale_op_to_int8(
tosa_graph, sub_output, scale_back, node, self.tosa_spec
tosa_graph,
sub_output,
scale_back,
node,
compute_rescale=False,
tosa_spec=self.tosa_spec,
) # type: ignore[possibly-undefined]


Expand Down
8 changes: 4 additions & 4 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,10 @@ def any_or_hardtanh_min_zero(n: Node):
]
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
elif node.target in (
torch.ops.aten.add.Tensor,
torch.ops.aten.add_.Tensor,
torch.ops.aten.sub.Tensor,
torch.ops.aten.sub_.Tensor,
torch.ops.aten.matmul.default,
torch.ops.aten.mm.default,
torch.ops.aten.bmm.default,
Expand All @@ -485,10 +489,6 @@ def any_or_hardtanh_min_zero(n: Node):
]
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
elif node.target in (
torch.ops.aten.add.Tensor,
torch.ops.aten.add_.Tensor,
torch.ops.aten.sub.Tensor,
torch.ops.aten.sub_.Tensor,
torch.ops.aten.minimum.default,
torch.ops.aten.maximum.default,
):
Expand Down
110 changes: 110 additions & 0 deletions backends/arm/test/misc/test_conv_relu_residual_add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import pytest

import torch
import torch.nn as nn
from executorch.backends.arm.test import common

from executorch.backends.arm.test.tester.test_pipeline import (
EthosU55PipelineINT,
EthosU85PipelineINT,
TosaPipelineFP,
TosaPipelineINT,
)


# Model with Conv1D - ReLU sequence and a residual add.
# Testing the annotation of Conv1D-ReLU(to be fused) and annotation of add.
# ReLU outputs positive numbers and linear outputs positive and negative numbers, so they
# should have different quantisation parameters. If the ReLU gets wrong quantisation parameters(e.g. qmin!=zp)
# because of a shared observer of a following operators(e.g. add), the Conv1D-ReLU sequence is not fused
# and is left in FP32. As a result, the test fails.
class AddDifferentRanges(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, input_dim):
super().__init__()
self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size)
self.relu = torch.nn.ReLU()
self.linear = nn.Linear(out_channels, out_channels)

def forward(self, x):
# Permute: (N, T, C) -> (N, C, T)
x = x.permute(0, 2, 1)
x = self.conv1(x)
x = self.relu(x)
x = x.permute(0, 2, 1)
out = x + self.linear(x)
return out


input_t = Tuple[torch.Tensor]
model = AddDifferentRanges(in_channels=3, out_channels=16, kernel_size=3, input_dim=10)
model_inputs = (torch.randn(1, 10, 3),)
quant_test_data = {
"per_channel_quantization=true": True,
"per_channel_quantization=false": False,
}


def test_tosa_FP():
pipeline = TosaPipelineFP[input_t](
model,
model_inputs,
aten_op=[],
exir_op=[],
use_to_edge_transform_and_lower=True,
)
pipeline.run()


@common.parametrize("per_channel_quantization", quant_test_data)
def test_tosa_INT(per_channel_quantization):
pipeline = TosaPipelineINT[input_t](
model,
model_inputs,
aten_op=[],
exir_op=[],
use_to_edge_transform_and_lower=True,
per_channel_quantization=per_channel_quantization,
qtol=0,
)
pipeline.run()


@pytest.mark.slow
@common.XfailIfNoCorstone300
@common.parametrize("per_channel_quantization", quant_test_data)
def test_tosa_u55_INT(per_channel_quantization):
pipeline = EthosU55PipelineINT[input_t](
model,
model_inputs,
[],
[],
run_on_fvp=True,
use_to_edge_transform_and_lower=True,
per_channel_quantization=per_channel_quantization,
qtol=0,
)
pipeline.run()


@pytest.mark.slow
@common.XfailIfNoCorstone320
@common.parametrize("per_channel_quantization", quant_test_data)
def test_tosa_u85_INT(per_channel_quantization):
pipeline = EthosU85PipelineINT[input_t](
model,
model_inputs,
[],
[],
run_on_fvp=True,
use_to_edge_transform_and_lower=True,
per_channel_quantization=per_channel_quantization,
qtol=0,
)
pipeline.run()
2 changes: 1 addition & 1 deletion backends/arm/test/models/test_inception_v3_arm.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_ic3_tosa_BI():
aten_op=[],
exir_op=[],
use_to_edge_transform_and_lower=True,
atol=0.6,
atol=0.65,
qtol=1,
)
pipeline.run()
Expand Down
99 changes: 99 additions & 0 deletions backends/arm/test/models/test_resnet18.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import pytest

import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import (
EthosU55PipelineINT,
EthosU85PipelineINT,
TosaPipelineFP,
TosaPipelineINT,
)

from torchvision import transforms # type: ignore[import-untyped]
from torchvision.models import resnet18, ResNet18_Weights

model = resnet18(weights=ResNet18_Weights)
model = model.eval()
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

model_inputs = (normalize(torch.randn((1, 3, 224, 224))),)

input_t = Tuple[torch.Tensor]


quant_test_data = {
"per_channel_quantization=true": True,
"per_channel_quantization=false": False,
}


def test_resnet_tosa_FP():
pipeline = TosaPipelineFP[input_t](
model,
model_inputs,
aten_op=[],
exir_op=[],
use_to_edge_transform_and_lower=True,
)
pipeline.run()


@common.parametrize("per_channel_quantization", quant_test_data)
def test_resnet_tosa_INT(per_channel_quantization):
pipeline = TosaPipelineINT[input_t](
model,
model_inputs,
aten_op=[],
exir_op=[],
use_to_edge_transform_and_lower=True,
per_channel_quantization=per_channel_quantization,
atol=0.5,
qtol=1,
)
pipeline.run()


@pytest.mark.slow
@common.XfailIfNoCorstone300
@common.parametrize("per_channel_quantization", quant_test_data)
def test_resnet_u55_INT(per_channel_quantization):
pipeline = EthosU55PipelineINT[input_t](
model,
model_inputs,
aten_ops=[],
exir_ops=[],
run_on_fvp=True,
use_to_edge_transform_and_lower=True,
per_channel_quantization=per_channel_quantization,
atol=0.5,
qtol=1,
)
pipeline.run()


@pytest.mark.slow
@pytest.mark.xfail(
reason="For resnet18 for Ethos-U85, the SRAM memory footprint is very high. The compiler team is investigating."
)
@common.XfailIfNoCorstone320
@common.parametrize("per_channel_quantization", quant_test_data)
def test_resnet_u85_INT(per_channel_quantization):
pipeline = EthosU85PipelineINT[input_t](
model,
model_inputs,
aten_ops=[],
exir_ops=[],
run_on_fvp=True,
use_to_edge_transform_and_lower=True,
per_channel_quantization=per_channel_quantization,
atol=0.5,
qtol=1,
)
pipeline.run()
21 changes: 16 additions & 5 deletions backends/arm/test/ops/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,17 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
"4d_randn_1": lambda: (torch.randn(1, 1, 4, 4), torch.ones(1, 1, 4, 1)),
"4d_randn_2": lambda: (torch.randn(1, 3, 4, 4), torch.randn(1, 3, 4, 4)),
"4d_randn_big": lambda: (
10000 * torch.randn(1, 1, 4, 4),
(1 << 30) * torch.randn(1, 1, 4, 4),
torch.randn(1, 1, 4, 1),
),
"4d_randn_1_mutltiple_broadcasts": lambda: (
torch.randn(1, 4, 4, 1),
torch.ones(1, 1, 4, 4),
),
"4d_big_small": lambda: (
(10e10) * torch.randn(1, 10, 20, 30),
torch.randn(1, 10, 20, 30),
),
}


Expand All @@ -86,7 +90,7 @@ def test_add_tensor_tosa_FP(test_data: input_t1):

@common.parametrize("test_data", Add.test_data)
def test_add_tensor_tosa_INT(test_data: input_t1):
pipeline = TosaPipelineINT[input_t1](Add(), test_data(), aten_op, exir_op)
pipeline = TosaPipelineINT[input_t1](Add(), test_data(), aten_op, exir_op, qtol=0)
pipeline.run()


Expand All @@ -111,9 +115,16 @@ def test_add_tensor_tosa_INT_i32(test_data: input_t1):
quant_max=2**31 - 1,
quant_min=-(2**31),
)
output_act_qspec = QuantizationSpec(
torch.int32,
observer,
qscheme=torch.per_tensor_symmetric,
quant_max=2**31 - 1,
quant_min=-(2**31),
)
# This quantization_config will be set as global config.
quantization_config = arm_quantizer.QuantizationConfig(
input_act_qspec, None, None, None
input_act_qspec, output_act_qspec, None, None
)
quantize_stage = Quantize(quantizer, quantization_config)
pipeline.change_args("quantize", quantize_stage)
Expand Down Expand Up @@ -157,13 +168,13 @@ def test_add_tensor_tosa_FP_3(test_data: input_t2):

@common.parametrize("test_data", Add3.test_data)
def test_add_tensor_tosa_INT_3(test_data: input_t2):
pipeline = TosaPipelineINT[input_t2](Add3(), test_data(), aten_op, exir_op)
pipeline = TosaPipelineINT[input_t2](Add3(), test_data(), aten_op, exir_op, qtol=0)
pipeline.run()


@common.parametrize("test_data", Add2.test_data)
def test_add_tensor_tosa_INT_2(test_data: input_t2):
pipeline = TosaPipelineINT[input_t2](Add2(), test_data(), aten_op, exir_op)
pipeline = TosaPipelineINT[input_t2](Add2(), test_data(), aten_op, exir_op, qtol=0)
pipeline.run()


Expand Down
14 changes: 7 additions & 7 deletions backends/arm/test/ops/test_conv_combos.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,28 +47,28 @@ def __init__(self):
# (t, c, n, s) = (6, 96, 1, 1)
# 1. 1x1 CONV2d + ReLU6 (Pointwise)
self.pointwise_conv2d = torch.nn.Conv2d(
in_channels=32, out_channels=128, kernel_size=1, stride=1, groups=1
in_channels=16, out_channels=96, kernel_size=1, stride=1, groups=1
) ## (1, 128, 81, 81)
self.batch_norm2d_16 = torch.nn.BatchNorm2d(128, affine=False)
self.batch_norm2d_16 = torch.nn.BatchNorm2d(96, affine=False)
self.relu6 = torch.nn.ReLU6()

# 2. 3x3 DepthwiseConv2d + ReLu6
self.depthwise_conv2d = torch.nn.Conv2d(
in_channels=128,
out_channels=128,
in_channels=96,
out_channels=96,
kernel_size=3,
padding=1,
stride=1,
groups=128,
groups=96,
) ## (1, 128, H, W)

# 3. Linear 1x1 Conv2d
self.pointwise_conv2d_linear = torch.nn.Conv2d(
in_channels=128, out_channels=32, kernel_size=1, stride=1, groups=1
in_channels=96, out_channels=16, kernel_size=1, stride=1, groups=1
) ## (1, 32, 81, 81)

def get_inputs(self) -> Tuple[torch.Tensor]:
return (torch.randn(1, 32, 81, 81),)
return (torch.randn(1, 16, 81, 81),)

def forward(self, x):
input = x
Expand Down
3 changes: 3 additions & 0 deletions backends/arm/test/ops/test_group_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ def test_native_group_norm_tosa_INT(test_data):
"test_data",
test_data_suite,
xfails={
"rand_4_6_8_groups_2_eps_no_affine": "MLETORCH-925: Fix numerical issue for aten.native_group_norm",
"rand_4_6_groups_1": "MLETORCH-925: Fix numerical issue for aten.native_group_norm",
"rand_4_6_groups_2": "MLETORCH-925: Fix numerical issue for aten.native_group_norm",
"randn_1_12_8_6_groups_12": "MLETORCH-925: Fix numerical issue for aten.native_group_norm",
"rand_6_8_10_12_groups_1": "MLETORCH-925: Fix numerical issue for aten.native_group_norm",
"rand_6_8_10_12_groups_4_no_affine": "MLETORCH-925: Fix numerical issue for aten.native_group_norm",
Expand Down
Loading
Loading