Skip to content

Commit 8267cc2

Browse files
committed
vulkan: Use spec constants for conv2d s/d/p and kernel W/H
Also add some additional unrolling, which seems to help.
1 parent e7da30b commit 8267cc2

File tree

2 files changed

+115
-71
lines changed

2 files changed

+115
-71
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 76 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,12 @@ enum vk_conv_shapes {
351351
CONV_SHAPE_COUNT,
352352
};
353353

354+
uint32_t conv_shapes_wg_denoms[][3] = {
355+
{ 128, 128, 1 },
356+
{ 64, 32, 1 },
357+
{ 32, 256, 1 },
358+
};
359+
354360
enum dmmv_wg_sizes {
355361
DMMV_WG_SIZE_SUBGROUP,
356362
DMMV_WG_SIZE_LARGE,
@@ -379,6 +385,18 @@ struct vk_fa_pipeline_state {
379385
}
380386
};
381387

388+
struct vk_conv2d_pipeline_state {
389+
vk_conv2d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t p0, uint32_t p1, uint32_t d0, uint32_t d1, uint32_t KW, uint32_t KH)
390+
: s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), KW(KW), KH(KH) {}
391+
392+
uint32_t s0, s1, p0, p1, d0, d1, KW, KH;
393+
394+
bool operator<(const vk_conv2d_pipeline_state &b) const {
395+
return std::tie(s0, s1, p0, p1, d0, d1, KW, KH) <
396+
std::tie(b.s0, b.s1, b.p0, b.p1, b.d0, b.d1, b.KW, b.KH);
397+
}
398+
};
399+
382400
enum shader_reduction_mode {
383401
SHADER_REDUCTION_MODE_SHMEM,
384402
SHADER_REDUCTION_MODE_HYBRID,
@@ -668,10 +686,10 @@ struct vk_device_struct {
668686
vk_pipeline pipeline_ssm_conv_f32;
669687
vk_pipeline pipeline_opt_step_adamw_f32;
670688
vk_pipeline pipeline_opt_step_sgd_f32;
671-
vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
672-
vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
673-
vk_pipeline pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT];
674-
vk_pipeline pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT];
689+
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv2d_f32[CONV_SHAPE_COUNT];
690+
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
691+
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT];
692+
std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT];
675693
vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
676694
vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;
677695

@@ -1244,17 +1262,13 @@ struct vk_op_conv2d_push_constants {
12441262
uint32_t nb2;
12451263
uint32_t nb3;
12461264

1247-
// init_fastdiv_values constants for dividing by KW, KW*KH, OW, OW*OH
1248-
uint32_t KWmp; uint32_t KWL;
1249-
uint32_t KWKHmp; uint32_t KWKHL;
1265+
// init_fastdiv_values constants for dividing by OW, OW*OH
12501266
uint32_t OWmp; uint32_t OWL;
12511267
uint32_t OWOHmp; uint32_t OWOHL;
12521268
};
12531269

12541270
template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) {
1255-
// Compute magic values to divide by KW, KW*KH, OW, OW*OH
1256-
init_fastdiv_values(p.KW, p.KWmp, p.KWL);
1257-
init_fastdiv_values(p.KW*p.KH, p.KWKHmp, p.KWKHL);
1271+
// Compute magic values to divide by OW, OW*OH
12581272
init_fastdiv_values(p.OW, p.OWmp, p.OWL);
12591273
init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
12601274
}
@@ -1290,23 +1304,15 @@ struct vk_op_conv_transpose_2d_push_constants {
12901304
uint32_t nb2;
12911305
uint32_t nb3;
12921306

1293-
// init_fastdiv_values constants for dividing by KW, KW*KH, OW, OW*OH, s0, s1
1294-
uint32_t KWmp; uint32_t KWL;
1295-
uint32_t KWKHmp; uint32_t KWKHL;
1307+
// init_fastdiv_values constants for dividing by OW, OW*OH
12961308
uint32_t OWmp; uint32_t OWL;
12971309
uint32_t OWOHmp; uint32_t OWOHL;
1298-
uint32_t s0mp; uint32_t s0L;
1299-
uint32_t s1mp; uint32_t s1L;
13001310
};
13011311

