Skip to content

Commit 2f1c4ef

Browse files
authored
Merge pull request Azure-Samples#70 from Azure-Samples/embedcolumn
Add embedding with ollama
2 parents d1e990c + c124a84 commit 2f1c4ef

15 files changed

+234140
-954
lines changed

.env.sample

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,17 @@ AZURE_OPENAI_CHAT_MODEL=gpt-35-turbo
1818
AZURE_OPENAI_EMBED_DEPLOYMENT=text-embedding-ada-002
1919
AZURE_OPENAI_EMBED_MODEL=text-embedding-ada-002
2020
AZURE_OPENAI_EMBED_MODEL_DIMENSIONS=1536
21+
AZURE_OPENAI_EMBEDDING_COLUMN=embedding_ada002
2122
# Only needed when using key-based Azure authentication:
2223
AZURE_OPENAI_KEY=
2324
# Needed for OpenAI.com:
2425
OPENAICOM_KEY=YOUR-OPENAI-API-KEY
2526
OPENAICOM_CHAT_MODEL=gpt-3.5-turbo
2627
OPENAICOM_EMBED_MODEL=text-embedding-ada-002
2728
OPENAICOM_EMBED_MODEL_DIMENSIONS=1536
29+
OPENAICOM_EMBEDDING_COLUMN=embedding_ada002
2830
# Needed for Ollama:
2931
OLLAMA_ENDPOINT=http://host.docker.internal:11434/v1
30-
OLLAMA_CHAT_MODEL=phi3:3.8b
32+
OLLAMA_CHAT_MODEL=llama3.1
33+
OLLAMA_EMBED_MODEL=nomic-embed-text
34+
OLLAMA_EMBEDDING_COLUMN=embedding_nomic

.vscode/settings.json

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,10 @@
2222
"ssl": true
2323
}
2424
}
25-
]
25+
],
26+
"python.testing.pytestArgs": [
27+
"tests"
28+
],
29+
"python.testing.unittestEnabled": false,
30+
"python.testing.pytestEnabled": true
2631
}

src/backend/fastapi_app/dependencies.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,10 @@ class FastAPIAppContext(BaseModel):
2828

2929
openai_chat_model: str
3030
openai_embed_model: str
31-
openai_embed_dimensions: int
31+
openai_embed_dimensions: int | None
3232
openai_chat_deployment: str | None
3333
openai_embed_deployment: str | None
34+
embedding_column: str
3435

3536

3637
async def common_parameters():
@@ -43,16 +44,24 @@ async def common_parameters():
4344
openai_embed_deployment = os.getenv("AZURE_OPENAI_EMBED_DEPLOYMENT", "text-embedding-ada-002")
4445
openai_embed_model = os.getenv("AZURE_OPENAI_EMBED_MODEL", "text-embedding-ada-002")
4546
openai_embed_dimensions = int(os.getenv("AZURE_OPENAI_EMBED_DIMENSIONS", 1536))
47+
embedding_column = os.getenv("AZURE_OPENAI_EMBEDDING_COLUMN", "embedding_ada002")
48+
elif OPENAI_EMBED_HOST == "ollama":
49+
openai_embed_deployment = None
50+
openai_embed_model = os.getenv("OLLAMA_EMBED_MODEL", "nomic-embed-text")
51+
openai_embed_dimensions = None
52+
embedding_column = os.getenv("OLLAMA_EMBEDDING_COLUMN", "embedding_nomic")
4653
else:
47-
openai_embed_deployment = "text-embedding-ada-002"
54+
openai_embed_deployment = None
4855
openai_embed_model = os.getenv("OPENAICOM_EMBED_MODEL", "text-embedding-ada-002")
4956
openai_embed_dimensions = int(os.getenv("OPENAICOM_EMBED_DIMENSIONS", 1536))
57+
embedding_column = os.getenv("OPENAICOM_EMBEDDING_COLUMN", "embedding_ada002")
5058
if OPENAI_CHAT_HOST == "azure":
5159
openai_chat_deployment = os.getenv("AZURE_OPENAI_CHAT_DEPLOYMENT", "gpt-35-turbo")
5260
openai_chat_model = os.getenv("AZURE_OPENAI_CHAT_MODEL", "gpt-35-turbo")
5361
elif OPENAI_CHAT_HOST == "ollama":
5462
openai_chat_deployment = None
5563
openai_chat_model = os.getenv("OLLAMA_CHAT_MODEL", "phi3:3.8b")
64+
openai_embed_model = os.getenv("OLLAMA_EMBED_MODEL", "nomic-embed-text")
5665
else:
5766
openai_chat_deployment = None
5867
openai_chat_model = os.getenv("OPENAICOM_CHAT_MODEL", "gpt-3.5-turbo")
@@ -62,6 +71,7 @@ async def common_parameters():
6271
openai_embed_dimensions=openai_embed_dimensions,
6372
openai_chat_deployment=openai_chat_deployment,
6473
openai_embed_deployment=openai_embed_deployment,
74+
embedding_column=embedding_column,
6575
)
6676

