@@ -563,7 +563,9 @@ static bool ggml_metal_heap_resize(struct ggml_metal_heap * heap, size_t size) {
563563 return true ;
564564}
565565
566- static id <MTLBuffer > ggml_metal_heap_alloc (struct ggml_metal_heap * heap, size_t size, size_t alignment) {
566+ static id <MTLBuffer > ggml_metal_heap_alloc (struct ggml_metal_heap * heap, size_t size) {
567+ const size_t alignment = 1024 *1024 ;
568+
567569 const size_t size_aligned = GGML_PAD (size, alignment);
568570
569571 heap->need += size_aligned;
@@ -1583,7 +1585,8 @@ static bool ggml_metal_encode_node(
15831585 ggml_backend_t backend,
15841586 int idx,
15851587 id <MTLComputeCommandEncoder > encoder,
1586- struct ggml_metal_heap * heap) {
1588+ struct ggml_metal_heap * heap,
1589+ bool no_compute) {
15871590 struct ggml_backend_metal_context * ctx = backend->context ;
15881591 struct ggml_backend_metal_device_context * ctx_dev = backend->device ->context ;
15891592
@@ -1621,6 +1624,28 @@ static bool ggml_metal_encode_node(
16211624 GGML_ABORT (" unsupported op" );
16221625 }
16231626
1627+ id <MTLBuffer > h_src0 = nil ;
1628+ switch (dst->op ) {
1629+ case GGML_OP_SOFT_MAX:
1630+ {
1631+ h_src0 = ggml_metal_heap_alloc (heap, ggml_nbytes (src0));
1632+ if (!h_src0) {
1633+ // GGML_LOG_ERROR("%s: failed to allocate buffer, idx = %4d, size = %8zu, need = %8zu, max available = %9zu, heap size = %9zu, heap used = %zu\n",
1634+ // __func__, idx, ggml_nbytes(src0), heap->need, [heap->obj maxAvailableSizeWithAlignment:0], [heap->obj size], [heap->obj usedSize]);
1635+ return false ;
1636+ } else {
1637+ // GGML_LOG_ERROR("%s: allocated %zu\n", __func__, ggml_nbytes(src0));
1638+ }
1639+ } break ;
1640+ default :
1641+ {
1642+ } break ;
1643+ }
1644+
1645+ if (no_compute) {
1646+ return true ;
1647+ }
1648+
16241649 const int64_t ne00 = src0 ? src0->ne [0 ] : 0 ;
16251650 const int64_t ne01 = src0 ? src0->ne [1 ] : 0 ;
16261651 const int64_t ne02 = src0 ? src0->ne [2 ] : 0 ;
@@ -2278,23 +2303,14 @@ static bool ggml_metal_encode_node(
22782303 /* .nb3 =*/ nb03,
22792304 };
22802305
2281- id <MTLBuffer > id_src0h = ggml_metal_heap_alloc (heap, ggml_nbytes (src0), 64 *1024 );
2282- if (!id_src0h) {
2283- // GGML_LOG_ERROR("%s: failed to allocate buffer, idx = %4d, size = %8zu, need = %8zu, max available = %9zu, heap size = %9zu, heap used = %zu\n",
2284- // __func__, idx, ggml_nbytes(src0), heap->need, [heap->obj maxAvailableSizeWithAlignment:0], [heap->obj size], [heap->obj usedSize]);
2285- return true ;
2286- } else {
2287- // GGML_LOG_ERROR("%s: allocated %zu\n", __func__, ggml_nbytes(src0));
2288- }
2289-
22902306 if (src0->type == GGML_TYPE_F16) {
22912307 [encoder setComputePipelineState: ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline];
22922308 } else {
22932309 [encoder setComputePipelineState: ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline];
22942310 }
22952311 [encoder setBytes: &args_cpy length: sizeof (args_cpy) atIndex: 0 ];
22962312 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
2297- [encoder setBuffer: id_src0h offset: 0 atIndex: 2 ];
2313+ [encoder setBuffer: h_src0 offset: 0 atIndex: 2 ];
22982314
22992315 GGML_ASSERT (ne00 % ggml_blck_size (src0->type ) == 0 );
23002316 int nth_cpy = MIN (1024 , ne00 / ggml_blck_size (src0->type ));
@@ -2315,11 +2331,11 @@ static bool ggml_metal_encode_node(
23152331 };
23162332
23172333 [encoder setComputePipelineState: pipeline];
2318- [encoder setBuffer: id_src0h offset: 0 atIndex: 0 ];
2334+ [encoder setBuffer: h_src0 offset: 0 atIndex: 0 ];
23192335 if (id_src1) {
23202336 [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
23212337 } else {
2322- [encoder setBuffer: id_src0h offset: 0 atIndex: 1 ];
2338+ [encoder setBuffer: h_src0 offset: 0 atIndex: 1 ];
23232339 }
23242340 [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
23252341 [encoder setBytes: &args length: sizeof (args) atIndex: 3 ];
@@ -4732,6 +4748,12 @@ static enum ggml_status ggml_metal_graph_compute(
47324748 }
47334749 }
47344750
4751+ for (int i = 0 ; i <= n_cb; ++i) {
4752+ struct ggml_metal_heap * heap = ctx->cmd_bufs [i].heap ;
4753+
4754+ [heap->obj setPurgeableState: MTLPurgeableStateNonVolatile ];
4755+ }
4756+
47354757 // the main thread commits the first few commands immediately
47364758 // cmd_buf[n_cb]
47374759 {
@@ -4824,6 +4846,7 @@ static enum ggml_status ggml_metal_graph_compute(
48244846
48254847 if (heap->fail == 0 ) {
48264848 ggml_metal_heap_reset (ctx->cmd_bufs [i].heap );
4849+ [heap->obj setPurgeableState: MTLPurgeableStateEmpty ];
48274850
48284851 continue ;
48294852 }
@@ -5234,19 +5257,21 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
52345257
52355258 const bool should_capture = ctx->capture_next_compute ;
52365259
5260+ bool no_compute = false ;
5261+
52375262 for (int idx = node_start; idx < node_end; ++idx) {
52385263 if (should_capture) {
52395264 [encoder pushDebugGroup: [NSString stringWithCString: ggml_op_desc (ggml_graph_node (ctx->gf, idx)) encoding: NSUTF8StringEncoding]];
52405265 }
52415266
5242- const bool res = ggml_metal_encode_node (backend, idx, encoder, heap);
5267+ const bool res = ggml_metal_encode_node (backend, idx, encoder, heap, no_compute );
52435268
52445269 if (should_capture) {
52455270 [encoder popDebugGroup ];
52465271 }
52475272
52485273 if (!res) {
5249- break ;
5274+ no_compute = true ;
52505275 }
52515276 }
52525277
0 commit comments