Open
Description
class EmbeddingModel:
"""Base class for Embedding models"""
def __init__(self, model: str,openai_api_key: str = None, openai_api_base: str = None):
self.model_name = model
self.key = openai_api_key
self.api_base = openai_api_base
self.client = None
if self.model_name == 'OpenAI':
self.client = OpenAIEmbeddings(
model=self.model_name,
openai_api_key=self.key,
openai_api_base=self.api_base
)
else:
self.client = OpenAI(
api_key=self.key,
base_url=self.api_base,
)
def embed_query(self, text: str):
if self.model_name == 'OpenAI':
return self.client.embed_query(text)
else:
completion = self.client.embeddings.create(
model=self.model_name,
input=[text],
dimensions=1024,
encoding_format="float"
)
output=completion.model_dump_json()
output = json.loads(output)
# 1. 先按index排序
sorted_data = sorted(output["data"], key=lambda x: int(x["index"]))
embeddings = [item["embedding"] for item in sorted_data]
return embeddings[0]
def embed_documents(self, texts: List[str]) -> List[np.ndarray]:
for text in texts:
if self.model_name == 'OpenAI':
return self.client.embed_documents(texts)
else:
#按照10,10分片的方式进行分割
for i in range(0, len(texts), 10):
chunk = texts[i:i + 10]
#对于chunk需要对于每一个条目裁剪成最大8192个字符
chunk = [text[:16384] for text in chunk]
# 2. 分片处理
completion = self.client.embeddings.create(
model=self.model_name,
input=chunk,
dimensions=1024,
encoding_format="float"
)
output=completion.model_dump_json()
output = json.loads(output)
# 1. 先按index排序
sorted_data = sorted(output["data"], key=lambda x: int(x["index"]))
embeddings = [item["embedding"] for item in sorted_data]
# 3. 拼接
if i == 0:
all_embeddings = embeddings
else:
all_embeddings.extend(embeddings)
return all_embeddings
async def aembed_query(self, text: str) -> List[float]:
"""异步生成单个文本的embedding"""
if self.model_name == 'OpenAI':
# 假设OpenAIEmbeddings有异步方法
return await self.client.aembed_query(text)
else:
async with aiohttp.ClientSession() as session:
headers = {
"Authorization": f"Bearer {self.key}",
"Content-Type": "application/json"
}
payload = {
"model": self.model_name,
"input": [text],
"dimensions": 1024,
"encoding_format": "float"
}
async with session.post(
f"{self.api_base}/embeddings",
headers=headers,
json=payload
) as response:
output = await response.json()
sorted_data = sorted(output["data"], key=lambda x: int(x["index"]))
return [item["embedding"] for item in sorted_data][0]
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""异步批量生成embedding"""
if self.model_name == 'OpenAI':
return await self.client.aembed_documents(texts)
else:
# 使用asyncio.gather并发处理
tasks = [self.aembed_query(text) for text in texts]
return await asyncio.gather(*tasks)
Metadata
Metadata
Assignees
Labels
No labels