6777

src/backend/fastapi_app/embeddings.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ async def compute_text_embedding(
1010
openai_client: AsyncOpenAI | AsyncAzureOpenAI,
1111
embed_model: str,
1212
embed_deployment: str | None = None,
13-
embedding_dimensions: int = 1536,
13+
embedding_dimensions: int | None = None,
1414
) -> list[float]:
1515
SUPPORTED_DIMENSIONS_MODEL = {
1616
"text-embedding-ada-002": False,
@@ -21,7 +21,12 @@ async def compute_text_embedding(
2121
class ExtraArgs(TypedDict, total=False):
2222
dimensions: int
2323

24-
dimensions_args: ExtraArgs = {"dimensions": embedding_dimensions} if SUPPORTED_DIMENSIONS_MODEL[embed_model] else {}
24+
dimensions_args: ExtraArgs = {}
25+
if SUPPORTED_DIMENSIONS_MODEL.get(embed_model):
26+
if embedding_dimensions is None:
27+
raise ValueError(f"Model {embed_model} requires embedding dimensions")
28+
else:
29+
dimensions_args = {"dimensions": embedding_dimensions}
2530

2631
embedding = await openai_client.embeddings.create(
2732
# Azure OpenAI takes the deployment name as the model name

src/backend/fastapi_app/openai_clients.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,13 @@ async def create_openai_embed_client(
7676
azure_deployment=azure_deployment,
7777
azure_ad_token_provider=token_provider,
7878
)
79-
79+
elif OPENAI_EMBED_HOST == "ollama":
80+
logger.info("Authenticating to OpenAI using Ollama...")
81+
openai_embed_client = openai.AsyncOpenAI(
82+
base_url=os.getenv("OLLAMA_ENDPOINT"),
83+
api_key="nokeyneeded",
84+
)
8085
else:
86+
logger.info("Authenticating to OpenAI using OpenAI.com API key...")
8187
openai_embed_client = openai.AsyncOpenAI(api_key=os.getenv("OPENAICOM_KEY"))
8288
return openai_embed_client

src/backend/fastapi_app/postgres_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def get_password_from_azure_credential():
3030

3131
engine = create_async_engine(
3232
DATABASE_URI,
33-
echo=False,
33+
echo=True,
3434
)
3535

3636
@event.listens_for(engine.sync_engine, "do_connect")

src/backend/fastapi_app/postgres_models.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,17 @@ class Item(Base):
2020
name: Mapped[str] = mapped_column()
2121
description: Mapped[str] = mapped_column()
2222
price: Mapped[float] = mapped_column()
23-
embedding: Mapped[Vector] = mapped_column(Vector(1536)) # ada-002
23+
embedding_ada002: Mapped[Vector] = mapped_column(Vector(1536)) # ada-002
24+
embedding_nomic: Mapped[Vector] = mapped_column(Vector(768)) # nomic-embed-text
2425

2526
def to_dict(self, include_embedding: bool = False):
2627
model_dict = asdict(self)
2728
if include_embedding:
28-
model_dict["embedding"] = model_dict["embedding"].tolist()
29+
model_dict["embedding_ada002"] = model_dict.get("embedding_ada002", [])
30+
model_dict["embedding_nomic"] = model_dict.get("embedding_nomic", [])
2931
else:
30-
del model_dict["embedding"]
32+
del model_dict["embedding_ada002"]
33+
del model_dict["embedding_nomic"]
3134
return model_dict
3235

3336
def to_str_for_rag(self):
@@ -38,10 +41,18 @@ def to_str_for_embedding(self):
3841

3942

4043
# Define HNSW index to support vector similarity search through the vector_cosine_ops access method (cosine distance).
41-
index = Index(
42-
"hnsw_index_for_innerproduct_item_embedding",
43-
Item.embedding,
44+
index_ada002 = Index(
45+
"hnsw_index_for_innerproduct_item_embedding_ada002",
46+
Item.embedding_ada002,
4447
postgresql_using="hnsw",
4548
postgresql_with={"m": 16, "ef_construction": 64},
46-
postgresql_ops={"embedding": "vector_ip_ops"},
49+
postgresql_ops={"embedding_ada002": "vector_ip_ops"},
50+
)
51+
52+
index_nomic = Index(
53+
"hnsw_index_for_innerproduct_item_embedding_nomic",
54+
Item.embedding_nomic,
55+
postgresql_using="hnsw",
56+
postgresql_with={"m": 16, "ef_construction": 64},
57+
postgresql_ops={"embedding_nomic": "vector_ip_ops"},
4758
)

src/backend/fastapi_app/postgres_searcher.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@ def __init__(
1414
openai_embed_client: AsyncOpenAI | AsyncAzureOpenAI,
1515
embed_deployment: str | None, # Not needed for non-Azure OpenAI or for retrieval_mode="text"
1616
embed_model: str,
17-
embed_dimensions: int,
17+
embed_dimensions: int | None,
18+
embedding_column: str,
1819
):
1920
self.db_session = db_session
2021
self.openai_embed_client = openai_embed_client
2122
self.embed_model = embed_model
2223
self.embed_deployment = embed_deployment
2324
self.embed_dimensions = embed_dimensions
25+
self.embedding_column = embedding_column
2426

2527
def build_filter_clause(self, filters) -> tuple[str, str]:
2628
if filters is None:
@@ -36,19 +38,15 @@ def build_filter_clause(self, filters) -> tuple[str, str]:
3638
return "", ""
3739

3840
async def search(
39-
self,
40-
query_text: str | None,
41-
query_vector: list[float] | list,
42-
top: int = 5,
43-
filters: list[dict] | None = None,
41+
self, query_text: str | None, query_vector: list[float] | list, top: int = 5, filters: list[dict] | None = None
4442
):
4543
filter_clause_where, filter_clause_and = self.build_filter_clause(filters)
4644

4745
vector_query = f"""
48-
SELECT id, RANK () OVER (ORDER BY embedding <=> :embedding) AS rank
46+
SELECT id, RANK () OVER (ORDER BY {self.embedding_column} <=> :embedding) AS rank
4947
FROM items
5048
{filter_clause_where}
51-
ORDER BY embedding <=> :embedding
49+
ORDER BY {self.embedding_column} <=> :embedding
5250
LIMIT 20
5351
"""
5452

src/backend/fastapi_app/routes/api_routes.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,18 @@ async def item_handler(database_session: DBSession, id: int) -> ItemPublic:
4545

4646

4747
@router.get("/similar", response_model=list[ItemWithDistance])
48-
async def similar_handler(database_session: DBSession, id: int, n: int = 5) -> list[ItemWithDistance]:
48+
async def similar_handler(
49+
context: CommonDeps, database_session: DBSession, id: int, n: int = 5
50+
) -> list[ItemWithDistance]:
4951
"""A similarity API to find items similar to items with given ID."""
5052
item = (await database_session.scalars(select(Item).where(Item.id == id))).first()
5153
if not item:
5254
raise HTTPException(detail=f"Item with ID {id} not found.", status_code=404)
55+
5356
closest = await database_session.execute(
54-
select(Item, Item.embedding.l2_distance(item.embedding))
57+
select(Item, Item.embedding_ada002.l2_distance(item.embedding_ada002))
5558
.filter(Item.id != id)
56-
.order_by(Item.embedding.l2_distance(item.embedding))
59+
.order_by(Item.embedding_ada002.l2_distance(item.embedding_ada002))
5760
.limit(n)
5861
)
5962
return [
@@ -78,6 +81,7 @@ async def search_handler(
7881
embed_deployment=context.openai_embed_deployment,
7982
embed_model=context.openai_embed_model,
8083
embed_dimensions=context.openai_embed_dimensions,
84+
embedding_column=context.embedding_column,
8185
)
8286
results = await searcher.search_and_embed(
8387
query, top=top, enable_vector_search=enable_vector_search, enable_text_search=enable_text_search
@@ -99,6 +103,7 @@ async def chat_handler(
99103
embed_deployment=context.openai_embed_deployment,
100104
embed_model=context.openai_embed_model,
101105
embed_dimensions=context.openai_embed_dimensions,
106+
embedding_column=context.embedding_column,
102107
)
103108
rag_flow: SimpleRAGChat | AdvancedRAGChat
104109
if chat_request.context.overrides.use_advanced_flow:
@@ -139,6 +144,7 @@ async def chat_stream_handler(
139144
embed_deployment=context.openai_embed_deployment,
140145
embed_model=context.openai_embed_model,
141146
embed_dimensions=context.openai_embed_dimensions,
147+
embedding_column=context.embedding_column,
142148
)
143149

144150
rag_flow: SimpleRAGChat | AdvancedRAGChat

src/backend/fastapi_app/seed_data.json

Lines changed: 233915 additions & 908 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)