diff --git a/langchain_postgres/v2/async_vectorstore.py b/langchain_postgres/v2/async_vectorstore.py index 8382b3e..c83930e 100644 --- a/langchain_postgres/v2/async_vectorstore.py +++ b/langchain_postgres/v2/async_vectorstore.py @@ -580,16 +580,16 @@ async def __query_collection( For best hybrid search performance, consider creating a TSV column and adding GIN index. """ - if not k: - k = ( - max( - self.k, - self.hybrid_search_config.primary_top_k, - self.hybrid_search_config.secondary_top_k, - ) - if self.hybrid_search_config - else self.k - ) + hybrid_search_config = kwargs.get( + "hybrid_search_config", self.hybrid_search_config + ) + + final_k = k if k is not None else self.k + + dense_limit = final_k + if hybrid_search_config: + dense_limit = hybrid_search_config.primary_top_k + operator = self.distance_strategy.operator search_function = self.distance_strategy.search_function @@ -617,9 +617,9 @@ async def __query_collection( embedding_data_string = ":query_embedding" where_filters = f"WHERE {safe_filter}" if safe_filter else "" dense_query_stmt = f"""SELECT {column_names}, {search_function}("{self.embedding_column}", {embedding_data_string}) as distance - FROM "{self.schema_name}"."{self.table_name}" {where_filters} ORDER BY "{self.embedding_column}" {operator} {embedding_data_string} LIMIT :k; + FROM "{self.schema_name}"."{self.table_name}" {where_filters} ORDER BY "{self.embedding_column}" {operator} {embedding_data_string} LIMIT :dense_limit; """ - param_dict = {"query_embedding": query_embedding, "k": k} + param_dict = {"query_embedding": query_embedding, "dense_limit": dense_limit} if filter_dict: param_dict.update(filter_dict) if self.index_query_options: @@ -637,16 +637,13 @@ async def __query_collection( result_map = result.mappings() dense_results = result_map.fetchall() - hybrid_search_config = kwargs.get( - "hybrid_search_config", self.hybrid_search_config - ) fts_query = ( hybrid_search_config.fts_query if hybrid_search_config and hybrid_search_config.fts_query else kwargs.get("fts_query", "") ) if hybrid_search_config and fts_query: - hybrid_search_config.fusion_function_parameters["fetch_top_k"] = k + hybrid_search_config.fusion_function_parameters["fetch_top_k"] = final_k # do the sparse query lang = ( f"'{hybrid_search_config.tsv_lang}',"