From b4d664e54b2cb706a20a6d908382319ec97a7207 Mon Sep 17 00:00:00 2001 From: Sebastian Larsson Date: Thu, 2 Oct 2025 15:03:20 +0200 Subject: [PATCH] Arm backend: Fix Mypy error related to _QuantProperty.qspec Previously, _QuantProperty.qspec had the type hint `type[QuantizationSpecBase] | List[type[QuantizationSpecBase]]`, which implies that _QuantProperty.qspec should be a class object. However, in torchao the class `QuantizationAnnotation` has this property: `output_qspec: Optional[QuantizationSpecBase] = None` which is set to _QuantProperty.qspec through a series of function calls. `output_qspec` should, as the type hinting implies, be an instance of a class of `QuantizationSpecBase`, not a class object. Therefore, change `type[QuantizationSpecBase] | List[type[QuantizationSpecBase]]` to `QuantizationSpecBase | List[QuantizationSpecBase]`. This allows us to remove a bunch of mypy ignores. Change-Id: Idbb5ca4ba9ab17e8805b1e4d647e46e86f434b69 Signed-off-by: Sebastian Larsson --- .../arm/quantizer/quantization_annotator.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index ebc91c22bbb..61255e8d001 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -37,7 +37,7 @@ class _QuantProperty: """Specify how the input/output at 'index' must be quantized.""" index: int - qspec: type[QuantizationSpecBase] | List[type[QuantizationSpecBase]] + qspec: QuantizationSpecBase | List[QuantizationSpecBase] optional: bool = False mark_annotated: bool = False @@ -510,24 +510,24 @@ def any_or_hardtanh_min_zero(n: Node): quant_properties.quant_inputs = [ _QuantProperty(0, input_act_qspec), _QuantProperty( - 1, input_act_qspec if node.args[0] == node.args[1] else shared_qspec # type: ignore[arg-type] + 1, input_act_qspec if node.args[0] == node.args[1] else shared_qspec ), ] - quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type] + quant_properties.quant_output = _QuantProperty(0, shared_qspec) elif node.target in (torch.ops.aten.where.self,): shared_qspec = SharedQuantizationSpec(node.args[1]) # type: ignore[arg-type] quant_properties.quant_inputs = [ - _QuantProperty(1, shared_qspec), # type: ignore[arg-type] - _QuantProperty(2, shared_qspec), # type: ignore[arg-type] + _QuantProperty(1, shared_qspec), + _QuantProperty(2, shared_qspec), ] - quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type] + quant_properties.quant_output = _QuantProperty(0, shared_qspec) elif node.target in _one_to_one_shared_input_or_input_act_qspec: input_qspec = ( SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type] - if is_output_annotated(node.args[0]) # type: ignore + if is_output_annotated(node.args[0]) # type: ignore[arg-type] else input_act_qspec ) - quant_properties.quant_inputs = [_QuantProperty(0, input_qspec)] # type: ignore[arg-type] + quant_properties.quant_inputs = [_QuantProperty(0, input_qspec)] quant_properties.quant_output = _QuantProperty( 0, SharedQuantizationSpec((node.args[0], node)) # type: ignore[arg-type] ) @@ -545,7 +545,7 @@ def any_or_hardtanh_min_zero(n: Node): if len(node.args[0]) == 0: raise ValueError("Expected non-empty list for node.args[0]") - shared_qspec = SharedQuantizationSpec((node.args[0][0], node)) + shared_qspec = SharedQuantizationSpec((node.args[0][0], node)) # type: ignore[arg-type] quant_properties.quant_inputs = [ _QuantProperty( 0, @@ -555,7 +555,7 @@ def any_or_hardtanh_min_zero(n: Node): ], ) ] - quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type] + quant_properties.quant_output = _QuantProperty(0, shared_qspec) elif node.target in _one_to_one: quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)] quant_properties.quant_output = _QuantProperty(0, output_act_qspec) @@ -575,7 +575,7 @@ def any_or_hardtanh_min_zero(n: Node): quant_properties.quant_inputs = [ _QuantProperty(0, input_act_qspec), _QuantProperty( - 1, input_act_qspec if node.args[0] == node.args[1] else shared_qspec # type: ignore[arg-type] + 1, input_act_qspec if node.args[0] == node.args[1] else shared_qspec ), ] quant_properties.quant_output = None @@ -588,11 +588,11 @@ def any_or_hardtanh_min_zero(n: Node): quant_properties.quant_inputs = [] quant_properties.quant_output = _QuantProperty(0, output_act_qspec) elif node.target in [operator.getitem]: - if not is_output_annotated(node.args[0]): # type: ignore[attr-defined, arg-type] + if not is_output_annotated(node.args[0]): # type: ignore[arg-type] return None shared_qspec = SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type] - quant_properties.quant_inputs = [_QuantProperty(0, shared_qspec)] # type: ignore[arg-type] - quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type] + quant_properties.quant_inputs = [_QuantProperty(0, shared_qspec)] + quant_properties.quant_output = _QuantProperty(0, shared_qspec) else: return None