Skip to content

Commit acd154f

Browse files
author
ssjia
committed
[ET-VK] Miscellaneous fixes
Pull Request resolved: #14732 Collecting fixes for various models/ops in this diff/PR. They have all been squashed into this single change to make it easier to cherry pick. # Fixes ## Wav2Letter Type: Output correctness failure This is caused by a bug in swiftshader, and not reproducible on any other platform. Specifically, the issue is in the softmax shader; the exact cause of the issue is unknown, but it is related to using shared memory within shaders. The workaround for this issue is to use separate shared memory arrays for the shared max and shared sum. ## ConvNeXT Type: Exception during runtime This is caused by an incompatible memory layout being used for mean2d. More technically, the packed dimension of the tensor cannot be one of the dims being reduced. The current operator registry system did not have a way to select valid tensor representations based on the actual arguments of an op. To fix, we have to introduce a mechanism for ops to specify valid representations once a node's arguments are known. Once the model is exported with supported memory layout, the model test passes. ## Inception_V3/ViT Type: Exception during runtime The root cause of this was an interaction betwen the fuse batch norm pass and how `vulkan_preprocess.py` was applying passes. Essentially, the fuse batch norm pass creates a new param node for the fused weight, which uses the original name of the tensor which may contain capital letters. However after re-tracing the graph, the node's name was being lowercased. `vulkan_preprocess` was using _copy_module to update the exported program's graph module in place, which was not updating the ep's graph signature with the new lowercase name after retracing. The solution was to migrate vulkan_preprocess.py to use the _transform() API instead of using _copy_module. There was also a small bug in Pool.cpp where `bool` was used to pass a UBO field that is received as an `int`. ## DenseNet 161 (w/ dynamic shapes) Type: Output Mismatch Cause: the native_batch_norm op doesn't support dynamic shapes. However, the backend test runner doesn't set the correct compile option to filter ops without dynamic shape support. Since batch norm is easy to implement, fix by implementing resize for batch norm. ghstack-source-id: 313794474 Differential Revision: [D83703496](https://our.internmc.facebook.com/intern/diff/D83703496/)
1 parent 19258d2 commit acd154f

File tree

15 files changed

+253
-129
lines changed

15 files changed

+253
-129
lines changed

.github/workflows/pull.yml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -970,11 +970,16 @@ jobs:
970970
PYTHON_EXECUTABLE=python bash backends/vulkan/test/scripts/test_model.sh --build
971971
972972
# Test models serially
973-
models="mv2 mv3 edsr resnet18 resnet50 dl3"
973+
models="mv2 mv3 edsr resnet18 resnet50 dl3 w2l ic3 ic4 convnext_small vit"
974974
for model in $models; do
975975
python -m examples.vulkan.export --model_name=$model --test
976976
done
977977
978+
# For selected vision models, test with dynamic shapes
979+
models="mv2 mv3 resnet18 resnet50 ic3 ic4 densenet161"
980+
for model in $models; do
981+
python -m examples.vulkan.export --model_name=$model --test -d
982+
done
978983
979984
test-vulkan-operators-linux:
980985
name: test-vulkan-operators-linux

