66
77# pyre-strict
88
9- from typing import Any , cast , Dict , Sequence , Tuple
9+ from typing import Any , cast , Dict , List , Optional , Sequence , Tuple , Type
1010
1111import torch
12+ import torch .fx
13+ import torch .utils ._pytree as pytree
14+ from executorch .backends .cadence .aot .pass_utils import (
15+ CadencePassAttribute ,
16+ create_cadence_pass_filter ,
17+ register_cadence_pass ,
18+ )
1219from executorch .backends .cadence .aot .utils import get_edge_overload_packet
20+ from executorch .backends .transforms .remove_clone_ops import RemoveCloneOpsTransform
1321from executorch .exir .dialects ._ops import ops as exir_ops
1422from executorch .exir .pass_base import ExportPass , NodeMetadata , PassResult , ProxyValue
23+ from executorch .exir .pass_manager import PassManager , PassType
1524from executorch .exir .passes import dead_code_elimination_pass
25+ from executorch .exir .passes .scalar_to_tensor_pass import ScalarToTensorPass
1626from executorch .exir .passes .spec_prop_pass import SpecPropPass
1727from torch ._subclasses import FakeTensor
1828from torch .utils ._pytree import tree_map_only
1929
30+
31+ @register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
32+ class RemoveCloneOpsTransformImported (ExportPass ):
33+ def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
34+ finalize_passes : List [PassType ] = [
35+ RemoveCloneOpsTransform (),
36+ ]
37+ result = PassManager (passes = finalize_passes )(graph_module )
38+ dead_code_elimination_pass (result .graph_module )
39+ return result
40+
41+
42+ @register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
43+ class InitializePipeline (ExportPass ):
44+ """
45+ Initialize the Jarvis pipeline. This should invariably be the first pass to
46+ run.
47+ """
48+
49+ def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
50+ dead_code_elimination_pass (graph_module )
51+ result = SpecPropPass ()(graph_module )
52+ assert result is not None
53+ return result
54+
55+
56+ @register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
57+ class FinalizePipeline (ExportPass ):
58+ """
59+ The final cleanup pass after running the Jarvis pipeline.
60+ """
61+
62+ def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
63+ finalize_passes : List [PassType ] = [
64+ ScalarToTensorPass (),
65+ SpecPropPass (),
66+ ]
67+ result = PassManager (passes = finalize_passes )(graph_module )
68+ dead_code_elimination_pass (result .graph_module )
69+ return result
70+
71+
2072# Similar to what's done in executorch/exir/pass_base.py
2173Argument = Any # pyre-ignore
2274
2375
76+ @register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
2477class ReplacePT2QuantWithCadenceQuantPass (ExportPass ):
2578 """
2679 Replace the pt2 quantization ops with custom cadence quantization ops.
@@ -44,6 +97,7 @@ def call_operator(
4497 )
4598
4699
100+ @register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
47101class ReplacePT2DequantWithCadenceDequantPass (ExportPass ):
48102 """
49103 Replace the pt2 dequantization ops with custom cadence dequantization ops.
@@ -67,6 +121,7 @@ def call_operator(
67121 )
68122
69123
124+ @register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
70125class ReplaceScalarTensorWithFullPass (ExportPass ):
71126 """
72127 aten.scalar_tensor can be replaced by aten.full with a shape of [1].
@@ -96,6 +151,7 @@ def call_operator(
96151 )
97152
98153
154+ @register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
99155class ReplaceSqueezeAndUnsqueezeWithViewPass (ExportPass ):
100156 """
101157 When the shape is static, replace squeeze_copy and unsqueeze_copy ops with
@@ -131,7 +187,8 @@ def call_operator(
131187 )
132188
133189
134- class RemoveZeroSizedCatArgsPass (ExportPass ):
190+ @register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
191+ class RemoveZeroSizedCatArgsPass (ExportPass ): # is this the latest?
135192 def call_operator (
136193 self ,
137194 op , # pyre-ignore
@@ -176,6 +233,7 @@ def call_operator(
176233 return super ().call_operator (op , args , kwargs , meta )
177234
178235
236+ @register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
179237class RemoveNopExpandOpPass (ExportPass ):
180238 """
181239 For an expand op, if the operator shape matches the expand shape, then the
@@ -205,6 +263,7 @@ def call_operator(
205263 return super ().call_operator (op , args , kwargs , meta )
206264
207265
266+ @register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
208267class ReplaceLogicalNotBooleanWhereWithWherePass (ExportPass ):
209268 """
210269 A where op with a logical_not and a boolean tensor can be replaced
@@ -255,20 +314,8 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
255314 return result
256315
257316
258- class InitializePipeline (ExportPass ):
259- """
260- Initialize the Jarvis pipeline. This should invariably be the first pass to
261- run.
262- """
263-
264- def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
265- dead_code_elimination_pass (graph_module )
266- result = SpecPropPass ()(graph_module )
267- assert result is not None
268- return result
269-
270-
271- class ReplaceSafeSoftmaxWithSoftmax (ExportPass ):
317+ @register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
318+ class ReplaceSafeSoftmaxWithSoftmax (ExportPass ): # keep
272319 """
273320 Replace _safe_softmax with _softmax
274321 """
@@ -292,3 +339,33 @@ def call_operator(
292339 kwargs ,
293340 meta ,
294341 )
342+
343+
344+ def get_passes_in_default_order () -> List [Type [PassType ]]:
345+ passes = [
346+ InitializePipeline ,
347+ RemoveZeroSizedCatArgsPass ,
348+ ReplaceLogicalNotBooleanWhereWithWherePass ,
349+ ReplaceScalarTensorWithFullPass ,
350+ RemoveCloneOpsTransformImported ,
351+ RemoveNopExpandOpPass ,
352+ ReplaceSqueezeAndUnsqueezeWithViewPass ,
353+ ReplacePT2QuantWithCadenceQuantPass ,
354+ ReplacePT2DequantWithCadenceDequantPass ,
355+ # TODO: add the rest of the passes here.
356+ ]
357+ return pytree .tree_flatten (passes )[0 ]
358+
359+
360+ def get_cadence_passes (
361+ opt_level : int ,
362+ ) -> List [Optional [PassResult ]]:
363+ passes = get_passes_in_default_order ()
364+ pass_filter = create_cadence_pass_filter (opt_level )
365+ filtered_passes = [
366+ # pyre-fixme[20]: Call `torch.fx.passes.infra.pass_base.PassBase.__call__` expects argument `graph_module`.
367+ filtered_pass ()
368+ # pyre-fixme[6]: In call `filter.__new__` ... got `List[Type[typing.Callable[[GraphModule], Optional[PassResult]]]]`.
369+ for filtered_pass in list (filter (pass_filter , passes ))
370+ ]
371+ return filtered_passes
0 commit comments