Skip to content

Commit 9f97676

Browse files
feat: Added Hybrid Search Config and Tests [1/N] (#211)
1 parent 194b3bb commit 9f97676

File tree

9 files changed

+1195
-32
lines changed

9 files changed

+1195
-32
lines changed

langchain_postgres/v2/async_vectorstore.py

Lines changed: 152 additions & 13 deletions
Large diffs are not rendered by default.

langchain_postgres/v2/engine.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from sqlalchemy.engine import URL
1010
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
1111

12+
from .hybrid_search_config import HybridSearchConfig
13+
1214
T = TypeVar("T")
1315

1416

@@ -156,6 +158,7 @@ async def _ainit_vectorstore_table(
156158
id_column: Union[str, Column, ColumnDict] = "langchain_id",
157159
overwrite_existing: bool = False,
158160
store_metadata: bool = True,
161+
hybrid_search_config: Optional[HybridSearchConfig] = None,
159162
) -> None:
160163
"""
161164
Create a table for saving of vectors to be used with PGVectorStore.
@@ -178,6 +181,8 @@ async def _ainit_vectorstore_table(
178181
overwrite_existing (bool): Whether to drop existing table. Default: False.
179182
store_metadata (bool): Whether to store metadata in the table.
180183
Default: True.
184+
hybrid_search_config (HybridSearchConfig): Hybrid search configuration.
185+
Default: None.
181186
182187
Raises:
183188
:class:`DuplicateTableError <asyncpg.exceptions.DuplicateTableError>`: if table already exists.
@@ -186,6 +191,7 @@ async def _ainit_vectorstore_table(
186191

187192
schema_name = self._escape_postgres_identifier(schema_name)
188193
table_name = self._escape_postgres_identifier(table_name)
194+
hybrid_search_default_column_name = content_column + "_tsv"
189195
content_column = self._escape_postgres_identifier(content_column)
190196
embedding_column = self._escape_postgres_identifier(embedding_column)
191197
if metadata_columns is None:
@@ -226,10 +232,22 @@ async def _ainit_vectorstore_table(
226232
id_data_type = id_column["data_type"]
227233
id_column_name = id_column["name"]
228234

235+
hybrid_search_column = "" # Default is no TSV column for hybrid search
236+
if hybrid_search_config:
237+
hybrid_search_column_name = (
238+
hybrid_search_config.tsv_column or hybrid_search_default_column_name
239+
)
240+
hybrid_search_column_name = self._escape_postgres_identifier(
241+
hybrid_search_column_name
242+
)
243+
hybrid_search_config.tsv_column = hybrid_search_column_name
244+
hybrid_search_column = f',"{self._escape_postgres_identifier(hybrid_search_column_name)}" TSVECTOR NOT NULL'
245+
229246
query = f"""CREATE TABLE "{schema_name}"."{table_name}"(
230247
"{id_column_name}" {id_data_type} PRIMARY KEY,
231248
"{content_column}" TEXT NOT NULL,
232-
"{embedding_column}" vector({vector_size}) NOT NULL"""
249+
"{embedding_column}" vector({vector_size}) NOT NULL
250+
{hybrid_search_column}"""
233251
for column in metadata_columns:
234252
if isinstance(column, Column):
235253
nullable = "NOT NULL" if not column.nullable else ""
@@ -258,6 +276,7 @@ async def ainit_vectorstore_table(
258276
id_column: Union[str, Column, ColumnDict] = "langchain_id",
259277
overwrite_existing: bool = False,
260278
store_metadata: bool = True,
279+
hybrid_search_config: Optional[HybridSearchConfig] = None,
261280
) -> None:
262281
"""
263282
Create a table for saving of vectors to be used with PGVectorStore.
@@ -280,6 +299,10 @@ async def ainit_vectorstore_table(
280299
overwrite_existing (bool): Whether to drop existing table. Default: False.
281300
store_metadata (bool): Whether to store metadata in the table.
282301
Default: True.
302+
hybrid_search_config (HybridSearchConfig): Hybrid search configuration.
303+
Note that queries might be slow if the hybrid search column does not exist.
304+
For best hybrid search performance, consider creating a TSV column and adding GIN index.
305+
Default: None.
283306
"""
284307
await self._run_as_async(
285308
self._ainit_vectorstore_table(
@@ -293,6 +316,7 @@ async def ainit_vectorstore_table(
293316
id_column=id_column,
294317
overwrite_existing=overwrite_existing,
295318
store_metadata=store_metadata,
319+
hybrid_search_config=hybrid_search_config,
296320
)
297321
)
298322

@@ -309,6 +333,7 @@ def init_vectorstore_table(
309333
id_column: Union[str, Column, ColumnDict] = "langchain_id",
310334
overwrite_existing: bool = False,
311335
store_metadata: bool = True,
336+
hybrid_search_config: Optional[HybridSearchConfig] = None,
312337
) -> None:
313338
"""
314339
Create a table for saving of vectors to be used with PGVectorStore.
@@ -331,6 +356,10 @@ def init_vectorstore_table(
331356
overwrite_existing (bool): Whether to drop existing table. Default: False.
332357
store_metadata (bool): Whether to store metadata in the table.
333358
Default: True.
359+
hybrid_search_config (HybridSearchConfig): Hybrid search configuration.
360+
Note that queries might be slow if the hybrid search column does not exist.
361+
For best hybrid search performance, consider creating a TSV column and adding GIN index.
362+
Default: None.
334363
"""
335364
self._run_as_sync(
336365
self._ainit_vectorstore_table(
@@ -344,6 +373,7 @@ def init_vectorstore_table(
344373
id_column=id_column,
345374
overwrite_existing=overwrite_existing,
346375
store_metadata=store_metadata,
376+
hybrid_search_config=hybrid_search_config,
347377
)
348378
)
349379

@@ -354,7 +384,7 @@ async def _adrop_table(
354384
schema_name: str = "public",
355385
) -> None:
356386
"""Drop the vector store table"""
357-
query = f'DROP TABLE "{schema_name}"."{table_name}";'
387+
query = f'DROP TABLE IF EXISTS "{schema_name}"."{table_name}";'
358388
async with self._pool.connect() as conn:
359389
await conn.execute(text(query))
360390
await conn.commit()
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
from abc import ABC
2+
from dataclasses import dataclass, field
3+
from typing import Any, Callable, Optional, Sequence
4+
5+
from sqlalchemy import RowMapping
6+
7+
8+
def weighted_sum_ranking(
9+
primary_search_results: Sequence[RowMapping],
10+
secondary_search_results: Sequence[RowMapping],
11+
primary_results_weight: float = 0.5,
12+
secondary_results_weight: float = 0.5,
13+
fetch_top_k: int = 4,
14+
) -> Sequence[dict[str, Any]]:
15+
"""
16+
Ranks documents using a weighted sum of scores from two sources.
17+
18+
Args:
19+
primary_search_results: A list of (document, distance) tuples from
20+
the primary search.
21+
secondary_search_results: A list of (document, distance) tuples from
22+
the secondary search.
23+
primary_results_weight: The weight for the primary source's scores.
24+
Defaults to 0.5.
25+
secondary_results_weight: The weight for the secondary source's scores.
26+
Defaults to 0.5.
27+
fetch_top_k: The number of documents to fetch after merging the results.
28+
Defaults to 4.
29+
30+
Returns:
31+
A list of (document, distance) tuples, sorted by weighted_score in
32+
descending order.
33+
"""
34+
35+
# stores computed metric with provided distance metric and weights
36+
weighted_scores: dict[str, dict[str, Any]] = {}
37+
38+
# Process results from primary source
39+
for row in primary_search_results:
40+
values = list(row.values())
41+
doc_id = str(values[0]) # first value is doc_id
42+
distance = float(values[-1]) # type: ignore # last value is distance
43+
row_values = dict(row)
44+
row_values["distance"] = primary_results_weight * distance
45+
weighted_scores[doc_id] = row_values
46+
47+
# Process results from secondary source,
48+
# adding to existing scores or creating new ones
49+
for row in secondary_search_results:
50+
values = list(row.values())
51+
doc_id = str(values[0]) # first value is doc_id
52+
distance = float(values[-1]) # type: ignore # last value is distance
53+
primary_score = (
54+
weighted_scores[doc_id]["distance"] if doc_id in weighted_scores else 0.0
55+
)
56+
row_values = dict(row)
57+
row_values["distance"] = distance * secondary_results_weight + primary_score
58+
weighted_scores[doc_id] = row_values
59+
60+
# Sort the results by weighted score in descending order
61+
ranked_results = sorted(
62+
weighted_scores.values(), key=lambda item: item["distance"], reverse=True
63+
)
64+
return ranked_results[:fetch_top_k]
65+
66+
67+
def reciprocal_rank_fusion(
68+
primary_search_results: Sequence[RowMapping],
69+
secondary_search_results: Sequence[RowMapping],
70+
rrf_k: float = 60,
71+
fetch_top_k: int = 4,
72+
) -> Sequence[dict[str, Any]]:
73+
"""
74+
Ranks documents using Reciprocal Rank Fusion (RRF) of scores from two sources.
75+
76+
Args:
77+
primary_search_results: A list of (document, distance) tuples from
78+
the primary search.
79+
secondary_search_results: A list of (document, distance) tuples from
80+
the secondary search.
81+
rrf_k: The RRF parameter k.
82+
Defaults to 60.
83+
fetch_top_k: The number of documents to fetch after merging the results.
84+
Defaults to 4.
85+
86+
Returns:
87+
A list of (document_id, rrf_score) tuples, sorted by rrf_score
88+
in descending order.
89+
"""
90+
rrf_scores: dict[str, dict[str, Any]] = {}
91+
92+
# Process results from primary source
93+
for rank, row in enumerate(
94+
sorted(primary_search_results, key=lambda item: item["distance"], reverse=True)
95+
):
96+
values = list(row.values())
97+
doc_id = str(values[0])
98+
row_values = dict(row)
99+
primary_score = rrf_scores[doc_id]["distance"] if doc_id in rrf_scores else 0.0
100+
primary_score += 1.0 / (rank + rrf_k)
101+
row_values["distance"] = primary_score
102+
rrf_scores[doc_id] = row_values
103+
104+
# Process results from secondary source
105+
for rank, row in enumerate(
106+
sorted(
107+
secondary_search_results, key=lambda item: item["distance"], reverse=True
108+
)
109+
):
110+
values = list(row.values())
111+
doc_id = str(values[0])
112+
row_values = dict(row)
113+
secondary_score = (
114+
rrf_scores[doc_id]["distance"] if doc_id in rrf_scores else 0.0
115+
)
116+
secondary_score += 1.0 / (rank + rrf_k)
117+
row_values["distance"] = secondary_score
118+
rrf_scores[doc_id] = row_values
119+
120+
# Sort the results by rrf score in descending order
121+
# Sort the results by weighted score in descending order
122+
ranked_results = sorted(
123+
rrf_scores.values(), key=lambda item: item["distance"], reverse=True
124+
)
125+
# Extract only the RowMapping for the top results
126+
return ranked_results[:fetch_top_k]
127+
128+
129+
@dataclass
130+
class HybridSearchConfig(ABC):
131+
"""
132+
AlloyDB Vector Store Hybrid Search Config.
133+
134+
Queries might be slow if the hybrid search column does not exist.
135+
For best hybrid search performance, consider creating a TSV column
136+
and adding GIN index.
137+
"""
138+
139+
tsv_column: Optional[str] = ""
140+
tsv_lang: Optional[str] = "pg_catalog.english"
141+
fts_query: Optional[str] = ""
142+
fusion_function: Callable[
143+
[Sequence[RowMapping], Sequence[RowMapping], Any], Sequence[Any]
144+
] = weighted_sum_ranking # Updated default
145+
fusion_function_parameters: dict[str, Any] = field(default_factory=dict)
146+
primary_top_k: int = 4
147+
secondary_top_k: int = 4
148+
index_name: str = "langchain_tsv_index"
149+
index_type: str = "GIN"

0 commit comments

Comments
 (0)