From 4073532a7b9d6ee1dc30221562b3e9e22d9ec689 Mon Sep 17 00:00:00 2001 From: Douglas Hanley Date: Fri, 8 Mar 2024 12:36:39 -0600 Subject: [PATCH 1/2] switch to llama_get_embeddings_seq --- llama_cpp/llama.py | 2 +- llama_cpp/llama_cpp.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 7187b4a17..aabbb7e71 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -814,7 +814,7 @@ def decode_batch(n_seq: int): # store embeddings for i in range(n_seq): - embedding: List[float] = llama_cpp.llama_get_embeddings_ith( + embedding: List[float] = llama_cpp.llama_get_embeddings_seq( self._ctx.ctx, i )[:n_embd] if normalize: diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 92b96766c..c23d0e203 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -1803,6 +1803,22 @@ def llama_get_embeddings_ith( ... +# // Get the embeddings for sequence seq_id +# // shape: [n_embd] (1-dimensional) +# LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); +@ctypes_function( + "llama_get_embeddings_seq", + [llama_context_p_ctypes, ctypes.c_int32], + ctypes.POINTER(ctypes.c_float), +) +def llama_get_embeddings_seq( + ctx: llama_context_p, i: Union[ctypes.c_int32, int], / +) -> CtypesArray[ctypes.c_float]: + """Get the embeddings for sequence seq_id + shape: [n_embd] (1-dimensional)""" + ... + + # // Get the embeddings for a sequence id # // Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE # // shape: [n_embd] (1-dimensional) From 3a23829fc2e869a805fa71b94bea488cf79c5a9f Mon Sep 17 00:00:00 2001 From: Douglas Hanley Date: Fri, 8 Mar 2024 20:46:40 -0500 Subject: [PATCH 2/2] Remove duplicate definition of llama_get_embeddings_seq Co-authored-by: Andrei --- llama_cpp/llama_cpp.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index c23d0e203..92b96766c 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -1803,22 +1803,6 @@ def llama_get_embeddings_ith( ... -# // Get the embeddings for sequence seq_id -# // shape: [n_embd] (1-dimensional) -# LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); -@ctypes_function( - "llama_get_embeddings_seq", - [llama_context_p_ctypes, ctypes.c_int32], - ctypes.POINTER(ctypes.c_float), -) -def llama_get_embeddings_seq( - ctx: llama_context_p, i: Union[ctypes.c_int32, int], / -) -> CtypesArray[ctypes.c_float]: - """Get the embeddings for sequence seq_id - shape: [n_embd] (1-dimensional)""" - ... - - # // Get the embeddings for a sequence id # // Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE # // shape: [n_embd] (1-dimensional)