Skip to content

Commit e77b89e

Browse files
authored
[ET-VK] Reduced int precision for all int storage in q linear op, and reducing some texture coordinate storage variables to improve performance.
Differential Revision: D64780578 Pull Request resolved: #6465
1 parent fa30e80 commit e77b89e

File tree

1 file changed

+10
-14
lines changed

1 file changed

+10
-14
lines changed

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -91,27 +91,23 @@ void main() {
9191

9292
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
9393

94-
VEC4_T q_8w_linear(const ivec3 out_pos, const int K) {
95-
u16vec3 mat1_pos = u16vec3(0, out_pos.yz);
96-
u16vec3 qmat2_pos = u16vec3(0, out_pos.x * 4, 0);
94+
VEC4_T q_8w_linear(const u16vec3 out_pos, const uint16_t K) {
95+
const uint16_t qmat2_pos_y = out_pos.x * uint16_t(4);
9796

9897
VEC4_T outtex = VEC4_T(0);
9998

10099
const u16vec3 scales_pos = u16vec3(out_pos.x, 0, 0);
101100
const VEC4_T scales = load_texel(t_scales, scales_pos);
102101

103-
for (int i = 0; i < K; i += 4) {
104-
const VEC4_T mat1_tex = load_texel(t_mat1, mat1_pos);
102+
for (uint16_t i = uint16_t(0), x = uint16_t(0); i < K; i += uint16_t(4), x++) {
103+
const VEC4_T mat1_tex = load_texel(t_mat1, u16vec3(x, out_pos.yz));
105104
const VEC4_T sums = VEC4_T(
106-
dot(mat1_tex, load_texel(t_qmat2, qmat2_pos)),
107-
dot(mat1_tex, load_texel(t_qmat2, qmat2_pos + u16vec3(0, 1, 0))),
108-
dot(mat1_tex, load_texel(t_qmat2, qmat2_pos + u16vec3(0, 2, 0))),
109-
dot(mat1_tex, load_texel(t_qmat2, qmat2_pos + u16vec3(0, 3, 0))));
105+
dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y, 0))),
106+
dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y + uint16_t(1), 0))),
107+
dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y + uint16_t(2), 0))),
108+
dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y + uint16_t(3), 0))));
110109

111110
outtex += sums;
112-
113-
mat1_pos.x++;
114-
qmat2_pos.x++;
115111
}
116112

117113
outtex *= scales;
@@ -120,12 +116,12 @@ VEC4_T q_8w_linear(const ivec3 out_pos, const int K) {
120116
}
121117

122118
void main() {
123-
const ivec3 out_pos = ivec3(gl_GlobalInvocationID);
119+
const u16vec3 out_pos = u16vec3(gl_GlobalInvocationID);
124120
if (any(greaterThanEqual(out_pos, out_limits))) {
125121
return;
126122
}
127123

128-
VEC4_T outtex = q_8w_linear(out_pos, mat1_sizes.x);
124+
VEC4_T outtex = q_8w_linear(out_pos, uint16_t(mat1_sizes.x));
129125
write_texel(t_out, out_pos, outtex);
130126
}
131127

0 commit comments

Comments
 (0)