77#import < Metal/Metal.h>
88#import < MetalPerformanceShaders/MetalPerformanceShaders.h>
99
10+ #undef MIN
11+ #undef MAX
12+ #define MIN (a, b ) ((a) < (b) ? (a) : (b))
13+ #define MAX (a, b ) ((a) > (b) ? (a) : (b))
14+
1015#ifdef GGML_METAL_NDEBUG
1116#define metal_printf (...)
1217#else
1520
1621#define UNUSED (x ) (void )(x)
1722
23+ #define GGML_MAX_CONCUR (2 *GGML_MAX_NODES)
24+
1825struct ggml_metal_buffer {
1926 const char * name;
2027
3643 int n_buffers;
3744 struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
3845
39- int concur_list[GGML_MAX_NODES ];
46+ int concur_list[GGML_MAX_CONCUR ];
4047 int concur_list_len;
4148
4249 // custom kernels
@@ -370,44 +377,56 @@ void ggml_metal_graph_find_concurrency(
370377 struct ggml_metal_context * ctx,
371378 struct ggml_cgraph * gf) {
372379 int search_depth = gf->n_nodes ; // we only find concurrency in this range to avoid wasting too much time
373- int nodes_unused[GGML_MAX_NODES ];
380+ int nodes_unused[GGML_MAX_CONCUR ];
374381
375- for (int i = 0 ; i < GGML_MAX_NODES ; i++) {ctx->concur_list [i] = 0 ;}
376- for (int i = 0 ; i < gf->n_nodes ; i++) {nodes_unused[i] = 1 ;}
382+ for (int i = 0 ; i < GGML_MAX_CONCUR ; i++) { ctx->concur_list [i] = 0 ; }
383+ for (int i = 0 ; i < gf->n_nodes ; i++) { nodes_unused[i] = 1 ; }
377384 ctx->concur_list_len = 0 ;
378385
379- int n_left = gf->n_nodes ;
380- int n_start = 0 ; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
381- int level_pos = 0 ; // at ctx->concur_list, the last layer (level) ends at level_pos
386+ int n_left = gf->n_nodes ;
387+ int n_start = 0 ; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
388+ int level_pos = 0 ; // at ctx->concur_list, the last layer (level) ends at level_pos
382389
383390 while (n_left > 0 ) {
384391 // number of nodes at a layer (that can be issued concurrently)
385392 int concurrency = 0 ;
386393 for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes ) ? gf->n_nodes : n_start + search_depth); i++) {
387394 if (nodes_unused[i]) {
388395 // if the requirements for gf->nodes[i] are satisfied
389- int exe_flag=1 ;
396+ int exe_flag = 1 ;
397+
390398 // scan all srcs
391399 for (int src_ind = 0 ; src_ind < GGML_MAX_SRC; src_ind++) {
392400 struct ggml_tensor * src_cur = gf->nodes [i]->src [src_ind];
393401 if (src_cur) {
394402 // if is leaf nodes it's satisfied.
395- if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL ) {continue ;}
403+ // TODO: ggml_is_leaf()
404+ if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL ) {
405+ continue ;
406+ }
396407
397408 // otherwise this src should be the output from previous nodes.
398409 int is_found = 0 ;
410+
399411 // scan 2*search_depth back because we inserted barrier.
400- for (int j = ((level_pos - 2 *search_depth) < 0 ? 0 : (level_pos - 2 *search_depth)); j < level_pos; j++) {
401- if (gf->nodes [ctx->concur_list[j]] == src_cur) {is_found = 1 ; break ;}
412+ // for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
413+ for (int j = MAX (0 , level_pos - 2 *search_depth); j < level_pos; j++) {
414+ if (ctx->concur_list [j] >= 0 && gf->nodes [ctx->concur_list[j]] == src_cur) {
415+ is_found = 1 ;
416+ break ;
417+ }
418+ }
419+ if (is_found == 0 ) {
420+ exe_flag = 0 ;
421+ break ;
402422 }
403- if (is_found == 0 ) {exe_flag = 0 ; break ;}
404423 }
405424 }
406425 if (exe_flag) {
407426 // check if nodes[i]'s data will be overwritten by a node before nodes[i].
408427 // if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
409428 int64_t data_start = (int64_t ) gf->nodes [i]->data ;
410- int64_t length = (int64_t ) ggml_nbytes (gf->nodes [i]);
429+ int64_t length = (int64_t ) ggml_nbytes (gf->nodes [i]);
411430 for (int j = n_start; j < i; j++) {
412431 if (nodes_unused[j] && gf->nodes [j]->op != GGML_OP_RESHAPE \
413432 && gf->nodes [j]->op != GGML_OP_VIEW \
@@ -416,9 +435,9 @@ void ggml_metal_graph_find_concurrency(
416435 if (((int64_t )gf->nodes [j]->data ) >= data_start + length || \
417436 ((int64_t )gf->nodes [j]->data ) + (int64_t ) ggml_nbytes (gf->nodes [j]) <= data_start) {
418437 continue ;
419- } else {
420- exe_flag = 0 ;
421438 }
439+
440+ exe_flag = 0 ;
422441 }
423442 }
424443 }
@@ -435,11 +454,13 @@ void ggml_metal_graph_find_concurrency(
435454 ctx->concur_list [level_pos + concurrency] = -1 ;
436455 ctx->concur_list_len ++;
437456 // jump all sorted nodes at nodes_bak
438- while (!nodes_unused[n_start]) {n_start++;}
457+ while (!nodes_unused[n_start]) {
458+ n_start++;
459+ }
439460 level_pos += concurrency + 1 ;
440461 }
441462
442- if (ctx->concur_list_len > GGML_MAX_NODES ) {
463+ if (ctx->concur_list_len > GGML_MAX_CONCUR ) {
443464 fprintf (stderr, " %s : too many elements for metal ctx->concur_list!\n " , __func__);
444465 }
445466}
@@ -453,7 +474,7 @@ void ggml_metal_graph_compute(
453474 // else fallback to serial dispatch
454475 MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor ;
455476
456- const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES ;
477+ const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR ;
457478
458479 const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes ;
459480 edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial ;
0 commit comments