Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 51 additions & 52 deletions backends/arm/_passes/size_adjust_conv2d_pass.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,74 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from typing import cast, Optional
from typing import cast

import torch.fx
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
from torch._ops import OpOverload


def conv_remainder(input_length, pad, dilation, weight, stride):
"""
Returns the size
Returns the remainder of input_length; given the padding, dilation, stride,
and kernel size.
"""
return (input_length + 2 * pad - dilation * (weight - 1) - 1) % stride


def insert_q_dq_pair(
graph: torch.fx.Graph,
anchor: torch.fx.Node,
q_params: tuple,
):
with graph.inserting_after(anchor):
q = create_node(
graph=graph,
op_target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
args=(), # We add the argument last
)
q.meta = anchor.meta

with graph.inserting_after(q):
dq = create_node(
graph=graph,
op_target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
args=(q,) + q_params,
)
dq.meta = q.meta

anchor.replace_all_uses_with(dq)
# We add this last so the replace all uses above does not replace the quantized
# node's first use
q.args = (anchor,) + q_params
return dq


def create_node(
graph: torch.fx.Graph,
op_target: OpOverload,
args: tuple = (),
kwargs: Optional[dict] = None,
):
return graph.create_node(
"call_function",
op_target,
args=args,
kwargs=kwargs or {},
)


class SizeAdjustConv2DPass(ExportPass):
"""
Adjust the convolution input size to match perfectly with the
weight size, padding, stride and dilation parameters.
This is done by inserting a slice op to remove the uneven end of the input.
Adjust the convolution input size to match the kernel size, padding, stride,
and dilation parameters. Pytorch allows the input and kernel shape to not
"match", in which case the remaining rows/columns are truncated. However,
matching the size is a requirement in the TOSA specification. In case the
input and kernel shape do not match, the following is done to meet the
specification:

1) The padding is truncated (done in the node visitor)
2) (if neccessary) The input is truncated (done in this pass)."

A simple example would be a 2x2 kernel (no padding, stride=2) and a 5x5
input:

┌───┬───┬───┬───┬───┐ ┌───┬───┬───┬───┬───┐ ┌───┬───┬───┬───┬───┐
│ X │ X │ │ │ │ │ │ │ X │ X │ │ │ │ │ │ │ - │
├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤
│ X │ X │ │ │ │ │ │ │ X │ X │ │ │ │ │ │ │ - │
├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤
│ │ │ │ │ │ -> │ │ │ │ │ │ -> │ X │ X │ │ │ │ ->
├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤
│ │ │ │ │ │ │ │ │ │ │ │ │ X │ X │ │ │ │
├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤
│ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │
└───┴───┴───┴───┴───┘ └───┴───┴───┴───┴───┘ └───┴───┴───┴───┴───┘
First pass second pass third pass

┌───┬───┬───┬───┬───┐ ┌───┬───┬───┬───┬───┐
│ │ │ │ │ │ │ │ │ │ │ - │
├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤
│ │ │ │ │ │ │ │ │ │ │ - │
├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤
│ │ │ X │ X │ │ -> │ │ │ │ │ - │
├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤
│ │ │ X │ X │ │ │ │ │ │ │ - │
├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤
│ │ │ │ │ │ │ - │ - │ - │ - │ - │
└───┴───┴───┴───┴───┘ └───┴───┴───┴───┴───┘
Fourth pass Unvisited cells

Cells that are never visited are marked with `-` and are never considered
when the kernel traverses over the input, hence they can be removed.

To match the shape of the kernel (and all parameters) with the input, a
slice op is inserted to remove the remaining edges (rows and columns) of the
input.
"""

conv2d_op = exir_ops.edge.aten.convolution.default
Expand Down Expand Up @@ -109,9 +110,7 @@ def call(self, graph_module: torch.fx.GraphModule):
with graph_module.graph.inserting_before(node):
last_node = cast(torch.fx.Node, input_node)
for args in slice_args:
slice_node = graph.create_node(
"call_function", self.slice_op, (last_node,) + args
)
slice_node = create_node(graph, self.slice_op, (last_node,) + args)
last_node = slice_node
conv_node.replace_input_with(cast(torch.fx.Node, input_node), last_node)
modified_graph = True
Expand Down
45 changes: 45 additions & 0 deletions backends/arm/test/ops/test_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,47 @@ def forward(self, x):
batches=1,
)

conv1d_7_1x3x16_st2_pd1_dl2 = Conv1d(
in_channels=3,
out_channels=3,
kernel_size=7,
stride=2,
padding=1,
dilation=2,
length=16,
batches=1,
)
conv1d_7_1x3x15_st1_pd0_dl1 = Conv1d(
in_channels=3,
out_channels=3,
kernel_size=7,
stride=1,
padding=0,
dilation=1,
length=15,
batches=1,
)
conv1d_5_1x3x14_st5_pd0_dl1 = Conv1d(
in_channels=3,
out_channels=3,
kernel_size=5,
stride=5,
padding=0,
dilation=1,
length=14,
batches=1,
)
conv1d_5_1x3x9_st5_pd0_dl1 = Conv1d(
in_channels=3,
out_channels=3,
kernel_size=5,
stride=5,
padding=0,
dilation=1,
length=9,
batches=1,
)

