Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 8 additions & 58 deletions backends/arm/process_node.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -11,11 +11,6 @@
import serializer.tosa_serializer as ts
import torch
import torch.fx

# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
get_input_qparams,
)
from executorch.backends.arm.operators.node_visitor import NodeVisitor
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
from executorch.backends.arm.tosa_quant_utils import (
Expand All @@ -24,11 +19,7 @@
is_node_quantized,
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.arm.tosa_utils import (
getNodeArgs,
is_bias_node_for_quantized_conv,
tosa_shape,
)
from executorch.backends.arm.tosa_utils import getNodeArgs, tosa_shape
from torch.export.exported_program import ExportedProgram


Expand Down Expand Up @@ -99,41 +90,6 @@ def process_inputs(
tosa_graph.addInputTensor(tensor)


def process_quantized_bias(
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
parameter_values,
):
"""
Serialize bias node that needs to be quantized.
"""
consumer_node = list(node.users)[0]
(
input_node,
weight_node,
_,
) = consumer_node.all_input_nodes

input_qargs = get_input_qparams( # pyre-ignore[16]: Module `executorch.backends.arm` has no attribute `_passes`.
consumer_node
)

input_node_scale = input_qargs[0].scale
weight_node_scale = input_qargs[1].scale
bias_values_quantized = (
(parameter_values / (input_node_scale * weight_node_scale))
.round()
.astype(np.int32)
)

tosa_graph.addConst(
bias_values_quantized.shape,
ts.DType.INT32,
bias_values_quantized,
name=node.name,
)


def process_inputs_to_parameters(
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
Expand All @@ -148,20 +104,14 @@ def process_inputs_to_parameters(
assert isinstance(parameter_data, torch.Tensor), "Expect Attr to be tensor"
parameter_values = parameter_data.detach().numpy()

if is_bias_node_for_quantized_conv(node):
# BI bias
assert tosa_spec.support_integer(), f"{tosa_spec} doesnt't support integer"
process_quantized_bias(node, tosa_graph, parameter_values)
else:
# MI weights or bias
if inputs[0].dtype == torch.float32:
assert tosa_spec.support_float(), f"{tosa_spec} doesn't support float"
if inputs[0].dtype == torch.float32:
assert tosa_spec.support_float(), f"{tosa_spec} doesn't support float"

parameter_values = np.transpose(parameter_values, inputs[0].dim_order)
parameter_values = np.transpose(parameter_values, inputs[0].dim_order)

tosa_graph.addConst(
parameter_values.shape, inputs[0].dtype, parameter_values, name=node.name
)
tosa_graph.addConst(
parameter_values.shape, inputs[0].dtype, parameter_values, name=node.name
)


def process_inputs_to_buffers(
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -196,7 +196,7 @@ def get_quant_properties( # noqa: C901
input_act_qspec = quantization_config.get_input_act_qspec()
weight_qspec = quantization_config.get_weight_qspec()
output_act_qspec = quantization_config.get_output_act_qspec()
bias_qspec = quantization_config.get_bias_qspec()
bias_qspec = quantization_config.get_bias_qspec(node)

quant_properties = _OpQuantProperties()

Expand Down
40 changes: 38 additions & 2 deletions backends/arm/quantizer/quantization_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024 Arm Limited and/or its affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -9,8 +9,10 @@
from dataclasses import dataclass

import torch
from torch.ao.quantization import ObserverOrFakeQuantize

from torch.ao.quantization.quantizer import (
DerivedQuantizationSpec,
FixedQParamsQuantizationSpec,
QuantizationSpec,
)
Expand Down Expand Up @@ -53,8 +55,42 @@ def get_weight_qspec(self) -> QuantizationSpec | None:
], f"Unsupported quantization_spec {self.weight} for weight"
return self.weight

def get_bias_qspec(self) -> QuantizationSpec | None:
def get_bias_qspec(self, node: torch.fx.Node) -> QuantizationSpec | None:
"""Returns QuantizationSpec 'bias' after asserting that bias.dtype is torch.float."""

def _derive_qparams_fn(
obs_or_fqs: list[ObserverOrFakeQuantize],
) -> tuple[torch.Tensor, torch.Tensor]:
assert (
len(obs_or_fqs) == 2
), "Expecting two obs/fqs, one for activation and one for weight, got: {}".format(
len(obs_or_fqs)
)
act_obs_or_fq = obs_or_fqs[0]
weight_obs_or_fq = obs_or_fqs[1]
act_scale, act_zp = act_obs_or_fq.calculate_qparams()
weight_scale, weight_zp = weight_obs_or_fq.calculate_qparams()
return torch.tensor([act_scale * weight_scale]).to(
torch.float32
), torch.tensor([0]).to(torch.int32)

if node.target in [
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.linear.default,
]:
input_act = node.args[0]
weight = node.args[1]
quantization_spec = DerivedQuantizationSpec(
derived_from=[(input_act, node), (weight, node)],
derive_qparams_fn=_derive_qparams_fn,
dtype=torch.int32,
quant_min=torch.iinfo(torch.int32).min,
quant_max=torch.iinfo(torch.int32).max - 1,
qscheme=torch.per_tensor_symmetric,
)
return quantization_spec

if self.bias is None:
return None
assert (
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/test/misc/test_debug_feats.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,10 @@ def test_collate_tosa_BI_tests(self):
"test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests"
)
assert os.path.exists(
"test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/output_tag5.tosa"
"test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/output_tag6.tosa"
)
assert os.path.exists(
"test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/desc_tag5.json"
"test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/desc_tag6.json"
)

os.environ.pop("TOSA_TESTCASES_BASE_PATH")
Expand Down
15 changes: 1 addition & 14 deletions backends/arm/tosa_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023-2024 Arm Limited and/or its affiliates.
# Copyright 2023-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -133,19 +133,6 @@ def build_reshape(tosa_fb, input_name, new_shape, output_name):
tosa_fb.addOperator(TosaOp.Op().RESHAPE, [input_name], [output_name], attr)


def is_bias_node_for_quantized_conv(node):
consumer_node = list(node.users)[0]

if (
consumer_node.target == exir_ops.edge.aten.convolution.default
and consumer_node.args[2] == node
and consumer_node.meta["val"].dtype == torch.int8
):
return True

return False


def is_consumer_node_depthwise_conv2d(node):
consumer_node = list(node.users)[0]
if consumer_node.target == exir_ops.edge.aten.convolution.default:
Expand Down
Loading