@@ -339,6 +339,7 @@ struct hash_node {
339339};
340340
341341struct tensor_alloc {
342+ int buffer_id ;
342343 size_t offset ;
343344 size_t size_max ; // 0 = pre-allocated, unused, or view
344345};
@@ -349,7 +350,6 @@ struct leaf_alloc {
349350};
350351
351352struct node_alloc {
352- int buffer_id ;
353353 struct tensor_alloc dst ;
354354 struct tensor_alloc src [GGML_MAX_SRC ];
355355};
@@ -511,17 +511,18 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
511511 }
512512}
513513
514- static void ggml_gallocr_free_node (ggml_gallocr_t galloc , struct ggml_tensor * node , int buffer_id ) {
514+ static void ggml_gallocr_free_node (ggml_gallocr_t galloc , struct ggml_tensor * node ) {
515515 // graph outputs are never freed
516516 if (node -> flags & GGML_TENSOR_FLAG_OUTPUT ) {
517517 AT_PRINTF ("not freeing output %s\n" , node -> name );
518518 return ;
519519 }
520520
521- struct ggml_dyn_tallocr * alloc = galloc -> buf_tallocs [buffer_id ];
522- ggml_backend_buffer_type_t buft = galloc -> bufts [buffer_id ];
523521 struct hash_node * hn = ggml_gallocr_hash_get (galloc , node );
524522 size_t offset = hn -> offset ;
523+ int buffer_id = hn -> buffer_id ;
524+ struct ggml_dyn_tallocr * alloc = galloc -> buf_tallocs [buffer_id ];
525+ ggml_backend_buffer_type_t buft = galloc -> bufts [buffer_id ];
525526 size_t size = ggml_backend_buft_get_alloc_size (buft , node );
526527 ggml_dyn_tallocr_free_tensor (alloc , offset , size , node );
527528 hn -> allocated = false;
@@ -626,11 +627,11 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
626627 AT_PRINTF ("view_src %s: %d children, %d views\n" ,
627628 view_src -> name , view_src_hn -> n_children , view_src_hn -> n_views );
628629 if (view_src_hn -> n_views == 0 && view_src_hn -> n_children == 0 && view_src_hn -> allocated ) {
629- ggml_gallocr_free_node (galloc , view_src , buffer_id );
630+ ggml_gallocr_free_node (galloc , view_src );
630631 }
631632 }
632633 else if (p_hn -> allocated ) {
633- ggml_gallocr_free_node (galloc , parent , buffer_id );
634+ ggml_gallocr_free_node (galloc , parent );
634635 }
635636 }
636637 AT_PRINTF ("\n" );
@@ -674,22 +675,26 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
674675 for (int i = 0 ; i < graph -> n_nodes ; i ++ ) {
675676 struct ggml_tensor * node = graph -> nodes [i ];
676677 struct node_alloc * node_alloc = & galloc -> node_allocs [i ];
677- node_alloc -> buffer_id = get_node_buffer_id (node_buffer_ids , i );
678+ // node_alloc->buffer_id = get_node_buffer_id(node_buffer_ids, i);
678679 if (node -> view_src || node -> data ) {
680+ node_alloc -> dst .buffer_id = -1 ;
679681 node_alloc -> dst .offset = SIZE_MAX ;
680682 node_alloc -> dst .size_max = 0 ;
681683 } else {
682684 struct hash_node * hn = ggml_gallocr_hash_get (galloc , node );
683- node_alloc -> dst .offset = hn -> offset ;
684- node_alloc -> dst .size_max = ggml_backend_buft_get_alloc_size (galloc -> bufts [hn -> buffer_id ], node );
685+ node_alloc -> dst .buffer_id = hn -> buffer_id ;
686+ node_alloc -> dst .offset = hn -> offset ;
687+ node_alloc -> dst .size_max = ggml_backend_buft_get_alloc_size (galloc -> bufts [hn -> buffer_id ], node );
685688 }
686689 for (int j = 0 ; j < GGML_MAX_SRC ; j ++ ) {
687690 struct ggml_tensor * src = node -> src [j ];
688691 if (!src || src -> view_src || src -> data ) {
692+ node_alloc -> src [j ].buffer_id = -1 ;
689693 node_alloc -> src [j ].offset = SIZE_MAX ;
690694 node_alloc -> src [j ].size_max = 0 ;
691695 } else {
692696 struct hash_node * hn = ggml_gallocr_hash_get (galloc , src );
697+ node_alloc -> src [j ].buffer_id = hn -> buffer_id ;
693698 node_alloc -> src [j ].offset = hn -> offset ;
694699 node_alloc -> src [j ].size_max = ggml_backend_buft_get_alloc_size (galloc -> bufts [hn -> buffer_id ], src );
695700 }
@@ -706,9 +711,11 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
706711 struct hash_node * hn = ggml_gallocr_hash_get (galloc , leaf );
707712 galloc -> leaf_allocs [i ].buffer_id = hn -> buffer_id ;
708713 if (leaf -> view_src || leaf -> data ) {
714+ galloc -> leaf_allocs [i ].leaf .buffer_id = -1 ;
709715 galloc -> leaf_allocs [i ].leaf .offset = SIZE_MAX ;
710716 galloc -> leaf_allocs [i ].leaf .size_max = 0 ;
711717 } else {
718+ galloc -> leaf_allocs [i ].leaf .buffer_id = hn -> buffer_id ;
712719 galloc -> leaf_allocs [i ].leaf .offset = hn -> offset ;
713720 galloc -> leaf_allocs [i ].leaf .size_max = ggml_backend_buft_get_alloc_size (galloc -> bufts [hn -> buffer_id ], leaf );
714721 }
@@ -740,7 +747,8 @@ bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph *graph) {
740747 return ggml_gallocr_reserve_n (galloc , graph , NULL , NULL );
741748}
742749
743- static void ggml_gallocr_init_tensor (ggml_gallocr_t galloc , struct ggml_tensor * tensor , int buffer_id , struct tensor_alloc * tensor_alloc ) {
750+ static void ggml_gallocr_init_tensor (ggml_gallocr_t galloc , struct ggml_tensor * tensor , struct tensor_alloc * tensor_alloc ) {
751+ int buffer_id = tensor_alloc -> buffer_id ;
744752 assert (tensor -> data || tensor -> view_src || ggml_backend_buffer_get_alloc_size (galloc -> buffers [buffer_id ], tensor ) <= tensor_alloc -> size_max );
745753
746754 if (tensor -> view_src != NULL ) {
@@ -768,8 +776,8 @@ static void ggml_gallocr_init_tensor(ggml_gallocr_t galloc, struct ggml_tensor *
768776 }
769777}
770778
771- static bool ggml_gallocr_node_needs_realloc (ggml_gallocr_t galloc , struct ggml_tensor * node , struct node_alloc * nalloc , struct tensor_alloc * talloc ) {
772- ggml_backend_buffer_type_t buft = galloc -> bufts [nalloc -> buffer_id ];
779+ static bool ggml_gallocr_node_needs_realloc (ggml_gallocr_t galloc , struct ggml_tensor * node , struct tensor_alloc * talloc ) {
780+ ggml_backend_buffer_type_t buft = talloc -> buffer_id != -1 ? galloc -> bufts [talloc -> buffer_id ] : NULL ;
773781 size_t node_size = (node -> data || node -> view_src ) ? 0 : ggml_backend_buft_get_alloc_size (buft , node );
774782 return talloc -> size_max >= node_size ;
775783}
@@ -793,7 +801,7 @@ static bool ggml_gallocr_needs_realloc(ggml_gallocr_t galloc, struct ggml_cgraph
793801 struct ggml_tensor * node = graph -> nodes [i ];
794802 struct node_alloc * node_alloc = & galloc -> node_allocs [i ];
795803
796- if (!ggml_gallocr_node_needs_realloc (galloc , node , node_alloc , & node_alloc -> dst )) {
804+ if (!ggml_gallocr_node_needs_realloc (galloc , node , & node_alloc -> dst )) {
797805#ifndef NDEBUG
798806 fprintf (stderr , "%s: node %s is not valid\n" , __func__ , node -> name );
799807#endif
@@ -805,7 +813,7 @@ static bool ggml_gallocr_needs_realloc(ggml_gallocr_t galloc, struct ggml_cgraph
805813 if (src == NULL ) {
806814 continue ;
807815 }
808- if (!ggml_gallocr_node_needs_realloc (galloc , src , node_alloc , & node_alloc -> src [j ])) {
816+ if (!ggml_gallocr_node_needs_realloc (galloc , src , & node_alloc -> src [j ])) {
809817#ifndef NDEBUG
810818 fprintf (stderr , "%s: src %d (%s) of node %s is not valid\n" , __func__ , j , src -> name , node -> name );
811819#endif
@@ -846,7 +854,7 @@ bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph)
846854 for (int i = 0 ; i < graph -> n_leafs ; i ++ ) {
847855 struct ggml_tensor * leaf = graph -> leafs [i ];
848856 struct leaf_alloc * leaf_alloc = & galloc -> leaf_allocs [i ];
849- ggml_gallocr_init_tensor (galloc , leaf , leaf_alloc -> buffer_id , & leaf_alloc -> leaf );
857+ ggml_gallocr_init_tensor (galloc , leaf , & leaf_alloc -> leaf );
850858 }
851859 // nodes
852860 for (int i = 0 ; i < graph -> n_nodes ; i ++ ) {
@@ -857,9 +865,9 @@ bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph)
857865 if (src == NULL ) {
858866 continue ;
859867 }
860- ggml_gallocr_init_tensor (galloc , src , node_alloc -> buffer_id , & node_alloc -> src [j ]);
868+ ggml_gallocr_init_tensor (galloc , src , & node_alloc -> src [j ]);
861869 }
862- ggml_gallocr_init_tensor (galloc , node , node_alloc -> buffer_id , & node_alloc -> dst );
870+ ggml_gallocr_init_tensor (galloc , node , & node_alloc -> dst );
863871 }
864872
865873 return true;
0 commit comments