backends/vulkan/_passes/fold_qdq.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@ class FoldQDQPass(ExportPass):
1717
valid quant op patterns have already been fused before this pass.
1818
"""
1919

20-
def __init__(self, edge_program: torch.export.ExportedProgram):
21-
super(FoldQDQPass, self).__init__()
22-
self.edge_program = edge_program
20+
def __init__(self):
21+
super().__init__()
2322

2423
def call(self, graph_module: torch.fx.GraphModule):
2524
for node in graph_module.graph.nodes:

backends/vulkan/_passes/fuse_patterns.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from typing import Optional
8+
79
import executorch.backends.vulkan.patterns as vk_patterns
810

911
import torch
@@ -13,13 +15,15 @@
1315

1416

1517
class FusePatternsPass(ExportPass):
16-
def __init__(self, exported_program: ExportedProgram) -> None:
18+
def __init__(self) -> None:
1719
super().__init__()
18-
self.program = exported_program
20+
self._exported_program: Optional[ExportedProgram] = None
1921

2022
def call(self, graph_module: torch.fx.GraphModule):
23+
assert self._exported_program is not None
24+
2125
total_replaced = vk_patterns.replace_all_fusable_subgraphs(
22-
self.program, graph_module
26+
self._exported_program, graph_module
2327
)
2428

2529
if total_replaced > 0:

backends/vulkan/_passes/fuse_quantized_ops.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,18 +211,20 @@ def fuse_into_linear_qcnw_node(
211211

212212

213213
class FuseQuantizedOpsTransform(ExportPass):
214-
def __init__(self, exported_program: ExportedProgram) -> None:
214+
def __init__(self) -> None:
215215
super().__init__()
216-
self.program = exported_program
216+
self._exported_program: Optional[ExportedProgram] = None
217217

218218
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
219+
assert self._exported_program is not None
220+
219221
for node in graph_module.graph.nodes:
220222
# Check for linear_qcnw pattern (weight-only quantization)
221-
qcnw_details = matches_linear_qcnw_pattern(self.program, node)
223+
qcnw_details = matches_linear_qcnw_pattern(self._exported_program, node)
222224
if qcnw_details is not None:
223225
qcnw_method, qcnw_nbits = qcnw_details
224226
fuse_into_linear_qcnw_node(
225-
self.program, graph_module, node, qcnw_method, qcnw_nbits
227+
self._exported_program, graph_module, node, qcnw_method, qcnw_nbits
226228
)
227229
continue
228230

backends/vulkan/op_registry.py

Lines changed: 52 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616

1717
import torch
1818

19-
from executorch.backends.vulkan.serialization.vulkan_graph_schema import VkMemoryLayout
20-
2119
from executorch.exir.dialects._ops import ops as exir_ops
2220

2321
from executorch.exir.dialects.edge._ops import EdgeOpOverload
@@ -48,6 +46,9 @@ class OpFeatures:
4846
# Optional check function used during partitioning to determine if a node's
4947
# inputs are supported by the operator implementation.
5048
"are_node_inputs_supported_fn",
49+
# Optional function to determine valid representation sets for input and outputs
50+
# once a node's actual inputs are known.
51+
"pick_io_storage_fn",
5152
]
5253

5354
def __init__(
@@ -61,6 +62,7 @@ def __init__(
6162
supports_resize: bool = False,
6263
supports_prepacking: bool = False,
6364
are_node_inputs_supported_fn: Optional[Callable] = allow_node,
65+
pick_io_storage_fn: Optional[Callable] = None,
6466
):
6567
self.inputs_storage: utils.TensorRepSetList = utils.TensorRepSetList(
6668
inputs_storage if inputs_storage is not None else []
@@ -77,15 +79,21 @@ def __init__(
7779
self.supports_prepacking = supports_prepacking
7880

7981
self.are_node_inputs_supported_fn = are_node_inputs_supported_fn
82+
self.pick_io_storage_fn = pick_io_storage_fn
8083

8184
def make_op_repsets(
8285
self,
8386
op_node: torch.fx.Node,
8487
texture_limits: utils.ImageExtents = utils.DEFAULT_TEXTURE_LIMITS,
8588
) -> utils.OpRepSets:
86-
return utils.OpRepSets(
87-
self.inputs_storage, self.outputs_storage, op_node, texture_limits
88-
)
89+
inputs_storage = self.inputs_storage
90+
outputs_storage = self.outputs_storage
91+
if self.pick_io_storage_fn is not None:
92+
i_storage, o_storage = self.pick_io_storage_fn(op_node)
93+
inputs_storage = utils.TensorRepSetList(i_storage)
94+
outputs_storage = utils.TensorRepSetList(o_storage)
95+
96+
return utils.OpRepSets(inputs_storage, outputs_storage, op_node, texture_limits)
8997

9098

9199
#######################
@@ -410,28 +418,16 @@ def register_softmax_op():
410418
)
411419
def register_reduce_op():
412420
def check_reduce_node(node: torch.fx.Node) -> bool:
421+
# Only one argument implies that the reduction is over the entire tensor, which
422+
# is not supported yet.
423+
if len(node.args) == 1:
424+
return False
425+
413426
dim_list = node.args[1]
427+
# Only 1D and 2D reductions are supported at the moment.
414428
if isinstance(dim_list, list) and len(dim_list) > 2:
415429
return False
416430

417-
if isinstance(dim_list, list) and len(dim_list) == 2:
418-
# Try to get the memory layout for this node
419-
try:
420-
memory_layout = utils.get_node_memory_layout(node)
421-
422-
# If we have memory layout information, check if any dimension in dim_list corresponds to a packed dimension
423-
if (
424-
memory_layout is not None
425-
and memory_layout != VkMemoryLayout.DEFAULT_LAYOUT
426-
):
427-
# For now only default layout is supported for 2D reduction.
428-
# Because we can't determine if the input is NCHW or NHWC here,
429-
# assume the reduction dimension is packed so we cannot support it.
430-
return False
431-
except (AssertionError, KeyError, AttributeError):
432-
# If we can't get memory layout information, we'll assume the dims aren't packed
433-
pass
434-
435431
def try_find_keepdim_arg(node: torch.fx.Node) -> bool:
436432
for arg in node.args:
437433
if isinstance(arg, bool):
@@ -446,10 +442,41 @@ def try_find_keepdim_arg(node: torch.fx.Node) -> bool:
446442

447443
return True
448444

445+
def pick_io_storage_for_reduce(node: torch.fx.Node):
446+
inputs_storage = utils.ANY_TEXTURE
447+
outputs_storage = utils.ANY_TEXTURE
448+
449+
input_tensor = node.args[0]
450+
ndim = input_tensor.meta["val"].ndim
451+
dim_list = node.args[1]
452+
if isinstance(dim_list, list) and len(dim_list) == 2:
453+
reduce_dim1_whcn = utils.nchw_dim_to_whcn_dim(dim_list[0], ndim)
454+
reduce_dim2_whcn = utils.nchw_dim_to_whcn_dim(dim_list[1], ndim)
455+
456+
possible_packed_dims = {0, 1, 2}
457+
possible_packed_dims.discard(reduce_dim1_whcn)
458+
possible_packed_dims.discard(reduce_dim2_whcn)
459+
460+
packed_dim = possible_packed_dims.pop()
461+
assert packed_dim in [0, 1, 2]
462+
463+
if packed_dim == 0:
464+
inputs_storage = utils.WIDTH_PACKED_TEXTURE
465+
outputs_storage = utils.WIDTH_PACKED_TEXTURE
466+
elif packed_dim == 1:
467+
inputs_storage = utils.HEIGHT_PACKED_TEXTURE
468+
outputs_storage = utils.HEIGHT_PACKED_TEXTURE
469+
else:
470+
inputs_storage = utils.CHANNELS_PACKED_TEXTURE
471+
outputs_storage = utils.CHANNELS_PACKED_TEXTURE
472+
473+
return inputs_storage, outputs_storage
474+
449475
return OpFeatures(
450476
inputs_storage=utils.ANY_TEXTURE,
451477
supports_resize=True,
452478
are_node_inputs_supported_fn=check_reduce_node,
479+
pick_io_storage_fn=pick_io_storage_for_reduce,
453480
)
454481

455482

@@ -716,6 +743,7 @@ def register_ported_ops_with_prepacking():
716743
return OpFeatures(
717744
inputs_storage=utils.CHANNELS_PACKED_TEXTURE,
718745
supports_prepacking=True,
746+
supports_resize=True,
719747
)
720748

721749

@@ -746,6 +774,7 @@ def register_ported_ops_with_prepacking_all_dims():
746774
return OpFeatures(
747775
inputs_storage=utils.ANY_TEXTURE,
748776
supports_prepacking=True,
777+
supports_resize=True,
749778
)
750779

751780

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
Partitioner,
3737
PartitionResult,
3838
)
39-
from executorch.exir.backend.utils import tag_constant_data
39+
from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer
4040
from executorch.exir.dialects._ops import ops as exir_ops
4141

4242
from torch.export.exported_program import ExportedProgram
@@ -254,9 +254,10 @@ def _is_node_supported(self, node: torch.fx.Node) -> bool: # noqa: C901
254254
self.log_skip(node, "permute node of non compatible linear node")
255255
return False
256256

257-
is_in_local_scalar_dense_chain, dst_node_is_compatible = (
258-
self.is_in_local_scalar_dense_chain(node)
259-
)
257+
(
258+
is_in_local_scalar_dense_chain,
259+
dst_node_is_compatible,
260+
) = self.is_in_local_scalar_dense_chain(node)
260261
if is_in_local_scalar_dense_chain and dst_node_is_compatible:
261262
return True
262263
elif is_in_local_scalar_dense_chain:
@@ -419,6 +420,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
419420
logger.info(f"Found {pl} Vulkan subgraphs to be partitioned.")
420421

421422
tag_constant_data(exported_program)
423+
tag_mutated_buffer(exported_program)
422424

423425
return PartitionResult(
424426
tagged_exported_program=exported_program, partition_tags=partition_tags

backends/vulkan/patterns/quantized_linear.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,11 @@ def __init__(self, mm_node: torch.fx.Node) -> None:
9292
return
9393

9494
# Identify input node
95-
self.fp_input_node, self.quantize_input_node, dq_node = (
96-
utils.maybe_skip_q_dq_arg_chain(self.anchor_node.args[0])
97-
)
95+
(
96+
self.fp_input_node,
97+
self.quantize_input_node,
98+
dq_node,
99+
) = utils.maybe_skip_q_dq_arg_chain(self.anchor_node.args[0])
98100
assert self.fp_input_node is not None
99101
self.all_nodes.append(self.fp_input_node)
100102

@@ -386,7 +388,7 @@ def make_linear_dq8ca_q4gsw_op(
386388
weight_sums_node = create_constant_placeholder(
387389
exp_program=ep,
388390
graph=graph_module.graph,
389-
kind=InputKind.CONSTANT_TENSOR,
391+
kind=InputKind.PARAMETER,
390392
name=sums_name,
391393
data=sum_per_quant_group,
392394
)
@@ -429,7 +431,7 @@ def make_linear_q8ta_q8csw_custom_op(
429431
weight_sums_node = create_constant_placeholder(
430432
exp_program=ep,
431433
graph=graph_module.graph,
432-
kind=InputKind.CONSTANT_TENSOR,
434+
kind=InputKind.PARAMETER,
433435
name=sums_name,
434436
data=sum_per_output_channel,
435437
)

backends/vulkan/runtime/graph/ops/glsl/full.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ full:
1414
DTYPE:
1515
- VALUE: half
1616
- VALUE: float
17+
- VALUE: int32
1718
shader_variants:
1819
- NAME: full

backends/vulkan/runtime/graph/ops/glsl/softmax.glsl

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ layout(constant_id = 5) const int group_dim = 1;
4242
// work group will write into its assigned element in the shared array.
4343
#define MAX_NTHREADS 16
4444

45-
shared vec4 shared_vecs[MAX_NTHREADS];
45+
shared vec4 shared_max[MAX_NTHREADS];
46+
shared vec4 shared_sum[MAX_NTHREADS];
4647

4748
#include "indexing_utils.h"
4849

@@ -102,13 +103,13 @@ void softmax_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) {
102103
i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) {
103104
max_elements = max(max_elements, load_texel(tin, scan_pos));
104105
}
105-
shared_vecs[smi] = max_elements;
106+
shared_max[smi] = max_elements;
106107
barrier();
107108
// Iterate over the partial maximums to obtain the overall maximum
108109
group_i = tid.y * NWORKERS;
109-
max_elements = shared_vecs[group_i++];
110+
max_elements = shared_max[group_i++];
110111
for (int i = 1; i < NWORKERS; ++i, group_i++) {
111-
max_elements = max(max_elements, shared_vecs[group_i]);
112+
max_elements = max(max_elements, shared_max[group_i]);
112113
}
113114

114115
scan_pos[reduce_dim] = tid.x;
@@ -118,13 +119,13 @@ void softmax_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) {
118119
i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) {
119120
denominators += exp(load_texel(tin, scan_pos) - max_elements);
120121
}
121-
shared_vecs[smi] = denominators;
122+
shared_sum[smi] = denominators;
122123
barrier();
123124
// Iterate over the partial sums to obtain the overall sum
124125
group_i = tid.y * NWORKERS;
125-
denominators = shared_vecs[group_i++];
126+
denominators = shared_sum[group_i++];
126127
for (int i = 1; i < NWORKERS; ++i, group_i++) {
127-
denominators += shared_vecs[group_i];
128+
denominators += shared_sum[group_i];
128129
}
129130

130131
// Determine if there are any padding elements in the final texel of the
@@ -184,13 +185,13 @@ void softmax_packed_dim(const ivec2 tid, ivec3 scan_pos) {
184185
max_elements.x = max(intex[i], max_elements.x);
185186
}
186187
}
187-
shared_vecs[smi] = max_elements;
188+
shared_max[smi] = max_elements;
188189
barrier();
189190
// Iterate over the partial maximums to obtain the overall maximum
190191
group_i = tid.y * NWORKERS;
191-
max_elements = shared_vecs[group_i++];
192+
max_elements = shared_max[group_i++];
192193
for (int i = 1; i < NWORKERS; ++i, group_i++) {
193-
max_elements = max(max_elements, shared_vecs[group_i]);
194+
max_elements = max(max_elements, shared_max[group_i]);
194195
}
195196
// Each element of the texel is itself a partial maximum; iterate over the
196197
// texel to find the actual maximum
@@ -214,13 +215,13 @@ void softmax_packed_dim(const ivec2 tid, ivec3 scan_pos) {
214215
denominators.x += exp(intex[i] - max_element);
215216
}
216217
}
217-
shared_vecs[smi] = denominators;
218+
shared_sum[smi] = denominators;
218219
barrier();
219220
// Iterate over the partial sums to obtain the overall sum
220221
group_i = tid.y * NWORKERS;
221-
denominators = shared_vecs[group_i++];
222+
denominators = shared_sum[group_i++];
222223
for (int i = 1; i < NWORKERS; ++i, group_i++) {
223-
denominators += shared_vecs[group_i];
224+
denominators += shared_sum[group_i];
224225
}
225226
// Reduce over the accumulated texel to find the overall sum
226227
float denominator = 0;

0 commit comments

Comments
 (0)