diff --git a/convert_lora_to_gguf.py b/convert_lora_to_gguf.py index a67c0536a4128..befe8ab9cc838 100755 --- a/convert_lora_to_gguf.py +++ b/convert_lora_to_gguf.py @@ -12,7 +12,7 @@ from math import prod from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast -from transformers import AutoConfig +from transformers import AutoConfig, AutoTokenizer import torch @@ -26,6 +26,8 @@ # reuse model definitions from convert_hf_to_gguf.py from convert_hf_to_gguf import LazyTorchTensor, ModelBase +from gguf.constants import GGUFValueType + logger = logging.getLogger("lora-to-gguf") @@ -369,7 +371,31 @@ def set_type(self): self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora") def set_gguf_parameters(self): + logger.debug("GGUF KV: %s = %d", gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha) self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha) + alora_invocation_tokens = lparams.get("alora_invocation_tokens") + invocation_string = lparams.get("invocation_string") + if invocation_string and not alora_invocation_tokens: + logger.debug("Tokenizing invocation_string -> alora_invocation_tokens") + base_model_path_or_id = hparams.get("_name_or_path") + try: + tokenizer = AutoTokenizer.from_pretrained(base_model_path_or_id) + except ValueError: + logger.error("Unable to load tokenizer from %s", base_model_path_or_id) + raise + # NOTE: There's an off-by-one with the older aLoRAs where + # the invocation string includes the "<|start_of_turn|>" + # token, but the adapters themselves were trained to + # activate _after_ that first token, so we drop it here. + alora_invocation_tokens = tokenizer(invocation_string)["input_ids"][1:] + if alora_invocation_tokens: + logger.debug("GGUF KV: %s = %s", gguf.Keys.Adapter.ALORA_INVOCATION_TOKENS, alora_invocation_tokens) + self.gguf_writer.add_key_value( + gguf.Keys.Adapter.ALORA_INVOCATION_TOKENS, + alora_invocation_tokens, + GGUFValueType.ARRAY, + GGUFValueType.UINT32, + ) def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: # Never add extra tensors (e.g. rope_freqs) for LoRA adapters diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index c922bacf6eb97..b8ac394580b1f 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -231,10 +231,11 @@ class Tokenizer: MIDDLE_ID = "tokenizer.ggml.middle_token_id" class Adapter: - TYPE = "adapter.type" - LORA_ALPHA = "adapter.lora.alpha" - LORA_TASK_NAME = "adapter.lora.task_name" - LORA_PROMPT_PREFIX = "adapter.lora.prompt_prefix" + TYPE = "adapter.type" + LORA_ALPHA = "adapter.lora.alpha" + LORA_TASK_NAME = "adapter.lora.task_name" + LORA_PROMPT_PREFIX = "adapter.lora.prompt_prefix" + ALORA_INVOCATION_TOKENS = "adapter.alora.invocation_tokens" class IMatrix: CHUNK_COUNT = "imatrix.chunk_count" diff --git a/include/llama.h b/include/llama.h index 11f8a363a5733..453190e852b51 100644 --- a/include/llama.h +++ b/include/llama.h @@ -583,6 +583,10 @@ extern "C" { // Note: loaded adapters will be free when the associated model is deleted LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter); + // Get the invocation tokens if the current lora is an alora + LLAMA_API uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter); + LLAMA_API const llama_token * llama_adapter_get_alora_invocation_tokens (const struct llama_adapter_lora * adapter); + // The following functions operate on a llama_context, hence the naming: llama_verb_... // Add a loaded LoRA adapter to given context diff --git a/src/llama-adapter.cpp b/src/llama-adapter.cpp index 772ce1b448088..d8eef75a7ad70 100644 --- a/src/llama-adapter.cpp +++ b/src/llama-adapter.cpp @@ -6,6 +6,7 @@ #include #include +#include #include // vec @@ -215,6 +216,26 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_ } adapter.alpha = get_kv_f32(llm_kv(LLM_KV_ADAPTER_LORA_ALPHA)); + + // parse alora invocation sequence vector + const auto & key = llm_kv(LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS); + const int kid = gguf_find_key(ctx_gguf.get(), key.c_str()); + if (kid >= 0) { + if (gguf_get_kv_type(ctx_gguf.get(), kid) != GGUF_TYPE_ARRAY) { + throw std::runtime_error("invalid gguf type for " + key); + } + const auto arr_type = gguf_get_arr_type(ctx_gguf.get(), kid); + if (arr_type != GGUF_TYPE_UINT32) { + throw std::runtime_error("invalid gguf element type for " + key); + } + const size_t seq_len = gguf_get_arr_n(ctx_gguf.get(), kid); + const void * data = gguf_get_arr_data(ctx_gguf.get(), kid); + adapter.alora_invocation_tokens.resize(seq_len); + std::copy( + (const llama_token *)data, + (const llama_token *)data + seq_len, + adapter.alora_invocation_tokens.begin()); + } } int n_tensors = gguf_get_n_tensors(ctx_gguf.get()); @@ -450,3 +471,15 @@ int32_t llama_adapter_meta_val_str_by_index(const llama_adapter_lora * adapter, void llama_adapter_lora_free(llama_adapter_lora * adapter) { delete adapter; } + +uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter) { + if (!adapter) { + return 0; + } + return adapter->alora_invocation_tokens.size(); +} + +const llama_token * llama_adapter_get_alora_invocation_tokens(const llama_adapter_lora * adapter) { + GGML_ASSERT(adapter); + return adapter->alora_invocation_tokens.data(); +} diff --git a/src/llama-adapter.h b/src/llama-adapter.h index 9084e7cab08fd..4f65247c0feb1 100644 --- a/src/llama-adapter.h +++ b/src/llama-adapter.h @@ -70,6 +70,9 @@ struct llama_adapter_lora { // gguf metadata std::unordered_map gguf_kv; + // activated lora (aLoRA) + std::vector alora_invocation_tokens; + llama_adapter_lora() = default; ~llama_adapter_lora() = default; diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index d92bf2dd8515f..77b2fecf18fb8 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -237,10 +237,11 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" }, { LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" }, - { LLM_KV_ADAPTER_TYPE, "adapter.type" }, - { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" }, - { LLM_KV_ADAPTER_LORA_TASK_NAME, "adapter.lora.task_name" }, - { LLM_KV_ADAPTER_LORA_PROMPT_PREFIX, "adapter.lora.prompt_prefix" }, + { LLM_KV_ADAPTER_TYPE, "adapter.type" }, + { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" }, + { LLM_KV_ADAPTER_LORA_TASK_NAME, "adapter.lora.task_name" }, + { LLM_KV_ADAPTER_LORA_PROMPT_PREFIX, "adapter.lora.prompt_prefix" }, + { LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS, "adapter.alora.invocation_tokens" }, // deprecated { LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index 41f377923f50f..21ab47bd7af2a 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -235,6 +235,7 @@ enum llm_kv { LLM_KV_ADAPTER_LORA_ALPHA, LLM_KV_ADAPTER_LORA_TASK_NAME, LLM_KV_ADAPTER_LORA_PROMPT_PREFIX, + LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS, LLM_KV_POSNET_EMBEDDING_LENGTH, LLM_KV_POSNET_BLOCK_COUNT, diff --git a/tools/server/server.cpp b/tools/server/server.cpp index de7c9931177cf..73fc43bada543 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -117,7 +117,7 @@ struct slot_params { int32_t n_keep = 0; // number of tokens to keep from initial prompt int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half int32_t n_predict = -1; // new tokens to predict - int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters + int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters int64_t t_max_prompt_ms = -1; // TODO: implement int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit @@ -1382,6 +1382,7 @@ struct server_slot { common_speculative * spec = nullptr; std::vector lora; + int32_t alora_invocation_start = -1; // the index relative to completion multi-task request size_t index = 0; @@ -1476,6 +1477,9 @@ struct server_slot { // clear speculative decoding stats n_draft_total = 0; n_draft_accepted = 0; + + // clear alora start + alora_invocation_start = -1; } bool need_embd() const { @@ -2367,11 +2371,65 @@ struct server_context { slot.prompt_tokens = std::move(task.prompt_tokens); if (!are_lora_equal(slot.params.lora, slot.lora)) { - // if lora is changed, we cannot reuse cached tokens - slot.cache_tokens.clear(); + // if lora has changed, check to see if the cache should be cleared + if (lora_should_clear_cache(slot.lora, slot.params.lora)) { + SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n", slot.lora.size(), slot.params.lora.size()); + slot.cache_tokens.clear(); + } else { + SLT_INF(slot, "keeping cache for alora. %zu target loras\n", slot.params.lora.size()); + } slot.lora = slot.params.lora; } + // if using alora, make sure it's only a single one requested and active + size_t alora_invocation_start = slot.prompt_tokens.size(); + if (lora_all_alora(slot.lora)) { + + const auto & enabled_ids = lora_get_enabled_ids(slot.lora); + // TODO: This will error out if a user requests two aloras, but only + // provides the activation string for one. We could, instead search + // for all requested alora activation strings and then either keep + // only the last one, or reject if multiple are found. + if (enabled_ids.size() != 1) { + send_error(task, "Cannot run multiple aLoRAs in a single request", ERROR_TYPE_INVALID_REQUEST); + return false; + } + const auto & lora = slot.lora[enabled_ids[0]].ptr; + + // get the pointer and count for the invocation tokens + const uint64_t n_invocation_tokens = llama_adapter_get_alora_n_invocation_tokens(lora); + const llama_token * invocation_tokens = llama_adapter_get_alora_invocation_tokens (lora); + + // scan backwards through the prompt tokens to find the last + // occurrence of the invocation sequence + int match_idx = static_cast(n_invocation_tokens) - 1; + for (int i = slot.prompt_tokens.size() - 1; i >= 0; --i) { + // the token in this position matches the next token to find in + // the invocation sequence + if (slot.prompt_tokens[i] == invocation_tokens[match_idx]) { + // if it's a full match, we've found the start + if (match_idx == 0) { + alora_invocation_start = i; + break; + } + // otherwise, check the next token in the sequence + --match_idx; + } else { + // no match in this position, so start looking over again + match_idx = static_cast(n_invocation_tokens) - 1; + } + } + + // if the activation string is not found, disable the alora + if (alora_invocation_start == slot.prompt_tokens.size()) { + SLT_DBG(slot, "alora %zu requested, but not found. deactivating\n", enabled_ids[0]); + slot.lora[enabled_ids[0]].scale = 0.0f; + } else { + SLT_DBG(slot, "alora %zu activated starting at %zu\n", enabled_ids[0], alora_invocation_start); + slot.alora_invocation_start = alora_invocation_start; + } + } + if (!slot.prompt_tokens.validate(ctx)) { send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST); return false; @@ -3247,6 +3305,8 @@ struct server_context { int32_t n_ubatch = llama_n_ubatch(ctx); // next, batch any pending prompts without exceeding n_batch + float alora_scale = -1.0f; + size_t alora_disabled_id = 0; if (params_base.cont_batching || batch.n_tokens == 0) { for (auto & slot : slots) { // check if we can batch this slot with the previous one @@ -3367,6 +3427,12 @@ struct server_context { // reuse any previously computed tokens that are common with the new prompt slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens); + // if there is an alora invoked, don't cache after the invocation start + if (slot.alora_invocation_start >= 0) { + SLT_DBG(slot, "only caching to alora invocation start (n_past=%d, alora_invocation_start=%d)\n", slot.n_past, slot.alora_invocation_start); + slot.n_past = std::min(slot.n_past, slot.alora_invocation_start - 1); + } + // reuse chunks from the cached prompt by shifting their KV cache in the new position if (params_base.n_cache_reuse > 0) { size_t head_c = slot.n_past; // cache @@ -3539,6 +3605,20 @@ struct server_context { slot.n_prompt_tokens_processed += n_pos; } + // If using an alora, there may be uncached tokens that come + // before the invocation sequence. When this happens, the + // tokens before the invocation sequence need to be + // processed without the adpter in a separate batch, then + // the adapter needs to be enabled for the remaining tokens. + if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.n_past) { + SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_past = %d, alora_invocation_start = %d)\n", slot.n_past, slot.alora_invocation_start); + const auto & enabled_loras = lora_get_enabled_ids(slot.lora); + GGML_ASSERT(enabled_loras.size() == 1); + alora_scale = slot.lora[enabled_loras[0]].scale; + slot.lora[enabled_loras[0]].scale = 0.0f; + alora_disabled_id = enabled_loras[0]; + } + // add prompt tokens for processing in the current batch while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { // get next token to process @@ -3547,6 +3627,14 @@ struct server_context { break; // end of text chunk } + // if this is an alora request with pre-invocation + // tokens that are not cached, we need to stop filling + // this batch at those pre-invocation tokens. + if (alora_scale > 0 && slot.n_past == slot.alora_invocation_start - 1) { + SLT_DBG(slot, "stop prompt batch filling at (n_past = %d, alora_invocation_start = %d)\n", slot.n_past, slot.alora_invocation_start); + break; + } + // embedding requires all tokens in the batch to be output const bool need_embd = server_task_type_need_embd(slot.task_type); @@ -3605,6 +3693,13 @@ struct server_context { // apply lora, only need to do it once per batch common_set_adapter_lora(ctx, slot_batched->lora); + // if the lora is temporarily disabled for an alora, re-enable it + // for next time + if (alora_scale > 0.0f) { + SRV_DBG("re-enabling alora with scale %f\n", alora_scale); + slot_batched->lora[alora_disabled_id].scale = alora_scale; + } + llama_set_embeddings(ctx, slot_batched->need_embd()); } @@ -4990,13 +5085,26 @@ int main(int argc, char ** argv) { const auto & loras = ctx_server.params_base.lora_adapters; for (size_t i = 0; i < loras.size(); ++i) { auto & lora = loras[i]; - result.push_back({ + json entry = { {"id", i}, {"path", lora.path}, {"scale", lora.scale}, {"task_name", lora.task_name}, {"prompt_prefix", lora.prompt_prefix}, - }); + }; + std::string alora_invocation_string = ""; + const uint64_t n_alora_tokens = llama_adapter_get_alora_n_invocation_tokens(lora.ptr); + std::vector alora_invocation_tokens; + if (n_alora_tokens) { + const llama_token * alora_tokens = llama_adapter_get_alora_invocation_tokens(lora.ptr); + for (uint64_t i = 0; i < n_alora_tokens; ++i) { + alora_invocation_string += common_token_to_piece(ctx_server.ctx, alora_tokens[i]); + alora_invocation_tokens.push_back(alora_tokens[i]); + } + entry["alora_invocation_string"] = alora_invocation_string; + entry["alora_invocation_tokens"] = alora_invocation_tokens; + } + result.push_back(std::move(entry)); } res_ok(res, result); res.status = 200; // HTTP OK diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index 8636cf511d636..85fe25e0085fb 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -1002,6 +1002,47 @@ static bool are_lora_equal( return true; } +// get the ids of all enabled loras +static std::vector lora_get_enabled_ids(const std::vector & loras) { + std::vector enabled_ids; + for (size_t i = 0; i < loras.size(); ++i) { + if (loras[i].scale > 0) { + enabled_ids.push_back(i); + } + } + return enabled_ids; +} + +// check whether the given lora set has only aloras activated (empty => false) +static bool lora_all_alora(const std::vector & loras) { + bool found_alora = false; + for (const auto & lora : loras) { + if (lora.scale != 0) { + if (llama_adapter_get_alora_n_invocation_tokens(lora.ptr) == 0) { + return false; + } + found_alora = true; + } + } + return found_alora; +} + +// if the two sets of loras are different, they require a cache clear unless the +// change is only from aloras to aloras. +static bool lora_should_clear_cache( + const std::vector & current, + const std::vector & next) { + + // This should always be called after determining that the two sets are + // _not_ equal. This assert is therefore some slightly wasted work and + // should be safe to remove as long as this method is called correctly. + GGML_ASSERT(!are_lora_equal(current, next)); + + return ( + !(lora_get_enabled_ids(current).empty() || lora_all_alora(current)) || + !lora_all_alora(next)); +} + // parse lora config from JSON request, returned a copy of lora_base with updated scale static std::vector parse_lora_request( const std::vector & lora_base,