1414from sqlalchemy .ext .asyncio import AsyncEngine
1515
1616from .engine import PGEngine
17+ from .hybrid_search_config import HybridSearchConfig
1718from .indexes import (
1819 DEFAULT_DISTANCE_STRATEGY ,
1920 DEFAULT_INDEX_NAME_SUFFIX ,
@@ -77,6 +78,7 @@ def __init__(
7778 fetch_k : int = 20 ,
7879 lambda_mult : float = 0.5 ,
7980 index_query_options : Optional [QueryOptions ] = None ,
81+ hybrid_search_config : Optional [HybridSearchConfig ] = None ,
8082 ):
8183 """AsyncPGVectorStore constructor.
8284 Args:
@@ -95,6 +97,7 @@ def __init__(
9597 fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
9698 lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
9799 index_query_options (QueryOptions): Index query option.
100+ hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None.
98101
99102
100103 Raises:
@@ -119,6 +122,7 @@ def __init__(
119122 self .fetch_k = fetch_k
120123 self .lambda_mult = lambda_mult
121124 self .index_query_options = index_query_options
125+ self .hybrid_search_config = hybrid_search_config
122126
123127 @classmethod
124128 async def create (
@@ -139,6 +143,7 @@ async def create(
139143 fetch_k : int = 20 ,
140144 lambda_mult : float = 0.5 ,
141145 index_query_options : Optional [QueryOptions ] = None ,
146+ hybrid_search_config : Optional [HybridSearchConfig ] = None ,
142147 ) -> AsyncPGVectorStore :
143148 """Create an AsyncPGVectorStore instance.
144149
@@ -158,6 +163,7 @@ async def create(
158163 fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
159164 lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
160165 index_query_options (QueryOptions): Index query option.
166+ hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None.
161167
162168 Returns:
163169 AsyncPGVectorStore
@@ -193,6 +199,15 @@ async def create(
193199 raise ValueError (
194200 f"Content column, { content_column } , is type, { content_type } . It must be a type of character string."
195201 )
202+ if hybrid_search_config :
203+ tsv_column_name = (
204+ hybrid_search_config .tsv_column
205+ if hybrid_search_config .tsv_column
206+ else content_column + "_tsv"
207+ )
208+ if tsv_column_name not in columns or columns [tsv_column_name ] != "tsvector" :
209+ # mark tsv_column as empty because there is no TSV column in table
210+ hybrid_search_config .tsv_column = ""
196211 if embedding_column not in columns :
197212 raise ValueError (f"Embedding column, { embedding_column } , does not exist." )
198213 if columns [embedding_column ] != "USER-DEFINED" :
@@ -236,6 +251,7 @@ async def create(
236251 fetch_k = fetch_k ,
237252 lambda_mult = lambda_mult ,
238253 index_query_options = index_query_options ,
254+ hybrid_search_config = hybrid_search_config ,
239255 )
240256
241257 @property
@@ -273,7 +289,12 @@ async def aadd_embeddings(
273289 if len (self .metadata_columns ) > 0
274290 else ""
275291 )
276- insert_stmt = f'INSERT INTO "{ self .schema_name } "."{ self .table_name } "("{ self .id_column } ", "{ self .content_column } ", "{ self .embedding_column } "{ metadata_col_names } '
292+ hybrid_search_column = (
293+ f', "{ self .hybrid_search_config .tsv_column } "'
294+ if self .hybrid_search_config and self .hybrid_search_config .tsv_column
295+ else ""
296+ )
297+ insert_stmt = f'INSERT INTO "{ self .schema_name } "."{ self .table_name } "("{ self .id_column } ", "{ self .content_column } ", "{ self .embedding_column } "{ hybrid_search_column } { metadata_col_names } '
277298 values = {
278299 "id" : id ,
279300 "content" : content ,
@@ -284,6 +305,14 @@ async def aadd_embeddings(
284305 if not embedding and can_inline_embed :
285306 values_stmt = f"VALUES (:id, :content, { self .embedding_service .embed_query_inline (content )} " # type: ignore
286307
308+ if self .hybrid_search_config and self .hybrid_search_config .tsv_column :
309+ lang = (
310+ f"'{ self .hybrid_search_config .tsv_lang } ',"
311+ if self .hybrid_search_config .tsv_lang
312+ else ""
313+ )
314+ values_stmt += f", to_tsvector({ lang } :tsv_content)"
315+ values ["tsv_content" ] = content
287316 # Add metadata
288317 extra = copy .deepcopy (metadata )
289318 for metadata_column in self .metadata_columns :
@@ -308,6 +337,9 @@ async def aadd_embeddings(
308337
309338 upsert_stmt = f' ON CONFLICT ("{ self .id_column } ") DO UPDATE SET "{ self .content_column } " = EXCLUDED."{ self .content_column } ", "{ self .embedding_column } " = EXCLUDED."{ self .embedding_column } "'
310339
340+ if self .hybrid_search_config and self .hybrid_search_config .tsv_column :
341+ upsert_stmt += f', "{ self .hybrid_search_config .tsv_column } " = EXCLUDED."{ self .hybrid_search_config .tsv_column } "'
342+
311343 if self .metadata_json_column :
312344 upsert_stmt += f', "{ self .metadata_json_column } " = EXCLUDED."{ self .metadata_json_column } "'
313345
@@ -408,6 +440,7 @@ async def afrom_texts( # type: ignore[override]
408440 fetch_k : int = 20 ,
409441 lambda_mult : float = 0.5 ,
410442 index_query_options : Optional [QueryOptions ] = None ,
443+ hybrid_search_config : Optional [HybridSearchConfig ] = None ,
411444 ** kwargs : Any ,
412445 ) -> AsyncPGVectorStore :
413446 """Create an AsyncPGVectorStore instance from texts.
@@ -453,6 +486,7 @@ async def afrom_texts( # type: ignore[override]
453486 fetch_k = fetch_k ,
454487 lambda_mult = lambda_mult ,
455488 index_query_options = index_query_options ,
489+ hybrid_search_config = hybrid_search_config ,
456490 )
457491 await vs .aadd_texts (texts , metadatas = metadatas , ids = ids , ** kwargs )
458492 return vs
@@ -478,6 +512,7 @@ async def afrom_documents( # type: ignore[override]
478512 fetch_k : int = 20 ,
479513 lambda_mult : float = 0.5 ,
480514 index_query_options : Optional [QueryOptions ] = None ,
515+ hybrid_search_config : Optional [HybridSearchConfig ] = None ,
481516 ** kwargs : Any ,
482517 ) -> AsyncPGVectorStore :
483518 """Create an AsyncPGVectorStore instance from documents.
@@ -524,6 +559,7 @@ async def afrom_documents( # type: ignore[override]
524559 fetch_k = fetch_k ,
525560 lambda_mult = lambda_mult ,
526561 index_query_options = index_query_options ,
562+ hybrid_search_config = hybrid_search_config ,
527563 )
528564 texts = [doc .page_content for doc in documents ]
529565 metadatas = [doc .metadata for doc in documents ]
@@ -538,16 +574,30 @@ async def __query_collection(
538574 filter : Optional [dict ] = None ,
539575 ** kwargs : Any ,
540576 ) -> Sequence [RowMapping ]:
541- """Perform similarity search query on database."""
542- k = k if k else self .k
577+ """
578+ Perform similarity search (or hybrid search) query on database.
579+ Queries might be slow if the hybrid search column does not exist.
580+ For best hybrid search performance, consider creating a TSV column
581+ and adding GIN index.
582+ """
583+ if not k :
584+ k = (
585+ max (
586+ self .k ,
587+ self .hybrid_search_config .primary_top_k ,
588+ self .hybrid_search_config .secondary_top_k ,
589+ )
590+ if self .hybrid_search_config
591+ else self .k
592+ )
543593 operator = self .distance_strategy .operator
544594 search_function = self .distance_strategy .search_function
545595
546- columns = self . metadata_columns + [
596+ columns = [
547597 self .id_column ,
548598 self .content_column ,
549599 self .embedding_column ,
550- ]
600+ ] + self . metadata_columns
551601 if self .metadata_json_column :
552602 columns .append (self .metadata_json_column )
553603
@@ -557,16 +607,17 @@ async def __query_collection(
557607 filter_dict = None
558608 if filter and isinstance (filter , dict ):
559609 safe_filter , filter_dict = self ._create_filter_clause (filter )
560- param_filter = f"WHERE { safe_filter } " if safe_filter else ""
610+
561611 inline_embed_func = getattr (self .embedding_service , "embed_query_inline" , None )
562612 if not embedding and callable (inline_embed_func ) and "query" in kwargs :
563613 query_embedding = self .embedding_service .embed_query_inline (kwargs ["query" ]) # type: ignore
564614 embedding_data_string = f"{ query_embedding } "
565615 else :
566616 query_embedding = f"{ [float (dimension ) for dimension in embedding ]} "
567617 embedding_data_string = ":query_embedding"
568- stmt = f"""SELECT { column_names } , { search_function } ("{ self .embedding_column } ", { embedding_data_string } ) as distance
569- FROM "{ self .schema_name } "."{ self .table_name } " { param_filter } ORDER BY "{ self .embedding_column } " { operator } { embedding_data_string } LIMIT :k;
618+ where_filters = f"WHERE { safe_filter } " if safe_filter else ""
619+ dense_query_stmt = f"""SELECT { column_names } , { search_function } ("{ self .embedding_column } ", { embedding_data_string } ) as distance
620+ FROM "{ self .schema_name } "."{ self .table_name } " { where_filters } ORDER BY "{ self .embedding_column } " { operator } { embedding_data_string } LIMIT :k;
570621 """
571622 param_dict = {"query_embedding" : query_embedding , "k" : k }
572623 if filter_dict :
@@ -577,15 +628,51 @@ async def __query_collection(
577628 for query_option in self .index_query_options .to_parameter ():
578629 query_options_stmt = f"SET LOCAL { query_option } ;"
579630 await conn .execute (text (query_options_stmt ))
580- result = await conn .execute (text (stmt ), param_dict )
631+ result = await conn .execute (text (dense_query_stmt ), param_dict )
581632 result_map = result .mappings ()
582- results = result_map .fetchall ()
633+ dense_results = result_map .fetchall ()
583634 else :
584635 async with self .engine .connect () as conn :
585- result = await conn .execute (text (stmt ), param_dict )
636+ result = await conn .execute (text (dense_query_stmt ), param_dict )
637+ result_map = result .mappings ()
638+ dense_results = result_map .fetchall ()
639+
640+ hybrid_search_config = kwargs .get (
641+ "hybrid_search_config" , self .hybrid_search_config
642+ )
643+ fts_query = (
644+ hybrid_search_config .fts_query
645+ if hybrid_search_config and hybrid_search_config .fts_query
646+ else kwargs .get ("fts_query" , "" )
647+ )
648+ if hybrid_search_config and fts_query :
649+ hybrid_search_config .fusion_function_parameters ["fetch_top_k" ] = k
650+ # do the sparse query
651+ lang = (
652+ f"'{ hybrid_search_config .tsv_lang } ',"
653+ if hybrid_search_config .tsv_lang
654+ else ""
655+ )
656+ query_tsv = f"plainto_tsquery({ lang } :fts_query)"
657+ param_dict ["fts_query" ] = fts_query
658+ if hybrid_search_config .tsv_column :
659+ content_tsv = f'"{ hybrid_search_config .tsv_column } "'
660+ else :
661+ content_tsv = f'to_tsvector({ lang } "{ self .content_column } ")'
662+ and_filters = f"AND ({ safe_filter } )" if safe_filter else ""
663+ sparse_query_stmt = f'SELECT { column_names } , ts_rank_cd({ content_tsv } , { query_tsv } ) as distance FROM "{ self .schema_name } "."{ self .table_name } " WHERE { content_tsv } @@ { query_tsv } { and_filters } ORDER BY distance desc LIMIT { hybrid_search_config .secondary_top_k } ;'
664+ async with self .engine .connect () as conn :
665+ result = await conn .execute (text (sparse_query_stmt ), param_dict )
586666 result_map = result .mappings ()
587- results = result_map .fetchall ()
588- return results
667+ sparse_results = result_map .fetchall ()
668+
669+ combined_results = hybrid_search_config .fusion_function (
670+ dense_results ,
671+ sparse_results ,
672+ ** hybrid_search_config .fusion_function_parameters ,
673+ )
674+ return combined_results
675+ return dense_results
589676
590677 async def asimilarity_search (
591678 self ,
@@ -603,6 +690,14 @@ async def asimilarity_search(
603690 )
604691 kwargs ["query" ] = query
605692
693+ # add fts_query to hybrid_search_config
694+ hybrid_search_config = kwargs .get (
695+ "hybrid_search_config" , self .hybrid_search_config
696+ )
697+ if hybrid_search_config and not hybrid_search_config .fts_query :
698+ hybrid_search_config .fts_query = query
699+ kwargs ["hybrid_search_config" ] = hybrid_search_config
700+
606701 return await self .asimilarity_search_by_vector (
607702 embedding = embedding , k = k , filter = filter , ** kwargs
608703 )
@@ -634,6 +729,14 @@ async def asimilarity_search_with_score(
634729 )
635730 kwargs ["query" ] = query
636731
732+ # add fts_query to hybrid_search_config
733+ hybrid_search_config = kwargs .get (
734+ "hybrid_search_config" , self .hybrid_search_config
735+ )
736+ if hybrid_search_config and not hybrid_search_config .fts_query :
737+ hybrid_search_config .fts_query = query
738+ kwargs ["hybrid_search_config" ] = hybrid_search_config
739+
637740 docs = await self .asimilarity_search_with_score_by_vector (
638741 embedding = embedding , k = k , filter = filter , ** kwargs
639742 )
@@ -778,6 +881,41 @@ async def amax_marginal_relevance_search_with_score_by_vector(
778881
779882 return [r for i , r in enumerate (documents_with_scores ) if i in mmr_selected ]
780883
884+ async def aapply_hybrid_search_index (
885+ self ,
886+ concurrently : bool = False ,
887+ ) -> None :
888+ """Creates a TSV index in the vector store table if possible."""
889+ if (
890+ not self .hybrid_search_config
891+ or not self .hybrid_search_config .index_type
892+ or not self .hybrid_search_config .index_name
893+ ):
894+ # no index needs to be created
895+ raise ValueError ("Hybrid Search Config cannot create index." )
896+
897+ lang = (
898+ f"'{ self .hybrid_search_config .tsv_lang } ',"
899+ if self .hybrid_search_config .tsv_lang
900+ else ""
901+ )
902+ tsv_column_name = (
903+ self .hybrid_search_config .tsv_column
904+ if self .hybrid_search_config .tsv_column
905+ else f"to_tsvector({ lang } { self .content_column } )"
906+ )
907+ tsv_index_query = f'CREATE INDEX { "CONCURRENTLY" if concurrently else "" } { self .hybrid_search_config .index_name } ON "{ self .schema_name } "."{ self .table_name } " USING { self .hybrid_search_config .index_type } ({ tsv_column_name } );'
908+ if concurrently :
909+ async with self .engine .connect () as conn :
910+ autocommit_conn = await conn .execution_options (
911+ isolation_level = "AUTOCOMMIT"
912+ )
913+ await autocommit_conn .execute (text (tsv_index_query ))
914+ else :
915+ async with self .engine .connect () as conn :
916+ await conn .execute (text (tsv_index_query ))
917+ await conn .commit ()
918+
781919 async def aapply_vector_index (
782920 self ,
783921 index : BaseIndex ,
@@ -806,6 +944,7 @@ async def aapply_vector_index(
806944 index .name = self .table_name + DEFAULT_INDEX_NAME_SUFFIX
807945 name = index .name
808946 stmt = f'CREATE INDEX { "CONCURRENTLY" if concurrently else "" } "{ name } " ON "{ self .schema_name } "."{ self .table_name } " USING { index .index_type } ({ self .embedding_column } { function } ) { params } { filter } ;'
947+
809948 if concurrently :
810949 async with self .engine .connect () as conn :
811950 autocommit_conn = await conn .execution_options (
0 commit comments