Skip to content

Newest Embedding model #8

Open
Open
@chengzicong20040913

Description

@chengzicong20040913
class EmbeddingModel:
    """Base class for Embedding models"""
    def __init__(self, model: str,openai_api_key: str = None, openai_api_base: str = None):
        self.model_name = model
        self.key = openai_api_key
        self.api_base = openai_api_base
        self.client = None
        if self.model_name == 'OpenAI':
            self.client = OpenAIEmbeddings(
                model=self.model_name,
                openai_api_key=self.key,
                openai_api_base=self.api_base
            )
        else:
            self.client = OpenAI(
                api_key=self.key,
                base_url=self.api_base,
            )

    def embed_query(self, text: str):
        if self.model_name == 'OpenAI':
            return self.client.embed_query(text)
        else:
            completion = self.client.embeddings.create(
                model=self.model_name,
                input=[text],
                dimensions=1024,
                encoding_format="float"
            )
            output=completion.model_dump_json()
            output = json.loads(output)
            # 1. 先按index排序
            sorted_data = sorted(output["data"], key=lambda x: int(x["index"]))
            embeddings = [item["embedding"] for item in sorted_data]
            return embeddings[0]
            

    def embed_documents(self, texts: List[str]) -> List[np.ndarray]:
        for text in texts:
            if self.model_name == 'OpenAI':
                return self.client.embed_documents(texts)
            else:
                #按照10,10分片的方式进行分割
                for i in range(0, len(texts), 10):
                    chunk = texts[i:i + 10]
                    #对于chunk需要对于每一个条目裁剪成最大8192个字符
                    chunk = [text[:16384] for text in chunk]
                    # 2. 分片处理
                    completion = self.client.embeddings.create(
                        model=self.model_name,
                        input=chunk,
                        dimensions=1024,
                        encoding_format="float"
                    )
                    output=completion.model_dump_json()
                    output = json.loads(output)
                    # 1. 先按index排序
                    sorted_data = sorted(output["data"], key=lambda x: int(x["index"]))
                    embeddings = [item["embedding"] for item in sorted_data]
                    # 3. 拼接
                    if i == 0:
                        all_embeddings = embeddings
                    else:
                        all_embeddings.extend(embeddings)
                return all_embeddings
            
    async def aembed_query(self, text: str) -> List[float]:
        """异步生成单个文本的embedding"""
        if self.model_name == 'OpenAI':
            # 假设OpenAIEmbeddings有异步方法
            return await self.client.aembed_query(text)
        else:
            async with aiohttp.ClientSession() as session:
                headers = {
                    "Authorization": f"Bearer {self.key}",
                    "Content-Type": "application/json"
                }
                payload = {
                    "model": self.model_name,
                    "input": [text],
                    "dimensions": 1024,
                    "encoding_format": "float"
                }
                
                async with session.post(
                    f"{self.api_base}/embeddings",
                    headers=headers,
                    json=payload
                ) as response:
                    output = await response.json()
                    sorted_data = sorted(output["data"], key=lambda x: int(x["index"]))
                    return [item["embedding"] for item in sorted_data][0]

    async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
        """异步批量生成embedding"""
        if self.model_name == 'OpenAI':
            return await self.client.aembed_documents(texts)
        else:
            # 使用asyncio.gather并发处理
            tasks = [self.aembed_query(text) for text in texts]
            return await asyncio.gather(*tasks)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions