Skip to content

Commit 076f0cb

Browse files
feat: adds hybrid search for async VS interface [3/N]
2 parents 9611164 + 08a4ff6 commit 076f0cb

File tree

6 files changed

+723
-23
lines changed

6 files changed

+723
-23
lines changed

langchain_postgres/v2/async_vectorstore.py

Lines changed: 152 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from sqlalchemy.ext.asyncio import AsyncEngine
1515

1616
from .engine import PGEngine
17+
from .hybrid_search_config import HybridSearchConfig
1718
from .indexes import (
1819
DEFAULT_DISTANCE_STRATEGY,
1920
DEFAULT_INDEX_NAME_SUFFIX,
@@ -77,6 +78,7 @@ def __init__(
7778
fetch_k: int = 20,
7879
lambda_mult: float = 0.5,
7980
index_query_options: Optional[QueryOptions] = None,
81+
hybrid_search_config: Optional[HybridSearchConfig] = None,
8082
):
8183
"""AsyncPGVectorStore constructor.
8284
Args:
@@ -95,6 +97,7 @@ def __init__(
9597
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
9698
lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
9799
index_query_options (QueryOptions): Index query option.
100+
hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None.
98101
99102
100103
Raises:
@@ -119,6 +122,7 @@ def __init__(
119122
self.fetch_k = fetch_k
120123
self.lambda_mult = lambda_mult
121124
self.index_query_options = index_query_options
125+
self.hybrid_search_config = hybrid_search_config
122126

123127
@classmethod
124128
async def create(
@@ -139,6 +143,7 @@ async def create(
139143
fetch_k: int = 20,
140144
lambda_mult: float = 0.5,
141145
index_query_options: Optional[QueryOptions] = None,
146+
hybrid_search_config: Optional[HybridSearchConfig] = None,
142147
) -> AsyncPGVectorStore:
143148
"""Create an AsyncPGVectorStore instance.
144149
@@ -158,6 +163,7 @@ async def create(
158163
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
159164
lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
160165
index_query_options (QueryOptions): Index query option.
166+
hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None.
161167
162168
Returns:
163169
AsyncPGVectorStore
@@ -193,6 +199,15 @@ async def create(
193199
raise ValueError(
194200
f"Content column, {content_column}, is type, {content_type}. It must be a type of character string."
195201
)
202+
if hybrid_search_config:
203+
tsv_column_name = (
204+
hybrid_search_config.tsv_column
205+
if hybrid_search_config.tsv_column
206+
else content_column + "_tsv"
207+
)
208+
if tsv_column_name not in columns or columns[tsv_column_name] != "tsvector":
209+
# mark tsv_column as empty because there is no TSV column in table
210+
hybrid_search_config.tsv_column = ""
196211
if embedding_column not in columns:
197212
raise ValueError(f"Embedding column, {embedding_column}, does not exist.")
198213
if columns[embedding_column] != "USER-DEFINED":
@@ -236,6 +251,7 @@ async def create(
236251
fetch_k=fetch_k,
237252
lambda_mult=lambda_mult,
238253
index_query_options=index_query_options,
254+
hybrid_search_config=hybrid_search_config,
239255
)
240256

241257
@property
@@ -273,7 +289,12 @@ async def aadd_embeddings(
273289
if len(self.metadata_columns) > 0
274290
else ""
275291
)
276-
insert_stmt = f'INSERT INTO "{self.schema_name}"."{self.table_name}"("{self.id_column}", "{self.content_column}", "{self.embedding_column}"{metadata_col_names}'
292+
hybrid_search_column = (
293+
f', "{self.hybrid_search_config.tsv_column}"'
294+
if self.hybrid_search_config and self.hybrid_search_config.tsv_column
295+
else ""
296+
)
297+
insert_stmt = f'INSERT INTO "{self.schema_name}"."{self.table_name}"("{self.id_column}", "{self.content_column}", "{self.embedding_column}"{hybrid_search_column}{metadata_col_names}'
277298
values = {
278299
"id": id,
279300
"content": content,
@@ -284,6 +305,14 @@ async def aadd_embeddings(
284305
if not embedding and can_inline_embed:
285306
values_stmt = f"VALUES (:id, :content, {self.embedding_service.embed_query_inline(content)}" # type: ignore
286307

308+
if self.hybrid_search_config and self.hybrid_search_config.tsv_column:
309+
lang = (
310+
f"'{self.hybrid_search_config.tsv_lang}',"
311+
if self.hybrid_search_config.tsv_lang
312+
else ""
313+
)
314+
values_stmt += f", to_tsvector({lang} :tsv_content)"
315+
values["tsv_content"] = content
287316
# Add metadata
288317
extra = copy.deepcopy(metadata)
289318
for metadata_column in self.metadata_columns:
@@ -308,6 +337,9 @@ async def aadd_embeddings(
308337

309338
upsert_stmt = f' ON CONFLICT ("{self.id_column}") DO UPDATE SET "{self.content_column}" = EXCLUDED."{self.content_column}", "{self.embedding_column}" = EXCLUDED."{self.embedding_column}"'
310339

340+
if self.hybrid_search_config and self.hybrid_search_config.tsv_column:
341+
upsert_stmt += f', "{self.hybrid_search_config.tsv_column}" = EXCLUDED."{self.hybrid_search_config.tsv_column}"'
342+
311343
if self.metadata_json_column:
312344
upsert_stmt += f', "{self.metadata_json_column}" = EXCLUDED."{self.metadata_json_column}"'
313345

@@ -408,6 +440,7 @@ async def afrom_texts( # type: ignore[override]
408440
fetch_k: int = 20,
409441
lambda_mult: float = 0.5,
410442
index_query_options: Optional[QueryOptions] = None,
443+
hybrid_search_config: Optional[HybridSearchConfig] = None,
411444
**kwargs: Any,
412445
) -> AsyncPGVectorStore:
413446
"""Create an AsyncPGVectorStore instance from texts.
@@ -453,6 +486,7 @@ async def afrom_texts( # type: ignore[override]
453486
fetch_k=fetch_k,
454487
lambda_mult=lambda_mult,
455488
index_query_options=index_query_options,
489+
hybrid_search_config=hybrid_search_config,
456490
)
457491
await vs.aadd_texts(texts, metadatas=metadatas, ids=ids, **kwargs)
458492
return vs
@@ -478,6 +512,7 @@ async def afrom_documents( # type: ignore[override]
478512
fetch_k: int = 20,
479513
lambda_mult: float = 0.5,
480514
index_query_options: Optional[QueryOptions] = None,
515+
hybrid_search_config: Optional[HybridSearchConfig] = None,
481516
**kwargs: Any,
482517
) -> AsyncPGVectorStore:
483518
"""Create an AsyncPGVectorStore instance from documents.
@@ -524,6 +559,7 @@ async def afrom_documents( # type: ignore[override]
524559
fetch_k=fetch_k,
525560
lambda_mult=lambda_mult,
526561
index_query_options=index_query_options,
562+
hybrid_search_config=hybrid_search_config,
527563
)
528564
texts = [doc.page_content for doc in documents]
529565
metadatas = [doc.metadata for doc in documents]
@@ -538,16 +574,30 @@ async def __query_collection(
538574
filter: Optional[dict] = None,
539575
**kwargs: Any,
540576
) -> Sequence[RowMapping]:
541-
"""Perform similarity search query on database."""
542-
k = k if k else self.k
577+
"""
578+
Perform similarity search (or hybrid search) query on database.
579+
Queries might be slow if the hybrid search column does not exist.
580+
For best hybrid search performance, consider creating a TSV column
581+
and adding GIN index.
582+
"""
583+
if not k:
584+
k = (
585+
max(
586+
self.k,
587+
self.hybrid_search_config.primary_top_k,
588+
self.hybrid_search_config.secondary_top_k,
589+
)
590+
if self.hybrid_search_config
591+
else self.k
592+
)
543593
operator = self.distance_strategy.operator
544594
search_function = self.distance_strategy.search_function
545595

546-
columns = self.metadata_columns + [
596+
columns = [
547597
self.id_column,
548598
self.content_column,
549599
self.embedding_column,
550-
]
600+
] + self.metadata_columns
551601
if self.metadata_json_column:
552602
columns.append(self.metadata_json_column)
553603

@@ -557,16 +607,17 @@ async def __query_collection(
557607
filter_dict = None
558608
if filter and isinstance(filter, dict):
559609
safe_filter, filter_dict = self._create_filter_clause(filter)
560-
param_filter = f"WHERE {safe_filter}" if safe_filter else ""
610+
561611
inline_embed_func = getattr(self.embedding_service, "embed_query_inline", None)
562612
if not embedding and callable(inline_embed_func) and "query" in kwargs:
563613
query_embedding = self.embedding_service.embed_query_inline(kwargs["query"]) # type: ignore
564614
embedding_data_string = f"{query_embedding}"
565615
else:
566616
query_embedding = f"{[float(dimension) for dimension in embedding]}"
567617
embedding_data_string = ":query_embedding"
568-
stmt = f"""SELECT {column_names}, {search_function}("{self.embedding_column}", {embedding_data_string}) as distance
569-
FROM "{self.schema_name}"."{self.table_name}" {param_filter} ORDER BY "{self.embedding_column}" {operator} {embedding_data_string} LIMIT :k;
618+
where_filters = f"WHERE {safe_filter}" if safe_filter else ""
619+
dense_query_stmt = f"""SELECT {column_names}, {search_function}("{self.embedding_column}", {embedding_data_string}) as distance
620+
FROM "{self.schema_name}"."{self.table_name}" {where_filters} ORDER BY "{self.embedding_column}" {operator} {embedding_data_string} LIMIT :k;
570621
"""
571622
param_dict = {"query_embedding": query_embedding, "k": k}
572623
if filter_dict:
@@ -577,15 +628,51 @@ async def __query_collection(
577628
for query_option in self.index_query_options.to_parameter():
578629
query_options_stmt = f"SET LOCAL {query_option};"
579630
await conn.execute(text(query_options_stmt))
580-
result = await conn.execute(text(stmt), param_dict)
631+
result = await conn.execute(text(dense_query_stmt), param_dict)
581632
result_map = result.mappings()
582-
results = result_map.fetchall()
633+
dense_results = result_map.fetchall()
583634
else:
584635
async with self.engine.connect() as conn:
585-
result = await conn.execute(text(stmt), param_dict)
636+
result = await conn.execute(text(dense_query_stmt), param_dict)
637+
result_map = result.mappings()
638+
dense_results = result_map.fetchall()
639+
640+
hybrid_search_config = kwargs.get(
641+
"hybrid_search_config", self.hybrid_search_config
642+
)
643+
fts_query = (
644+
hybrid_search_config.fts_query
645+
if hybrid_search_config and hybrid_search_config.fts_query
646+
else kwargs.get("fts_query", "")
647+
)
648+
if hybrid_search_config and fts_query:
649+
hybrid_search_config.fusion_function_parameters["fetch_top_k"] = k
650+
# do the sparse query
651+
lang = (
652+
f"'{hybrid_search_config.tsv_lang}',"
653+
if hybrid_search_config.tsv_lang
654+
else ""
655+
)
656+
query_tsv = f"plainto_tsquery({lang} :fts_query)"
657+
param_dict["fts_query"] = fts_query
658+
if hybrid_search_config.tsv_column:
659+
content_tsv = f'"{hybrid_search_config.tsv_column}"'
660+
else:
661+
content_tsv = f'to_tsvector({lang} "{self.content_column}")'
662+
and_filters = f"AND ({safe_filter})" if safe_filter else ""
663+
sparse_query_stmt = f'SELECT {column_names}, ts_rank_cd({content_tsv}, {query_tsv}) as distance FROM "{self.schema_name}"."{self.table_name}" WHERE {content_tsv} @@ {query_tsv} {and_filters} ORDER BY distance desc LIMIT {hybrid_search_config.secondary_top_k};'
664+
async with self.engine.connect() as conn:
665+
result = await conn.execute(text(sparse_query_stmt), param_dict)
586666
result_map = result.mappings()
587-
results = result_map.fetchall()
588-
return results
667+
sparse_results = result_map.fetchall()
668+
669+
combined_results = hybrid_search_config.fusion_function(
670+
dense_results,
671+
sparse_results,
672+
**hybrid_search_config.fusion_function_parameters,
673+
)
674+
return combined_results
675+
return dense_results
589676

590677
async def asimilarity_search(
591678
self,
@@ -603,6 +690,14 @@ async def asimilarity_search(
603690
)
604691
kwargs["query"] = query
605692

693+
# add fts_query to hybrid_search_config
694+
hybrid_search_config = kwargs.get(
695+
"hybrid_search_config", self.hybrid_search_config
696+
)
697+
if hybrid_search_config and not hybrid_search_config.fts_query:
698+
hybrid_search_config.fts_query = query
699+
kwargs["hybrid_search_config"] = hybrid_search_config
700+
606701
return await self.asimilarity_search_by_vector(
607702
embedding=embedding, k=k, filter=filter, **kwargs
608703
)
@@ -634,6 +729,14 @@ async def asimilarity_search_with_score(
634729
)
635730
kwargs["query"] = query
636731

732+
# add fts_query to hybrid_search_config
733+
hybrid_search_config = kwargs.get(
734+
"hybrid_search_config", self.hybrid_search_config
735+
)
736+
if hybrid_search_config and not hybrid_search_config.fts_query:
737+
hybrid_search_config.fts_query = query
738+
kwargs["hybrid_search_config"] = hybrid_search_config
739+
637740
docs = await self.asimilarity_search_with_score_by_vector(
638741
embedding=embedding, k=k, filter=filter, **kwargs
639742
)
@@ -778,6 +881,41 @@ async def amax_marginal_relevance_search_with_score_by_vector(
778881

779882
return [r for i, r in enumerate(documents_with_scores) if i in mmr_selected]
780883

884+
async def aapply_hybrid_search_index(
885+
self,
886+
concurrently: bool = False,
887+
) -> None:
888+
"""Creates a TSV index in the vector store table if possible."""
889+
if (
890+
not self.hybrid_search_config
891+
or not self.hybrid_search_config.index_type
892+
or not self.hybrid_search_config.index_name
893+
):
894+
# no index needs to be created
895+
raise ValueError("Hybrid Search Config cannot create index.")
896+
897+
lang = (
898+
f"'{self.hybrid_search_config.tsv_lang}',"
899+
if self.hybrid_search_config.tsv_lang
900+
else ""
901+
)
902+
tsv_column_name = (
903+
self.hybrid_search_config.tsv_column
904+
if self.hybrid_search_config.tsv_column
905+
else f"to_tsvector({lang} {self.content_column})"
906+
)
907+
tsv_index_query = f'CREATE INDEX {"CONCURRENTLY" if concurrently else ""} {self.hybrid_search_config.index_name} ON "{self.schema_name}"."{self.table_name}" USING {self.hybrid_search_config.index_type}({tsv_column_name});'
908+
if concurrently:
909+
async with self.engine.connect() as conn:
910+
autocommit_conn = await conn.execution_options(
911+
isolation_level="AUTOCOMMIT"
912+
)
913+
await autocommit_conn.execute(text(tsv_index_query))
914+
else:
915+
async with self.engine.connect() as conn:
916+
await conn.execute(text(tsv_index_query))
917+
await conn.commit()
918+
781919
async def aapply_vector_index(
782920
self,
783921
index: BaseIndex,
@@ -806,6 +944,7 @@ async def aapply_vector_index(
806944
index.name = self.table_name + DEFAULT_INDEX_NAME_SUFFIX
807945
name = index.name
808946
stmt = f'CREATE INDEX {"CONCURRENTLY" if concurrently else ""} "{name}" ON "{self.schema_name}"."{self.table_name}" USING {index.index_type} ({self.embedding_column} {function}) {params} {filter};'
947+
809948
if concurrently:
810949
async with self.engine.connect() as conn:
811950
autocommit_conn = await conn.execution_options(

langchain_postgres/v2/engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ async def _adrop_table(
384384
schema_name: str = "public",
385385
) -> None:
386386
"""Drop the vector store table"""
387-
query = f'DROP TABLE "{schema_name}"."{table_name}";'
387+
query = f'DROP TABLE IF EXISTS "{schema_name}"."{table_name}";'
388388
async with self._pool.connect() as conn:
389389
await conn.execute(text(query))
390390
await conn.commit()

0 commit comments

Comments
 (0)