diff --git a/backends/apple/mps/mps_preprocess.py b/backends/apple/mps/mps_preprocess.py index 8362774fa97..749f32a04e5 100644 --- a/backends/apple/mps/mps_preprocess.py +++ b/backends/apple/mps/mps_preprocess.py @@ -32,6 +32,9 @@ CompileSpec, PreprocessResult, ) + +from executorch.exir.passes.memory_format_ops_pass import DimOrderOpsRevertPass +from executorch.exir.program._program import _transform from torch.export.exported_program import ExportedProgram FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" @@ -83,6 +86,9 @@ def preprocess( # FlatBuffer graph, process the `output` nodes and add their id to # the `output_ids` array in the schema. + # TODO: Remove this once we have a better support for the dim-order ops. + edge_program = _transform(edge_program, DimOrderOpsRevertPass()) + mps_graph = MPSGraph( version="0", mps_nodes=[], diff --git a/backends/apple/mps/operators/constant_ops.py b/backends/apple/mps/operators/constant_ops.py index dacb09215cb..f8dcfc66a0c 100644 --- a/backends/apple/mps/operators/constant_ops.py +++ b/backends/apple/mps/operators/constant_ops.py @@ -79,6 +79,25 @@ def define_node( ) +@register_node_visitor +class ToDimOrderEmptyVisitor(NodeVisitor): + target = ["dim_order_ops._empty_dim_order.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + # We should never get here, because DimOrderOpsRevertPass replaces this with an aten.empty.memory_format op + # But if we do, we can't handle it ATM, so raise an exception + raise NotImplementedError( + "dim_order_ops._empty_dim_order.default is not supported yet" + ) + + @register_node_visitor class FullLikeVisitor(NodeVisitor): target = "aten.full_like.default" diff --git a/backends/apple/mps/operators/op_clone.py b/backends/apple/mps/operators/op_clone.py index 2310ae02da7..0d0b7f53633 100644 --- a/backends/apple/mps/operators/op_clone.py +++ b/backends/apple/mps/operators/op_clone.py @@ -33,3 +33,22 @@ def define_node( ) input_id = self.define_tensor(get_input_node(node, 0), mps_graph) self.tensor_to_id[node] = input_id + + +@register_node_visitor +class ToDimOrderCopyVisitor(NodeVisitor): + target = ["dim_order_ops._to_dim_order_copy.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + mps_graph: MPSGraph, + ) -> None: + # We should never get here, because DimOrderOpsRevertPass replaces this with an aten._to_copy op + # But if we do, we can't handle it ATM, so raise an exception + raise NotImplementedError( + "dim_order_ops._to_dim_order_copy.default is not supported yet" + ) diff --git a/backends/apple/mps/test/test_mps.py b/backends/apple/mps/test/test_mps.py index fe64a30f3ce..a981d7ab8eb 100644 --- a/backends/apple/mps/test/test_mps.py +++ b/backends/apple/mps/test/test_mps.py @@ -1829,6 +1829,21 @@ def forward(self, x): Clone(), model_inputs, func_name=inspect.stack()[0].function[5:] ) + def test_mps_backend_to_copy(self): + class Copy(torch.nn.Module): + def forward(self, x): + return ( + torch.ops.aten._to_copy.default( + x + 2, memory_format=torch.contiguous_format + ) + + x + ) + + model_inputs = (torch.randn(1, 3, 3),) + self.lower_and_test_with_partitioner( + Copy(), model_inputs, func_name=inspect.stack()[0].function[5:] + ) + def test_mps_backend_floor(self): class Floor(torch.nn.Module): def forward(self, x): diff --git a/backends/apple/mps/test/test_mps_utils.py b/backends/apple/mps/test/test_mps_utils.py index 39ce5df5115..36f9b229e5c 100644 --- a/backends/apple/mps/test/test_mps_utils.py +++ b/backends/apple/mps/test/test_mps_utils.py @@ -26,10 +26,7 @@ # Config for Capturing the weights, will be moved in the future -# TODO(T182928844): Delegate dim order op to backend. -_EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig( - _check_ir_validity=False, _skip_dim_order=True -) +_EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig(_check_ir_validity=False) class ansi_colors: @@ -219,7 +216,6 @@ def lower_module_and_test_output( dynamic_shapes=dynamic_shapes, edge_compile_config=EdgeCompileConfig( _check_ir_validity=False, - _skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend. ), ) @@ -253,7 +249,6 @@ def lower_module_and_test_output( ), compile_config=exir.EdgeCompileConfig( _check_ir_validity=False, - _skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend. ), ).to_executorch( config=ExecutorchBackendConfig(extract_delegate_segments=False)