@@ -480,6 +480,10 @@ bool LlamaEngine::LoadModelImpl(std::shared_ptr<Json::Value> json_body) {
480480 if (!params.use_mmap ) {
481481 LOG_DEBUG << " Disabled mmap" ;
482482 }
483+ params.n_predict = json_body->get (" n_predict" , -1 ).asInt ();
484+ params.prompt = json_body->get (" prompt" , " " ).asString ();
485+ params.conversation = json_body->get (" conversation" , false ).asBool ();
486+ params.special = json_body->get (" special" , false ).asBool ();
483487
484488 server_map_[model_id].caching_enabled =
485489 json_body->get (" caching_enabled" , true ).asBool ();
@@ -599,6 +603,24 @@ void LlamaEngine::HandleInferenceImpl(
599603 data[" temperature" ] = completion.temperature ;
600604 data[" frequency_penalty" ] = completion.frequency_penalty ;
601605 data[" presence_penalty" ] = completion.presence_penalty ;
606+ data[" seed" ] = completion.seed ;
607+ data[" dynatemp_range" ] = completion.dynatemp_range ;
608+ data[" dynatemp_exponent" ] = completion.dynatemp_exponent ;
609+ data[" top_k" ] = completion.top_k ;
610+ data[" min_p" ] = completion.min_p ;
611+ data[" tfs_z" ] = completion.tfs_z ;
612+ data[" typical_p" ] = completion.typ_p ;
613+ data[" repeat_last_n" ] = completion.repeat_last_n ;
614+ data[" repeat_penalty" ] = completion.penalty_repeat ;
615+ data[" mirostat" ] = completion.mirostat ;
616+ data[" mirostat_tau" ] = completion.mirostat_tau ;
617+ data[" mirostat_eta" ] = completion.mirostat_eta ;
618+ data[" penalize_nl" ] = completion.penalize_nl ;
619+ data[" ignore_eos" ] = completion.ignore_eos ;
620+ data[" n_probs" ] = completion.n_probs ;
621+ data[" min_keep" ] = completion.min_keep ;
622+ data[" grammar" ] = completion.grammar ;
623+ int n_probs = completion.n_probs ;
602624 const Json::Value& messages = completion.messages ;
603625
604626 if (!si.grammar_file_content .empty ()) {
@@ -717,12 +739,17 @@ void LlamaEngine::HandleInferenceImpl(
717739 auto state = CreateInferenceState (si.ctx );
718740
719741 // Queued task
720- si.q ->runTaskInQueue ([cb = std::move (callback), state, data, request_id]() {
742+ si.q ->runTaskInQueue ([cb = std::move (callback), state, data, request_id, n_probs ]() {
721743 state->task_id = state->llama .RequestCompletion (data, false , false , -1 );
722744 while (state->llama .model_loaded_external ) {
723745 TaskResult result = state->llama .NextResult (state->task_id );
724746 if (!result.error ) {
725- std::string to_send = result.result_json [" content" ];
747+ std::string to_send;
748+ if (n_probs > 0 ){
749+ to_send = result.result_json [" completion_probabilities" ].dump ();
750+ }else {
751+ to_send = result.result_json [" content" ];
752+ }
726753 // trim the leading space if it is the first token
727754 if (std::exchange (state->is_first_token , false )) {
728755 llama_utils::ltrim (to_send);
0 commit comments