two_conv1d_nobias = Conv1d(
nbr_conv=2,
length=256,
Expand Down Expand Up @@ -214,6 +255,10 @@ def forward(self, x):
("2_1x2x14_st2", conv1d_2_1x2x14_st2),
("5_3x2x128_st1", conv1d_5_3x2x128_st1),
("3_1x3x224_st2_pd1", conv1d_3_1x3x224_st2_pd1),
("7_1x3x16_st2_pd1_dl2_needs_adjust_pass", conv1d_7_1x3x16_st2_pd1_dl2),
("7_1x3x15_st1_pd0_dl1_needs_adjust_pass", conv1d_7_1x3x15_st1_pd0_dl1),
("5_1x3x14_st5_pd0_dl1_needs_adjust_pass", conv1d_5_1x3x14_st5_pd0_dl1),
("5_1x3x9_st5_pd0_dl1_needs_adjust_pass", conv1d_5_1x3x9_st5_pd0_dl1),
("two_conv1d_nobias", two_conv1d_nobias),
("two_conv1d", two_conv1d),
]
Expand Down
105 changes: 104 additions & 1 deletion backends/arm/test/ops/test_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,101 @@ def forward(self, x):
batches=1,
)

conv2d_7x7_1x3x16x16_st2_pd1_dl2 = Conv2d(
in_channels=3,
out_channels=3,
kernel_size=(7, 7),
stride=2,
padding=1,
dilation=2,
width=16,
height=16,
batches=1,
)

conv2d_7x7_1x3x15x15_st1_pd0_dl1 = Conv2d(
in_channels=3,
out_channels=3,
kernel_size=(7, 7),
stride=1,
padding=0,
dilation=1,
width=15,
height=15,
batches=1,
)

conv2d_5x5_1x3x14x14_st5_pd0_dl1 = Conv2d(
in_channels=3,
out_channels=3,
kernel_size=(5, 5),
stride=5,
padding=0,
dilation=1,
width=14,
height=14,
batches=1,
)

conv2d_5x5_1x3x9x9_st5_pd0_dl1 = Conv2d(
in_channels=3,
out_channels=3,
kernel_size=(5, 5),
stride=5,
padding=0,
dilation=1,
width=9,
height=9,
batches=1,
)

conv2d_3x3_1x3x8x9_st3_pd0_dl1 = Conv2d(
in_channels=3,
out_channels=3,
kernel_size=(3, 3),
stride=3,
padding=0,
dilation=1,
width=8,
height=9,
batches=1,
)

conv2d_3x3_1x3x9x8_st3_pd0_dl1 = Conv2d(
in_channels=3,
out_channels=3,
kernel_size=(3, 3),
stride=3,
padding=0,
dilation=1,
width=8,
height=9,
batches=1,
)

conv2d_3x4_1x3x7x7_st3_pd0_dl1 = Conv2d(
in_channels=3,
out_channels=3,
kernel_size=(3, 4),
stride=3,
padding=0,
dilation=1,
width=7,
height=7,
batches=1,
)

conv2d_4x3_1x3x7x7_st3_pd0_dl1 = Conv2d(
in_channels=3,
out_channels=3,
kernel_size=(4, 3),
stride=3,
padding=0,
dilation=1,
width=7,
height=7,
batches=1,
)

two_conv2d_nobias = Conv2d(
nbr_conv=2,
Expand Down Expand Up @@ -236,7 +331,15 @@ def forward(self, x):
("3x3_1x3x12x12_st2_pd1", conv2d_3x3_1x3x12x12_st2_pd1),
("1x1_1x2x128x128_st1", conv2d_1x1_1x2x128x128_st1),
("2x2_1x1x14x13_st2_needs_adjust_pass", conv2d_2x2_1x1x14x13_st2),
("conv2d_5x5_1x3x14x15_st3_pd1_needs_adjust_pass", conv2d_5x5_1x3x14x15_st3_pd1),
("5x5_1x3x14x15_st3_pd1_needs_adjust_pass", conv2d_5x5_1x3x14x15_st3_pd1),
("7x7_1x3x16x16_st2_pd1_dl2_needs_adjust_pass", conv2d_7x7_1x3x16x16_st2_pd1_dl2),
("7x7_1x3x15x15_st1_pd0_dl1_needs_adjust_pass", conv2d_7x7_1x3x15x15_st1_pd0_dl1),
("5x5_1x3x14x14_st5_pd0_dl1_needs_adjust_pass", conv2d_5x5_1x3x14x14_st5_pd0_dl1),
("5x5_1x3x9x9_st5_pd0_dl1_needs_adjust_pass", conv2d_5x5_1x3x9x9_st5_pd0_dl1),
("3x3_1x3x9x8_st3_pd0_dl1_needs_adjust_pass", conv2d_3x3_1x3x9x8_st3_pd0_dl1),
("3x3_1x3x8x9_st3_pd0_dl1_needs_adjust_pass", conv2d_3x3_1x3x8x9_st3_pd0_dl1),
("3x4_1x3x7x7_st3_pd0_dl1_needs_adjust_pass", conv2d_3x4_1x3x7x7_st3_pd0_dl1),
("4x3_1x3x7x7_st3_pd0_dl1_needs_adjust_pass", conv2d_4x3_1x3x7x7_st3_pd0_dl1),
("5x5_3x2x128x128_st1", conv2d_5x5_3x2x128x128_st1),
("3x3_1x3x224x224_st2_pd1", conv2d_3x3_1x3x224x224_st2_pd1),
("two_conv2d_nobias", two_conv2d_nobias),
Expand Down