Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions backends/xnnpack/partition/graphs/bilinear_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,15 @@ def forward(self, x):
]
for align_corners in [True, False]:
for config in capture_configs:
edge = exir.capture(
bilinear2d(align_corners), sample_inputs, config
).to_edge(
config=get_xnnpack_edge_compile_config(),
)
_bilinear2d_graphs[edge.exported_program.graph_module] = align_corners
for skip_dim_order_flag in [True, False]:
edge = exir.capture(
bilinear2d(align_corners), sample_inputs, config
).to_edge(
config=get_xnnpack_edge_compile_config(
skip_dim_order=skip_dim_order_flag
)
)
_bilinear2d_graphs[edge.exported_program.graph_module] = align_corners
return _bilinear2d_graphs


Expand Down
3 changes: 3 additions & 0 deletions backends/xnnpack/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from executorch.exir.pass_base import ExportPass

from executorch.exir.passes.const_prop_pass import ConstPropPass
from executorch.exir.passes.memory_format_ops_pass import DimOrderOpsRevertPass

from executorch.exir.program._program import _transform
from torch._export.pass_base import PassType
Expand All @@ -50,6 +51,8 @@ def __init__(
if not passes:
# All the XNNPACK passes
self.passes = [
# TODO - remove this pass once we have a better support for dim_order ops lowering
DimOrderOpsRevertPass,
ConvertToUpsampleBilinear2d,
ConvertToLinearPass,
ConvertToSDPAPass,
Expand Down
7 changes: 4 additions & 3 deletions backends/xnnpack/test/ops/bilinear2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,15 @@ def forward(self, x):
)
return a

# Since we may or may not enable dim order, use these ops only for
# check_not since we have `to_copy` and `to_dim_order_copy` in the list.
ops = {
"executorch_exir_dialects_edge__ops_aten_sub_Tensor",
"executorch_exir_dialects_edge__ops_aten_mul_Tensor",
"executorch_exir_dialects_edge__ops_aten_index_Tensor",
"executorch_exir_dialects_edge__ops_aten_arange_start_step",
"executorch_exir_dialects_edge__ops_aten__to_copy_default",
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default",
"executorch_exir_dialects_edge__ops_aten_add_Tensor",
"executorch_exir_dialects_edge__ops_aten_clamp_default",
}
Expand All @@ -81,7 +84,6 @@ def test_fp32_static_resize_bilinear2d(self):
Tester(self.StaticResizeBilinear2dModule(), example_inputs)
.export()
.to_edge()
.check(self.ops)
.partition()
.check_not(self.ops)
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
Expand All @@ -90,13 +92,12 @@ def test_fp32_static_resize_bilinear2d(self):
.run_method_and_compare_outputs()
)