13021312
template <> void init_pushconst_fastdiv(vk_op_conv_transpose_2d_push_constants &p) {
1303-
// Compute magic values to divide by KW, KW*KH, OW, OW*OH, s0, s1
1304-
init_fastdiv_values(p.KW, p.KWmp, p.KWL);
1305-
init_fastdiv_values(p.KW*p.KH, p.KWKHmp, p.KWKHL);
1313+
// Compute magic values to divide by OW, OW*OH
13061314
init_fastdiv_values(p.OW, p.OWmp, p.OWL);
13071315
init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
1308-
init_fastdiv_values(p.s0, p.s0mp, p.s0L);
1309-
init_fastdiv_values(p.s1, p.s1mp, p.s1L);
13101316
}
13111317

13121318
struct vk_op_conv2d_dw_push_constants {
@@ -3828,22 +3834,22 @@ static void ggml_vk_load_shaders(vk_device& device) {
38283834
switch (s) {
38293835
default:
38303836
case CONV_SHAPE_128x128:
3831-
conv2d_BS_K = 128;
3832-
conv2d_BS_NPQ = 128;
3837+
conv2d_BS_K = conv_shapes_wg_denoms[CONV_SHAPE_128x128][0];
3838+
conv2d_BS_NPQ = conv_shapes_wg_denoms[CONV_SHAPE_128x128][1];
38333839
conv2d_BS_CRS = 16;
38343840
if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != vk_device_architecture::AMD_GCN) {
38353841
conv2d_UNROLL = false;
38363842
}
38373843
break;
38383844
case CONV_SHAPE_64x32:
3839-
conv2d_BS_K = 64;
3840-
conv2d_BS_NPQ = 32;
3845+
conv2d_BS_K = conv_shapes_wg_denoms[CONV_SHAPE_64x32][0];
3846+
conv2d_BS_NPQ = conv_shapes_wg_denoms[CONV_SHAPE_64x32][1];
38413847
conv2d_BS_CRS = 32;
38423848
conv2d_TS_K = 4;
38433849
break;
38443850
case CONV_SHAPE_32x256:
3845-
conv2d_BS_K = 32;
3846-
conv2d_BS_NPQ = 256;
3851+
conv2d_BS_K = conv_shapes_wg_denoms[CONV_SHAPE_32x256][0];
3852+
conv2d_BS_NPQ = conv_shapes_wg_denoms[CONV_SHAPE_32x256][1];
38473853
conv2d_BS_CRS = 16;
38483854
break;
38493855
}
@@ -3877,10 +3883,22 @@ static void ggml_vk_load_shaders(vk_device& device) {
38773883
std::vector<uint32_t> spec_constants = { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD };
38783884

38793885
#define CREATE_CONV(name, type_suffix, spv_suffix) \
3880-
ggml_vk_create_pipeline( \
3881-
device, device->pipeline_##name##type_suffix[s], #name #type_suffix, \
3882-
name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \
3883-
sizeof(vk_op_##name##_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
3886+
for (auto &c : device->pipeline_##name##type_suffix[s]) { \
3887+
const vk_conv2d_pipeline_state &state = c.first; \
3888+
std::vector<uint32_t> spec_constants_cpy = spec_constants; \
3889+
spec_constants_cpy.push_back(state.s0); \
3890+
spec_constants_cpy.push_back(state.s1); \
3891+
spec_constants_cpy.push_back(state.p0); \
3892+
spec_constants_cpy.push_back(state.p1); \
3893+
spec_constants_cpy.push_back(state.d0); \
3894+
spec_constants_cpy.push_back(state.d1); \
3895+
spec_constants_cpy.push_back(state.KW); \
3896+
spec_constants_cpy.push_back(state.KH); \
3897+
ggml_vk_create_pipeline( \
3898+
device, c.second, #name #type_suffix, \
3899+
name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \
3900+
sizeof(vk_op_##name##_push_constants), wg_denoms, spec_constants_cpy, 1, true, use_collectives); \
3901+
}
38843902
#define CREATE_CONVS(spv_suffix) \
38853903
CREATE_CONV(conv2d, _f32, spv_suffix) \
38863904
CREATE_CONV(conv2d, _f16_f32, spv_suffix) \
@@ -8551,7 +8569,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
85518569

85528570
uint32_t tiles[CONV_SHAPE_COUNT];
85538571
for (uint32_t i = 0; i < CONV_SHAPE_COUNT; ++i) {
8554-
tiles[i] = CEIL_DIV(elements[0], ctx->device->pipeline_conv2d_f32[i]->wg_denoms[0]) * CEIL_DIV(elements[1], ctx->device->pipeline_conv2d_f32[i]->wg_denoms[1]);
8572+
tiles[i] = CEIL_DIV(elements[0], conv_shapes_wg_denoms[i][0]) * CEIL_DIV(elements[1], conv_shapes_wg_denoms[i][1]);
85558573
}
85568574

85578575
// We can't query number of shader cores on Intel, use 32 as a placeholder
@@ -8566,19 +8584,42 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
85668584
shape = CONV_SHAPE_64x32;
85678585
}
85688586

8587+
uint32_t KW = static_cast<uint32_t>(src0->ne[0]);
8588+
uint32_t KH = static_cast<uint32_t>(src0->ne[1]);
8589+
uint32_t s0 = static_cast<uint32_t>(dst->op_params[0]);
8590+
uint32_t s1 = op == GGML_OP_CONV_2D ? static_cast<uint32_t>(dst->op_params[1]) : static_cast<uint32_t>(dst->op_params[0]);
8591+
uint32_t p0 = op == GGML_OP_CONV_2D ? static_cast<uint32_t>(dst->op_params[2]) : 0;
8592+
uint32_t p1 = op == GGML_OP_CONV_2D ? static_cast<uint32_t>(dst->op_params[3]) : 0;
8593+
uint32_t d0 = op == GGML_OP_CONV_2D ? static_cast<uint32_t>(dst->op_params[4]) : 1;
8594+
uint32_t d1 = op == GGML_OP_CONV_2D ? static_cast<uint32_t>(dst->op_params[5]) : 1;
8595+
8596+
vk_conv2d_pipeline_state conv2d_pipeline_state(s0, s1, p0, p1, d0, d1, KW, KH);
8597+
8598+
std::map<vk_conv2d_pipeline_state, vk_pipeline> *pipelines = nullptr;
85698599
if (op == GGML_OP_CONV_2D) {
85708600
if (src0->type == GGML_TYPE_F32) {
8571-
return ctx->device->pipeline_conv2d_f32[shape];
8601+
pipelines = &ctx->device->pipeline_conv2d_f32[shape];
85728602
} else if (src0->type == GGML_TYPE_F16) {
8573-
return ctx->device->pipeline_conv2d_f16_f32[shape];
8603+
pipelines = &ctx->device->pipeline_conv2d_f16_f32[shape];
85748604
}
85758605
} else if (op == GGML_OP_CONV_TRANSPOSE_2D) {
85768606
if (src0->type == GGML_TYPE_F32) {
8577-
return ctx->device->pipeline_conv_transpose_2d_f32[shape];
8607+
pipelines = &ctx->device->pipeline_conv_transpose_2d_f32[shape];
85788608
} else if (src0->type == GGML_TYPE_F16) {
8579-
return ctx->device->pipeline_conv_transpose_2d_f16_f32[shape];
8609+
pipelines = &ctx->device->pipeline_conv_transpose_2d_f16_f32[shape];
85808610
}
85818611
}
8612+
8613+
vk_pipeline pipeline = nullptr;
8614+
8615+
auto it = pipelines->find(conv2d_pipeline_state);
8616+
if (it != pipelines->end()) {
8617+
pipeline = it->second;
8618+
} else {
8619+
(*pipelines)[conv2d_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();
8620+
}
8621+
8622+
return pipeline;
85828623
}
85838624
return nullptr;
85848625
case GGML_OP_CONV_2D_DW:

ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp

Lines changed: 39 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,8 @@ layout(push_constant) uniform parameter {
6262
uint32_t nb3;
6363

6464
// fastdiv helper values
65-
uint32_t KWmp; uint32_t KWL;
66-
uint32_t KWKHmp; uint32_t KWKHL;
6765
uint32_t OWmp; uint32_t OWL;
6866
uint32_t OWOHmp; uint32_t OWOHL;
69-
#ifdef TRANSPOSE
70-
uint32_t s0mp; uint32_t s0L;
71-
uint32_t s1mp; uint32_t s1L;
72-
#endif
7367
}
7468

7569
p;
@@ -84,6 +78,15 @@ layout(constant_id = 4) const uint TS_K = 8;
8478
layout(constant_id = 5) const uint use_collectives = 1;
8579
layout(constant_id = 6) const uint SHMEM_PAD = 4;
8680

81+
layout(constant_id = 7) const uint s0 = 1;
82+
layout(constant_id = 8) const uint s1 = 1;
83+
layout(constant_id = 9) const uint p0 = 0;
84+
layout(constant_id = 10) const uint p1 = 0;
85+
layout(constant_id = 11) const uint d0 = 1;
86+
layout(constant_id = 12) const uint d1 = 1;
87+
layout(constant_id = 13) const uint KW = 1;
88+
layout(constant_id = 14) const uint KH = 1;
89+
8790
uint32_t tid = gl_LocalInvocationID.x;
8891
const uint32_t WG_SIZE = gl_WorkGroupSize.x;
8992

@@ -92,7 +95,7 @@ uint splitWork(uint work_size, uint block_size) {
9295
}
9396

9497
uint32_t K = p.Cout;
95-
uint32_t CRS = p.Cin * p.KH * p.KW;
98+
uint32_t CRS = p.Cin * KH * KW;
9699
uint32_t NPQ = p.N * p.OH * p.OW;
97100

98101
uint32_t n_elems_out = K * NPQ;
@@ -187,7 +190,7 @@ void main() {
187190
}
188191
#endif
189192
/* Advance block in CRS dim */
190-
for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) {
193+
UNROLL for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) {
191194
uint32_t CRS_idx_a;
192195
uint32_t Cin_idx_a;
193196
uint32_t KH_idx_a;
@@ -200,32 +203,32 @@ void main() {
200203
uint32_t cached_KW_idx;
201204
if (use_collectives == 1) {
202205
cached_CRS_idx = B_idx_CRS * BS_CRS + gl_SubgroupInvocationID;
203-
cached_Cin_idx = fastdiv(cached_CRS_idx, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
204-
uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx * p.KW * p.KH);
205-
cached_KH_idx = fastdiv(cached_CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
206-
cached_KW_idx = cached_CRS_remainder - cached_KH_idx * p.KW;
206+
cached_Cin_idx = cached_CRS_idx / (KW * KH);
207+
uint32_t cached_CRS_remainder = cached_CRS_idx % (KW * KH);
208+
cached_KH_idx = cached_CRS_remainder / KW;
209+
cached_KW_idx = cached_CRS_remainder % KW;
207210

208211
CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac);
209212
Cin_idx_a = subgroupShuffle(cached_Cin_idx, Ac);
210213
KH_idx_a = subgroupShuffle(cached_KH_idx, Ac);
211214
KW_idx_a = subgroupShuffle(cached_KW_idx, Ac);
212215
} else {
213216
CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
214-
Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
215-
uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
216-
KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
217-
KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
217+
Cin_idx_a = CRS_idx_a / (KW * KH);
218+
uint32_t CRS_remainder = CRS_idx_a % (KW * KH);
219+
KH_idx_a = CRS_remainder / KW;
220+
KW_idx_a = CRS_remainder % KW;
218221
}
219222
#else
220223
CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A)
221-
Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); / (p.KW * p.KH);
222-
CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH;
223-
KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
224-
KW_idx_a = CRS_remainder - KH_idx_a * p.KW;
224+
Cin_idx_a = CRS_idx_a / (KW * KH);
225+
CRS_remainder = CRS_idx_a % (KW * KH);
226+
KH_idx_a = CRS_remainder / KW;
227+
KW_idx_a = CRS_remainder % KW;
225228
#endif
226229

227230
/* Load kernel to A_block: (BS_K x BS_CRS)*/
228-
for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) {
231+
UNROLL for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) {
229232
uint32_t B_ly = r_offset + Ar;
230233
uint32_t B_lx = Ac;
231234
uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/
@@ -262,35 +265,35 @@ void main() {
262265
KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br);
263266
} else {
264267
CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
265-
Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
266-
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
267-
KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
268-
KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
268+
Cin_idx_b = CRS_idx_b / (KW * KH);
269+
uint32_t CRS_remainder = CRS_idx_b % (KW * KH);
270+
KH_idx_b = CRS_remainder / KW;
271+
KW_idx_b = CRS_remainder % KW;
269272
}
270273
#else
271274
CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
272-
Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH);
273-
uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH;
274-
KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW;
275-
KW_idx_b = CRS_remainder - KH_idx_b * p.KW;
275+
Cin_idx_b = CRS_idx_b / (KW * KH);
276+
uint32_t CRS_remainder = CRS_idx_b % (KW * KH);
277+
KH_idx_b = CRS_remainder / KW;
278+
KW_idx_b = CRS_remainder % KW;
276279
#endif
277280

278281
#ifdef TRANSPOSE
279-
uint32_t H_idx_x_s1 = OH_idx - KH_idx_b * p.d1 + p.p1;
280-
uint32_t W_idx_x_s0 = OW_idx - KW_idx_b * p.d0 + p.p0;
281-
uint32_t H_idx = fastdiv(H_idx_x_s1, p.s1mp, p.s1L);
282-
uint32_t W_idx = fastdiv(W_idx_x_s0, p.s0mp, p.s0L);
282+
uint32_t H_idx_x_s1 = OH_idx - KH_idx_b * d1 + p1;
283+
uint32_t W_idx_x_s0 = OW_idx - KW_idx_b * d0 + p0;
284+
uint32_t H_idx = H_idx_x_s1 / s1;
285+
uint32_t W_idx = W_idx_x_s0 / s0;
283286
#else
284-
uint32_t H_idx = OH_idx * p.s1 + KH_idx_b * p.d1 - p.p1;
285-
uint32_t W_idx = OW_idx * p.s0 + KW_idx_b * p.d0 - p.p0;
287+
uint32_t H_idx = OH_idx * s1 + KH_idx_b * d1 - p1;
288+
uint32_t W_idx = OW_idx * s0 + KW_idx_b * d0 - p0;
286289
#endif
287290
uint32_t src_idx =
288291
min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1);
289292
float val = src_data[src_idx];
290293
if (CRS_idx_b >= CRS || NPQ_idx >= NPQ
291294
|| H_idx >= p.H || W_idx >= p.W // Lower bound checks aren't necessary. (idx >= 0x80000000 for such case)
292295
#ifdef TRANSPOSE
293-
|| (H_idx_x_s1 - H_idx * p.s1 != 0) || (W_idx_x_s0 - W_idx * p.s0 != 0)
296+
|| (H_idx_x_s1 - H_idx * s1 != 0) || (W_idx_x_s0 - W_idx * s0 != 0)
294297
#endif
295298
) {
296299
val = 0.0;

0 commit comments

Comments
 (0)