@@ -60,6 +60,9 @@ static const size_t MB = 1024*1024;
6060// TODO: dynamically determine these sizes
6161// needs modifications in ggml
6262
63+ typedef void (*offload_func_t )(struct ggml_tensor * tensor);
64+ void llama_nop (struct ggml_tensor * tensor) {} // do nothing by default
65+
6366static const std::map<e_model, size_t > & MEM_REQ_SCRATCH0 ()
6467{
6568 static std::map<e_model, size_t > k_sizes = {
@@ -1300,10 +1303,11 @@ static bool llama_eval_internal(
13001303 const int i_gpu_start = n_layer - n_gpu_layers;
13011304
13021305 for (int il = 0 ; il < n_layer; ++il) {
1303- ggml_backend backend_offload = GGML_BACKEND_CPU;
1306+ offload_func_t offload_func = llama_nop;
1307+
13041308#ifdef GGML_USE_CUBLAS
13051309 if (il >= i_gpu_start) {
1306- backend_offload = GGML_BACKEND_GPU;
1310+ offload_func = ggml_cuda_assign_buffers; // sets the output backend to GPU
13071311 }
13081312#endif // GGML_USE_CUBLAS
13091313
@@ -1313,40 +1317,31 @@ static bool llama_eval_internal(
13131317
13141318 // norm
13151319 {
1316- ggml_set_default_backend (ctx0, backend_offload);
13171320 cur = ggml_rms_norm (ctx0, inpL);
1321+ offload_func (cur);
13181322 ggml_set_name (cur, " rms_norm_0" );
13191323
13201324 // cur = cur*attention_norm(broadcasted)
13211325 cur = ggml_mul (ctx0, cur, model.layers [il].attention_norm );
1326+ offload_func (cur);
13221327 ggml_set_name (cur, " attention_norm_0" );
13231328 }
13241329
13251330 // self-attention
13261331 {
13271332 // compute Q and K and RoPE them
13281333 struct ggml_tensor * tmpq = ggml_reshape_3d (ctx0, ggml_mul_mat (ctx0, model.layers [il].wq , cur), n_embd/n_head, n_head, N);
1334+ offload_func (cur);
13291335 ggml_set_name (tmpq, " tmpq" );
13301336 struct ggml_tensor * tmpk = ggml_reshape_3d (ctx0, ggml_mul_mat (ctx0, model.layers [il].wk , cur), n_embd/n_head, n_head, N);
1337+ offload_func (cur);
13311338 ggml_set_name (tmpk, " tmpk" );
1332- ggml_set_default_backend (ctx0, GGML_BACKEND_CPU);
13331339
1334- #ifdef GGML_USE_CUBLAS
1335- struct ggml_tensor * Kcur;
1336- struct ggml_tensor * Qcur;
1337- if (backend_offload == GGML_BACKEND_GPU) {
1338- Kcur = ggml_rope (ctx0, tmpk, n_past, n_rot, 0 );
1339- Qcur = ggml_rope (ctx0, tmpq, n_past, n_rot, 0 );
1340- } else {
1341- Kcur = ggml_rope_inplace (ctx0, tmpk, n_past, n_rot, 0 );
1342- Qcur = ggml_rope_inplace (ctx0, tmpq, n_past, n_rot, 0 );
1343- }
1344- #else
13451340 struct ggml_tensor * Kcur = ggml_rope_inplace (ctx0, tmpk, n_past, n_rot, 0 );
1341+ ggml_set_name (Kcur, " Kcur" );
1342+
13461343 struct ggml_tensor * Qcur = ggml_rope_inplace (ctx0, tmpq, n_past, n_rot, 0 );
1347- #endif // GGML_USE_CUBLAS
13481344 ggml_set_name (Qcur, " Qcur" );
1349- ggml_set_name (Kcur, " Kcur" );
13501345
13511346 // store key and value to memory
13521347 {
@@ -1430,62 +1425,70 @@ static bool llama_eval_internal(
14301425 ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_embd, N));
14311426 ggml_set_name (cur, " KQV_merged_contiguous" );
14321427
1433- ggml_set_default_backend (ctx0, backend_offload);
14341428 // projection (no bias)
14351429 cur = ggml_mul_mat (ctx0,
14361430 model.layers [il].wo ,
14371431 cur);
1432+ offload_func (cur);
14381433 ggml_set_name (cur, " result_wo" );
14391434 }
14401435
14411436 lctx.use_buf (ctx0, 1 );
14421437 // ggml_cuda_set_scratch(1);
14431438
14441439 struct ggml_tensor * inpFF = ggml_add (ctx0, cur, inpSA);
1440+ offload_func (inpFF);
14451441 ggml_set_name (inpFF, " inpFF" );
14461442
14471443 // feed-forward network
14481444 {
14491445 // norm
14501446 {
14511447 cur = ggml_rms_norm (ctx0, inpFF);
1448+ offload_func (cur);
14521449 ggml_set_name (cur, " rms_norm_1" );
14531450
14541451 // cur = cur*ffn_norm(broadcasted)
14551452 cur = ggml_mul (ctx0, cur, model.layers [il].ffn_norm );
1453+ offload_func (cur);
14561454 ggml_set_name (cur, " ffn_norm" );
14571455 }
14581456
14591457 struct ggml_tensor * tmp = ggml_mul_mat (ctx0,
14601458 model.layers [il].w3 ,
14611459 cur);
1462- ggml_set_name (cur, " result_w3" );
1460+ offload_func (tmp);
1461+ ggml_set_name (tmp, " result_w3" );
14631462
14641463 cur = ggml_mul_mat (ctx0,
14651464 model.layers [il].w1 ,
14661465 cur);
1466+ offload_func (cur);
14671467 ggml_set_name (cur, " result_w2" );
14681468
14691469 // SILU activation
14701470 cur = ggml_silu (ctx0, cur);
1471+ offload_func (cur);
14711472 ggml_set_name (cur, " silu" );
14721473
14731474 cur = ggml_mul (ctx0, cur, tmp);
1475+ offload_func (cur);
14741476 ggml_set_name (cur, " silu_x_result_w3" );
14751477
14761478 cur = ggml_mul_mat (ctx0,
14771479 model.layers [il].w2 ,
14781480 cur);
1481+ offload_func (cur);
14791482 ggml_set_name (cur, " result_w2" );
14801483 }
14811484
14821485 cur = ggml_add (ctx0, cur, inpFF);
1486+ offload_func (cur);
14831487 ggml_set_name (cur, " inpFF_+_result_w2" );
14841488
14851489 // input for next layer
14861490 inpL = cur;
14871491
1488- ggml_set_default_backend (ctx0, GGML_BACKEND_CPU);
14891492 }
14901493
14911494 lctx.use_buf (ctx0, 0 );
@@ -1494,28 +1497,32 @@ static bool llama_eval_internal(
14941497 // used at the end to optionally extract the embeddings
14951498 struct ggml_tensor * embeddings = NULL ;
14961499
1500+ offload_func_t offload_func = llama_nop;
1501+
14971502#ifdef GGML_USE_CUBLAS
1498- if (n_gpu_layers > n_layer) {
1499- ggml_set_default_backend (ctx0, GGML_BACKEND_GPU);
1500- }
1503+ if (n_gpu_layers > n_layer) {
1504+ offload_func = ggml_cuda_assign_buffers; // sets the output backend to GPU
1505+ }
15011506#endif // GGML_USE_CUBLAS
15021507
15031508 // norm
15041509 {
15051510 cur = ggml_rms_norm (ctx0, inpL);
1511+ offload_func (cur);
15061512 ggml_set_name (cur, " rms_norm_inpL" );
15071513
15081514 cur = ggml_rms_norm (ctx0, cur);
1515+ offload_func (cur);
15091516 ggml_set_name (cur, " rms_norm_after" );
15101517
15111518 // cur = cur*norm(broadcasted)
15121519 cur = ggml_mul (ctx0, cur, model.norm );
1520+ offload_func (cur);
15131521 ggml_set_name (cur, " result_norm" );
15141522
15151523 embeddings = cur;
15161524 }
15171525
1518- ggml_set_default_backend (ctx0, GGML_BACKEND_CPU);
15191526
15201527 // lm_head
15211528 cur = ggml_mul_mat (ctx0, model.output , cur);
0 commit comments