66
77# pyre-strict
88
9- from typing import List
10-
11- import executorch .backends .vulkan ._passes .custom_ops_defs # noqa
9+ import executorch .backends .vulkan .custom_ops_lib # noqa
1210
1311import torch
1412
13+ from executorch .backends .vulkan .op_registry import handles_own_prepacking
14+
1515from executorch .exir .dialects ._ops import ops as exir_ops
1616
1717from torch ._export .utils import is_buffer , is_param
1818from torch .export import ExportedProgram
1919
20- USES_WEIGHTS : List [torch ._ops .OpOverload ] = [
21- exir_ops .edge .aten .embedding .default ,
22- exir_ops .edge .aten .convolution .default ,
23- exir_ops .edge .et_vk .conv_with_clamp .default ,
24- exir_ops .edge .aten .linear .default ,
25- exir_ops .edge .aten ._weight_int8pack_mm .default ,
26- exir_ops .edge .et_vk .linear_weight_int4 .default ,
27- exir_ops .edge .aten ._native_batch_norm_legit_no_training .default ,
28- exir_ops .edge .aten .native_layer_norm .default ,
29- "llama::sdpa_with_kv_cache" ,
30- ]
31-
3220
3321def insert_prepack_nodes (program : ExportedProgram ) -> ExportedProgram :
3422 """
3523 Insert `et_vk.prepack` nodes for constant tensors in the graph. The prepack operator
3624 is responsible for transferring the tensor data, which is serialized with the model,
3725 to a GPU tensor object during the prepacking stage of model execution.
3826
39- Some operators, listed in `USES_WEIGHTS` above, are performance sensitive and will
40- prefer to handle prepacking within the operator. For these ops, the constant tensor
41- data will be passed directly as an argument into the operator implementation.
27+ Some operators are performance sensitive and will prefer to handle prepacking within
28+ the operator. For these ops, the constant tensor data will be passed directly as an
29+ argument into the operator implementation.
4230 """
4331
4432 def is_get_attr_node (node : torch .fx .Node ) -> bool :
@@ -58,22 +46,21 @@ def is_param_node(node: torch.fx.Node) -> bool:
5846 or is_constant (node )
5947 )
6048
61- def is_non_weight_param_tensor (node : torch .fx .Node ) -> bool :
49+ def prepack_not_required (node : torch .fx .Node ) -> bool :
6250 if not is_param_node (node ):
63- return False
51+ return True
6452
6553 for user in node .users :
66- if user .op == "call_function" and (
67- # pyre-ignore [16]
68- user .target in USES_WEIGHTS
69- or user .target .name () in USES_WEIGHTS
54+ if user .op == "call_function" and handles_own_prepacking (
55+ # pyre-ignore
56+ user .target
7057 ):
71- return False
58+ return True
7259
73- return True
60+ return False
7461
7562 for node in program .graph_module .graph .nodes :
76- if not is_non_weight_param_tensor (node ):
63+ if prepack_not_required (node ):
7764 continue
7865
7966 with program .graph_module .graph .inserting_after (node ):
0 commit comments