1212
1313import executorch .backends .cadence .aot .ops_registrations # noqa
1414import torch
15+ from executorch .backends .cadence .aot .memory_planning import (
16+ CadenceMemoryPlanning ,
17+ print_memory_planning_info ,
18+ )
1519from executorch .backends .cadence .aot .quantizer .fusion_pass import QuantFusion
1620from executorch .backends .cadence .aot .quantizer .quantizer import CadenceQuantizer
1721
1822from executorch .backends .cadence .aot .replace_ops import ReplaceSafeSoftmaxWithSoftmax
19- from executorch .backends .cadence .aot .utils import model_gm_has_SDPA , model_is_quantized
23+ from executorch .backends .cadence .aot .utils import (
24+ get_default_memory_config ,
25+ MemoryConfig ,
26+ model_gm_has_SDPA ,
27+ model_is_quantized ,
28+ )
2029from executorch .backends .transforms .decompose_sdpa import (
2130 DecomposeScaledDotProductAttention ,
2231)
2332from executorch .devtools import generate_etrecord
2433from executorch .exir import (
2534 EdgeCompileConfig ,
2635 EdgeProgramManager ,
36+ ExecutorchBackendConfig ,
2737 ExecutorchProgramManager ,
2838 to_edge ,
2939)
3040from executorch .exir .pass_base import PassResult
41+ from executorch .exir .passes import ToOutVarPass
42+ from executorch .exir .passes .sym_shape_eval_pass import HintBasedSymShapeEvalPass
3143from torch ._inductor .decomposition import remove_decompositions
3244from torch .ao .quantization .pt2e .export_utils import model_is_exported
3345from torch .ao .quantization .quantize_pt2e import convert_pt2e , prepare_pt2e
@@ -263,6 +275,10 @@ def export_to_executorch_gen_etrecord(
263275 inputs : tuple [object , ...],
264276 output_dir : Optional [str ] = None ,
265277 opt_level : int = 1 ,
278+ mem_algo : int = 0 ,
279+ alloc_graph_input : bool = True ,
280+ alloc_graph_output : bool = True ,
281+ memory_config : Optional [MemoryConfig ] = None ,
266282 dump_graphs : bool = False ,
267283) -> ExecutorchProgramManager :
268284 cadence_passes = get_cadence_passes (opt_level )
@@ -281,8 +297,36 @@ def export_to_executorch_gen_etrecord(
281297 cadence_prog_manager .exported_program ().graph_module ,
282298 )
283299
300+ if memory_config is None :
301+ memory_config = get_default_memory_config ()
302+
303+ memory_planning_pass = CadenceMemoryPlanning (
304+ memory_config ,
305+ opt_level = opt_level ,
306+ mem_algo = mem_algo ,
307+ alloc_graph_input = alloc_graph_input ,
308+ alloc_graph_output = alloc_graph_output ,
309+ )
310+
284311 # Get executorch program after Cadence specific passes
285- exec_prog : ExecutorchProgramManager = cadence_prog_manager .to_executorch ()
312+ exec_prog : ExecutorchProgramManager = cadence_prog_manager .to_executorch (
313+ ExecutorchBackendConfig (
314+ memory_planning_pass = memory_planning_pass ,
315+ emit_stacktrace = False ,
316+ to_out_var_pass = ToOutVarPass (),
317+ extract_delegate_segments = False ,
318+ sym_shape_eval_pass = HintBasedSymShapeEvalPass (),
319+ ),
320+ )
321+
322+ print_memory_planning_info (
323+ exec_prog ,
324+ memory_config ,
325+ opt_level ,
326+ alloc_graph_input ,
327+ alloc_graph_output ,
328+ )
329+
286330 if output_dir :
287331 _gen_etrecord (edge_prog_manager , exec_prog , Path (output_dir ))
288332 else :
0 commit comments