def test_fp32_static_resize_bilinear2d_with_align_cornesr(self):
def test_fp32_static_resize_bilinear2d_with_align_corners(self):
example_inputs = (torch.randn(2, 3, 4, 5),)
(
Tester(self.StaticResizeBilinear2dModuleWithAlignCorners(), example_inputs)
.export()
.to_edge()
.check(self.ops)
.partition()
.check_not(self.ops)
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
Expand Down
8 changes: 6 additions & 2 deletions backends/xnnpack/utils/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@


### XNNPACK Configs ###
def get_xnnpack_edge_compile_config() -> exir.EdgeCompileConfig:
return exir.EdgeCompileConfig(_check_ir_validity=False, _skip_dim_order=True)
def get_xnnpack_edge_compile_config(
skip_dim_order: bool = True,
) -> exir.EdgeCompileConfig:
return exir.EdgeCompileConfig(
_check_ir_validity=False, _skip_dim_order=skip_dim_order
)


def get_transform_passes(additional_passes=None) -> List[PassType]:
Expand Down
12 changes: 12 additions & 0 deletions exir/passes/dim_order_ops_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,15 @@ def _to_dim_order_copy_out_impl(*args, **kwargs):
DimOrderOpsMap = {
"aten._to_copy.default": exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
}

"""
Defines a map of aten or edge ops to the corresponding memory format ops for quick lookup
"""
MemoryFormatOpsMap = {
"dim_order_ops._to_dim_order_copy.default": exir_ops.edge.aten._to_copy.default,
}

# If we are replacing an aten op with a dim_order op, we must have a 1:1 mapping through these dicts.
assert len(DimOrderOpsMap) == len(MemoryFormatOpsMap)

# TODO stricter check for 1:1 mapping
60 changes: 57 additions & 3 deletions exir/passes/memory_format_ops_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,19 @@

import torch
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.dim_order_utils import get_dim_order
from executorch.exir.dim_order_utils import get_dim_order, get_memory_format
from executorch.exir.pass_base import ExportPass, ProxyValue
from executorch.exir.passes.dim_order_ops_registry import DimOrderOpsMap
from executorch.exir.passes.dim_order_ops_registry import (
DimOrderOpsMap,
MemoryFormatOpsMap,
)

logger = logging.getLogger(__file__)
logger.setLevel(logging.INFO)

# TODO - these passes are too specialized on a single to_copy op.
# We should be able to replace (or revert) any of the dim_order ops in the future.


class MemoryFormatOpsPass(ExportPass):
"""
Expand Down Expand Up @@ -53,7 +59,55 @@ def call_operator(self, op, args, kwargs, meta):
f" _to_dim_order_copy = dim_order: {nkwargs['dim_order']}"
)

t = DimOrderOpsMap[op.__name__]
t = DimOrderOpsMap.get(op.__name__, None)
assert t is not None, f"{op.__name__} not found in DimOrderOpsMap"

return super().call_operator(
t,
args,
nkwargs,
meta,
)


class DimOrderOpsRevertPass(ExportPass):
"""
This pass is to revert the dim_order ops back to the memory format ops.
"""

def call_operator(self, op, args, kwargs, meta):
if not (isinstance(op, EdgeOpOverload) and op.__name__ in MemoryFormatOpsMap):
return super().call_operator(
op,
args,
kwargs,
meta,
)

# new kwargs with dim_order, and no memory_format for the new op
nkwargs = dict(copy.deepcopy(kwargs)) # orig kwargs are immutable

# can always get the shape, assuming rank is specialized
if isinstance(args[0], ProxyValue) and args[0].is_tensor():
ndim = args[0].to_tensor().dim()
elif isinstance(args[0], torch.Tensor):
ndim = args[0].dim()
else:
assert 0, f"Expecting a Tensor or a ProxyValue buy got {type(args[0])}"

# get the "to" memory format for the EdgeOp
default_dim_order = list(range(ndim))
dim_order = nkwargs.pop("dim_order", default_dim_order)

nkwargs["memory_format"] = get_memory_format(dim_order)

logger.debug(
f" _to_dim_order_copy = dim_order: {dim_order}."
f"_to_copy = rank: {ndim}, memory_format: {nkwargs['memory_format']}."
)

t = MemoryFormatOpsMap.get(op.__name__, None)
assert t is not None, f"{op.__name__} not found in MemoryFormatOpsMap"

return super().call_operator(
t,
Expand Down
100 changes: 99 additions & 1 deletion exir/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,11 @@
from executorch.exir.passes.insert_write_back_for_buffers_pass import (
insert_write_back_for_buffers_pass,
)

from executorch.exir.passes.memory_format_ops_pass import DimOrderOpsRevertPass
from executorch.exir.passes.normalize_view_copy_base_pass import (
NormalizeViewCopyBasePass,
)

from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass
from executorch.exir.passes.remove_mixed_type_operators import RemoveMixedTypeOperators
from executorch.exir.passes.replace_edge_with_backend_pass import EdgeToBackendOpsPass
Expand Down Expand Up @@ -1676,3 +1677,100 @@ def forward(self, text_tokens):
)
new_ep = constant_prop_pass(edge_manager._edge_programs["forward"])
_ = copy.deepcopy(new_ep.module_call_graph)

def test_dim_order_revert_pass(self) -> None:
aten_op_str = "torch.ops.aten._to_copy.default"
edge_aten_op_str = "executorch_exir_dialects_edge__ops_aten__to_copy_default"
edge_dim_order_op_str = "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default"

class Module(torch.nn.Module):
"""
A simple module that has a single to op that converts to channels last and then back to contiguous.
Assuming contiguous input.
"""

def __init__(self):
super().__init__()

def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.to(memory_format=torch.channels_last).to(
memory_format=torch.contiguous_format
) + x.to(memory_format=torch.channels_last).to(
memory_format=torch.contiguous_format
)

@staticmethod
def to_copy_count():
return 4

def _do_checks(
test_str: str, allowed: str, allowed_count: int, not_allowed_list: List[str]
) -> None:
for not_allowed in not_allowed_list:
FileCheck().check_count(allowed, allowed_count, exactly=True).check_not(
not_allowed
).run(test_str)

m = Module()
n = m.to_copy_count()
input = torch.randn([2, 3, 4, 5]).to(memory_format=torch.contiguous_format)

# 1. vanilla export, no edge ops
ep = export(
m,
(input,),
)
_do_checks(
ep.graph_module.code,
aten_op_str,
n,
[edge_aten_op_str, edge_dim_order_op_str],
)

# 2a. to edge without dim orders, we should see edge aten ops but not dim order ops
edge_prog = to_edge(
ep, compile_config=exir.EdgeCompileConfig(_skip_dim_order=True)
)._edge_programs["forward"]
_do_checks(
edge_prog.graph_module.code,
edge_aten_op_str,
n,
[aten_op_str, edge_dim_order_op_str],
)

# 3a. expect no change after the pass, we should see edge aten ops but not dim order ops
new_res = DimOrderOpsRevertPass()(edge_prog.graph_module)
self.assertIsNotNone(new_res)
_do_checks(
new_res.graph_module.code,
edge_aten_op_str,
n,
[aten_op_str, edge_dim_order_op_str],
)

# 2b. let's try with dim order enabled, we should see edge dim order ops but not edge aten ops
edge_prog_dim_order = to_edge(
ep, compile_config=exir.EdgeCompileConfig(_skip_dim_order=False)
)._edge_programs["forward"]
_do_checks(
edge_prog_dim_order.graph_module.code,
edge_dim_order_op_str,
n,
[aten_op_str, edge_aten_op_str],
)

# 3b. expect edge aten ops after the pass, we should see not see the edge dim order ops
new_res_dim_order = DimOrderOpsRevertPass()(edge_prog_dim_order.graph_module)
self.assertIsNotNone(new_res_dim_order)
_do_checks(
new_res_dim_order.graph_module.code,
edge_aten_op_str,
n,
[aten_op_str, edge_dim_order_op_str],
)

output_no_dim_order = new_res.graph_module(input)
output_no_dim_order_revert = new_res_dim_order.graph_module(input)
self.assertTrue(
torch.allclose(output_no_dim_order[0], output_no_dim_order_revert[0])
)