@@ -21,6 +21,8 @@ int main(int argc, char ** argv) {
21
21
return 1 ;
22
22
}
23
23
24
+ params.n_batch = params.n_ctx ;
25
+
24
26
common_init ();
25
27
26
28
int is_pp_shared = params.is_pp_shared ;
@@ -61,48 +63,21 @@ int main(int argc, char ** argv) {
61
63
62
64
llama_batch batch = llama_batch_init (n_kv_max, 0 , 1 );
63
65
64
- // decode in batches of ctx_params.n_batch tokens
65
- auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
66
- for (int32_t i = 0 ; i < (int32_t ) batch.n_tokens ; i += n_batch) {
67
- const int32_t n_tokens = std::min (n_batch, (int32_t ) (batch.n_tokens - i));
68
-
69
- llama_batch batch_view = {
70
- n_tokens,
71
- batch.token + i,
72
- nullptr ,
73
- batch.pos + i,
74
- batch.n_seq_id + i,
75
- batch.seq_id + i,
76
- batch.logits + i,
77
- };
78
-
79
- const int ret = llama_decode (ctx, batch_view);
80
- if (ret != 0 ) {
81
- LOG_ERR (" failed to decode the batch, n_batch = %d, ret = %d\n " , n_batch, ret);
82
- return false ;
83
- }
84
-
85
- llama_synchronize (ctx);
86
- }
87
-
88
- return true ;
89
- };
90
-
91
66
// warm up
92
67
{
93
68
for (int i = 0 ; i < 16 ; ++i) {
94
69
common_batch_add (batch, 0 , i, { 0 }, false );
95
70
}
96
71
97
- if (! decode_helper (ctx, batch, ctx_params. n_batch )) {
98
- LOG_ERR (" %s: llama_decode() failed\n " , __func__);
72
+ if (const auto ret = llama_decode (ctx, batch)) {
73
+ LOG_ERR (" %s: llama_decode() failed, ret = %d \n " , __func__, ret );
99
74
return 1 ;
100
75
}
101
76
}
102
77
103
78
if (!params.batched_bench_output_jsonl ) {
104
79
LOG (" \n " );
105
- LOG (" %s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n " , __func__, n_kv_max, params. n_batch , params.n_ubatch , params.flash_attn , params.is_pp_shared , params.n_gpu_layers , ctx_params.n_threads , ctx_params.n_threads_batch );
80
+ LOG (" %s: n_kv_max = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n " , __func__, n_kv_max, params.n_ubatch , params.flash_attn , params.is_pp_shared , params.n_gpu_layers , ctx_params.n_threads , ctx_params.n_threads_batch );
106
81
LOG (" \n " );
107
82
LOG (" |%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n " , " PP" , " TG" , " B" , " N_KV" , " T_PP s" , " S_PP t/s" , " T_TG s" , " S_TG t/s" , " T s" , " S t/s" );
108
83
LOG (" |%6s-|-%6s-|-%4s-|-%6s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|\n " , " ------" , " ------" , " ----" , " ------" , " --------" , " --------" , " --------" , " --------" , " --------" , " --------" );
@@ -134,9 +109,11 @@ int main(int argc, char ** argv) {
134
109
135
110
llama_kv_self_clear (ctx);
136
111
137
- if (!decode_helper (ctx, batch, ctx_params.n_batch )) {
138
- LOG_ERR (" %s: llama_decode() failed\n " , __func__);
139
- return 1 ;
112
+ if (batch.n_tokens > 0 ) {
113
+ if (const auto ret = llama_decode (ctx, batch) != 0 ) {
114
+ LOG_ERR (" %s: llama_decode() failed, ret = %d\n " , __func__, ret);
115
+ return 1 ;
116
+ }
140
117
}
141
118
142
119
if (is_pp_shared) {
@@ -156,8 +133,8 @@ int main(int argc, char ** argv) {
156
133
common_batch_add (batch, 0 , pp + i, { j }, true );
157
134
}
158
135
159
- if (! decode_helper (ctx, batch, ctx_params. n_batch ) ) {
160
- LOG_ERR (" %s: llama_decode() failed\n " , __func__);
136
+ if (const auto ret = llama_decode (ctx, batch) != 0 ) {
137
+ LOG_ERR (" %s: llama_decode() failed, ret = %d \n " , __func__, ret );
161
138
return 1 ;
162
139
}
163
140
}
0 commit comments