Skip to content

Commit 9890c24

Browse files
pytorchbotSS-JIA
andauthored
[ET-VK] Introduce AOT operator registry (#6511)
Pull Request resolved: #6488 ## Changes Move the following files to the root directory of Vulkan backend: * `backends/vulkan/partitioner/supported_ops.py` -> `backends/vulkan/op_registry.py` * `backends/vulkan/_passes/custom_ops_defs.py` -> `backends/vulkan/custom_ops_lib.py` In the new `op_registry.py` file, the way operator features are specified is reworked to provide much more detail about the features of the operator implementation in Vulkan. See the new `OpFeatures` class for more details. An example of registering a new operator to the export flow is ``` @update_features( [ exir_ops.edge.aten._log_softmax.default, exir_ops.edge.aten._softmax.default, exir_ops.edge.aten.mean.dim, exir_ops.edge.aten.sum.dim_IntList, exir_ops.edge.aten.amax.default, exir_ops.edge.aten.amin.default, ] ) def register_reduce_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( uses_packed_dim=True, ) features.resize_fn = True def check_reduce_node(node: torch.fx.Node) -> bool: dim_list = node.args[1] assert isinstance(dim_list, list) if len(dim_list) != 1: return False keepdim = node.args[2] assert isinstance(keepdim, bool) if not keepdim: return False return True features.check_node_fn = check_reduce_node return features ``` ## Rationale The purpose of these changes is to centralize operator definitions so that there is a common source of truth about the capabilities of operator implementation in Vulkan. This way, the partitioner does not have to implement ad-hoc functions for specific operators (i.e. `is_valid_to_copy`) and graph transforms do not have to maintain their own operator metadata (`USES_WEIGHTS` in `insert_prepack_nodes`). ghstack-source-id: 250279709 @exported-using-ghexport Differential Revision: [D64915640](https://our.internmc.facebook.com/intern/diff/D64915640/) Co-authored-by: Stephen Jia <[email protected]>
1 parent a3579e9 commit 9890c24

File tree

19 files changed

+599
-457
lines changed

19 files changed

+599
-457
lines changed

backends/transforms/fuse_conv_with_clamp.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66

77
import sys
88

9+
import executorch.backends.vulkan.custom_ops_lib # noqa
10+
911
import torch
10-
from executorch.backends.vulkan._passes.custom_ops_defs import ( # noqa
11-
conv_with_clamp_op,
12-
)
1312

1413
from executorch.exir.dialects._ops import ops as exir_ops
1514
from executorch.exir.pass_base import ExportPass, PassResult

backends/transforms/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def define_common_targets():
7070
deps = [
7171
":utils",
7272
"//caffe2:torch",
73-
"//executorch/backends/vulkan/_passes:custom_ops_defs",
73+
"//executorch/backends/vulkan:custom_ops_lib",
7474
"//executorch/exir:pass_base",
7575
"//executorch/exir:sym_util",
7676
"//executorch/exir/dialects:lib",

backends/vulkan/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,14 @@ set(vulkan_standard_shaders_cpp ${generated_spv_cpp})
8383
set(SCHEMA_INCLUDE_DIR ${CMAKE_BINARY_DIR}/schema/include)
8484

8585
set(GENERATED_HEADER
86-
${SCHEMA_INCLUDE_DIR}/executorch/backends/vulkan/schema_generated.h
86+
${SCHEMA_INCLUDE_DIR}/executorch/backends/vulkan/serialization/schema_generated.h
8787
)
8888

8989
add_custom_command(
9090
OUTPUT ${GENERATED_HEADER}
9191
COMMAND
9292
${FLATC_EXECUTABLE} --cpp --cpp-std c++11 --scoped-enums -o
93-
"${SCHEMA_INCLUDE_DIR}/executorch/backends/vulkan/" ${_vulkan_schema__srcs}
93+
"${SCHEMA_INCLUDE_DIR}/executorch/backends/vulkan/serialization/" ${_vulkan_schema__srcs}
9494
WORKING_DIRECTORY ${EXECUTORCH_ROOT}
9595
COMMENT "Generating vulkan_schema headers"
9696
VERBATIM

backends/vulkan/TARGETS

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,4 @@
1-
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
21
load(":targets.bzl", "define_common_targets")
3-
42
oncall("executorch")
53

64
define_common_targets(is_fbcode = True)
7-
8-
runtime.python_library(
9-
name = "vulkan_preprocess",
10-
srcs = [
11-
"serialization/vulkan_graph_builder.py",
12-
"serialization/vulkan_graph_schema.py",
13-
"serialization/vulkan_graph_serialize.py",
14-
"vulkan_preprocess.py",
15-
],
16-
resources = [
17-
"serialization/schema.fbs",
18-
],
19-
visibility = [
20-
"//executorch/...",
21-
"//executorch/vulkan/...",
22-
"@EXECUTORCH_CLIENTS",
23-
],
24-
deps = [
25-
"//executorch/backends/transforms:addmm_mm_to_linear",
26-
"//executorch/backends/transforms:fuse_batch_norm_with_conv",
27-
"//executorch/backends/transforms:fuse_conv_with_clamp",
28-
"//executorch/backends/transforms:fuse_dequant_linear",
29-
"//executorch/backends/transforms:fuse_view_copy",
30-
"//executorch/backends/transforms:remove_clone_ops",
31-
"//executorch/backends/vulkan/_passes:vulkan_passes",
32-
"//executorch/exir:graph_module",
33-
"//executorch/exir/_serialize:_bindings",
34-
"//executorch/exir/_serialize:lib",
35-
"//executorch/exir/backend:backend_details",
36-
],
37-
)

backends/vulkan/_passes/TARGETS

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,6 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
33

44
oncall("executorch")
55

6-
runtime.python_library(
7-
name = "custom_ops_defs",
8-
srcs = [
9-
"custom_ops_defs.py",
10-
],
11-
visibility = [
12-
"//executorch/...",
13-
"@EXECUTORCH_CLIENTS",
14-
],
15-
deps = [
16-
"//caffe2:torch",
17-
],
18-
)
19-
20-
python_unittest(
21-
name = "test_custom_ops",
22-
srcs = [
23-
"test_custom_ops.py",
24-
],
25-
deps = [
26-
":custom_ops_defs",
27-
"//caffe2:torch",
28-
],
29-
)
30-
316
runtime.python_library(
327
name = "insert_prepack_nodes",
338
srcs = ["insert_prepack_nodes.py"],
@@ -62,7 +37,7 @@ runtime.python_library(
6237
"//executorch/backends/...",
6338
],
6439
deps = [
65-
":custom_ops_defs",
40+
"//executorch/backends/vulkan:custom_ops_lib",
6641
"//pytorch/ao:torchao",
6742
]
6843
)

backends/vulkan/_passes/insert_prepack_nodes.py

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,39 +6,27 @@
66

77
# pyre-strict
88

9-
from typing import List
10-
11-
import executorch.backends.vulkan._passes.custom_ops_defs # noqa
9+
import executorch.backends.vulkan.custom_ops_lib # noqa
1210

1311
import torch
1412

13+
from executorch.backends.vulkan.op_registry import handles_own_prepacking
14+
1515
from executorch.exir.dialects._ops import ops as exir_ops
1616

1717
from torch._export.utils import is_buffer, is_param
1818
from torch.export import ExportedProgram
1919

20-
USES_WEIGHTS: List[torch._ops.OpOverload] = [
21-
exir_ops.edge.aten.embedding.default,
22-
exir_ops.edge.aten.convolution.default,
23-
exir_ops.edge.et_vk.conv_with_clamp.default,
24-
exir_ops.edge.aten.linear.default,
25-
exir_ops.edge.aten._weight_int8pack_mm.default,
26-
exir_ops.edge.et_vk.linear_weight_int4.default,
27-
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
28-
exir_ops.edge.aten.native_layer_norm.default,
29-
"llama::sdpa_with_kv_cache",
30-
]
31-
3220

3321
def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram:
3422
"""
3523
Insert `et_vk.prepack` nodes for constant tensors in the graph. The prepack operator
3624
is responsible for transferring the tensor data, which is serialized with the model,
3725
to a GPU tensor object during the prepacking stage of model execution.
3826
39-
Some operators, listed in `USES_WEIGHTS` above, are performance sensitive and will
40-
prefer to handle prepacking within the operator. For these ops, the constant tensor
41-
data will be passed directly as an argument into the operator implementation.
27+
Some operators are performance sensitive and will prefer to handle prepacking within
28+
the operator. For these ops, the constant tensor data will be passed directly as an
29+
argument into the operator implementation.
4230
"""
4331

4432
def is_get_attr_node(node: torch.fx.Node) -> bool:
@@ -58,22 +46,21 @@ def is_param_node(node: torch.fx.Node) -> bool:
5846
or is_constant(node)
5947
)
6048

61-
def is_non_weight_param_tensor(node: torch.fx.Node) -> bool:
49+
def prepack_not_required(node: torch.fx.Node) -> bool:
6250
if not is_param_node(node):
63-
return False
51+
return True
6452

6553
for user in node.users:
66-
if user.op == "call_function" and (
67-
# pyre-ignore [16]
68-
user.target in USES_WEIGHTS
69-
or user.target.name() in USES_WEIGHTS
54+
if user.op == "call_function" and handles_own_prepacking(
55+
# pyre-ignore
56+
user.target
7057
):
71-
return False
58+
return True
7259

73-
return True
60+
return False
7461

7562
for node in program.graph_module.graph.nodes:
76-
if not is_non_weight_param_tensor(node):
63+
if prepack_not_required(node):
7764
continue
7865

7966
with program.graph_module.graph.inserting_after(node):

backends/vulkan/_passes/int4_weight_only_quantizer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import logging
22
from typing import Any, Callable, Dict, Optional, Type
33

4+
import executorch.backends.vulkan.custom_ops_lib # noqa
5+
46
import torch
57
import torch.nn.functional as F
68

7-
from executorch.backends.vulkan._passes.custom_ops_defs import ( # noqa
8-
linear_weight_int4_op,
9-
)
10-
119
from torchao.quantization.GPTQ import _check_linear_int4_k
1210
from torchao.quantization.unified import Quantizer
1311
from torchao.quantization.utils import groupwise_affine_quantize_tensor

backends/vulkan/_passes/test_custom_ops.py

Lines changed: 0 additions & 124 deletions
This file was deleted.
File renamed without changes.

0 commit comments

Comments
 (0)