@@ -47,8 +47,8 @@ void llama_ngram_cache_update(llama_ngram_cache & ngram_cache, int ngram_min, in
47
47
}
48
48
49
49
// Helper function to get a token from the combined, speculative sequence of inp and draft.
50
- static llama_token get_token (const std::vector< llama_token> & inp , const std::vector<llama_token> & draft, const size_t i) {
51
- return i < inp. size () ? inp [i] : draft[1 + i - inp. size () ];
50
+ static llama_token get_token (const llama_token * inp_data , const int inp_size, const std::vector<llama_token> & draft, const int i) {
51
+ return i < inp_size ? inp_data [i] : draft[1 + i - inp_size ];
52
52
}
53
53
54
54
// If sample size or percentage are below these thresholds the draft is aborted early:
@@ -139,11 +139,10 @@ static llama_token try_draft(
139
139
}
140
140
141
141
void llama_ngram_cache_draft (
142
- std::vector< llama_token> & inp , std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max,
142
+ llama_token * inp_data, int inp_size , std::vector<llama_token> & draft, int n_draft, int ngram_min, int ngram_max,
143
143
llama_ngram_cache & nc_context, llama_ngram_cache & nc_dynamic, llama_ngram_cache & nc_static
144
144
) {
145
145
GGML_ASSERT (draft.size () == 1 );
146
- const int inp_size = inp.size ();
147
146
148
147
if (inp_size < LLAMA_NGRAM_STATIC) {
149
148
return ;
@@ -155,7 +154,7 @@ void llama_ngram_cache_draft(
155
154
const int ngram_start_static = inp_size-LLAMA_NGRAM_STATIC + draft.size ()-1 ;
156
155
llama_ngram ngram_static;
157
156
for (int j = ngram_start_static; j < ngram_start_static + LLAMA_NGRAM_STATIC; ++j) {
158
- ngram_static.tokens [j-ngram_start_static] = get_token (inp , draft, j);
157
+ ngram_static.tokens [j-ngram_start_static] = get_token (inp_data, inp_size , draft, j);
159
158
}
160
159
llama_ngram_cache::iterator part_static_it = nc_static.find (ngram_static);
161
160
llama_ngram_cache_part part_static;
@@ -169,7 +168,7 @@ void llama_ngram_cache_draft(
169
168
const int ngram_start_cd = inp_size-ngram_size_cd + draft.size ()-1 ;
170
169
llama_ngram ngram_cd;
171
170
for (int j = ngram_start_cd; j < ngram_start_cd + ngram_size_cd; ++j) {
172
- ngram_cd.tokens [j-ngram_start_cd] = get_token (inp , draft, j);
171
+ ngram_cd.tokens [j-ngram_start_cd] = get_token (inp_data, inp_size , draft, j);
173
172
}
174
173
ngrams_cd.push_back (ngram_cd);
175
174
}
0 commit comments