Skip to content

Commit 08a4ff6

Browse files
feat: adds hybrid search for sync VS interface [4/N]
2 parents 5bf1a4b + e092c82 commit 08a4ff6

File tree

2 files changed

+82
-1
lines changed

2 files changed

+82
-1
lines changed

langchain_postgres/v2/vectorstores.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from .async_vectorstore import AsyncPGVectorStore
1111
from .engine import PGEngine
12+
from .hybrid_search_config import HybridSearchConfig
1213
from .indexes import (
1314
DEFAULT_DISTANCE_STRATEGY,
1415
BaseIndex,
@@ -59,6 +60,7 @@ async def create(
5960
fetch_k: int = 20,
6061
lambda_mult: float = 0.5,
6162
index_query_options: Optional[QueryOptions] = None,
63+
hybrid_search_config: Optional[HybridSearchConfig] = None,
6264
) -> PGVectorStore:
6365
"""Create an PGVectorStore instance.
6466
@@ -78,6 +80,7 @@ async def create(
7880
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
7981
lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
8082
index_query_options (QueryOptions): Index query option.
83+
hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None.
8184
8285
Returns:
8386
PGVectorStore
@@ -98,6 +101,7 @@ async def create(
98101
fetch_k=fetch_k,
99102
lambda_mult=lambda_mult,
100103
index_query_options=index_query_options,
104+
hybrid_search_config=hybrid_search_config,
101105
)
102106
vs = await engine._run_as_async(coro)
103107
return cls(cls.__create_key, engine, vs)
@@ -120,6 +124,7 @@ def create_sync(
120124
fetch_k: int = 20,
121125
lambda_mult: float = 0.5,
122126
index_query_options: Optional[QueryOptions] = None,
127+
hybrid_search_config: Optional[HybridSearchConfig] = None,
123128
) -> PGVectorStore:
124129
"""Create an PGVectorStore instance.
125130
@@ -140,6 +145,7 @@ def create_sync(
140145
fetch_k (int, optional): Number of Documents to fetch to pass to MMR algorithm. Defaults to 20.
141146
lambda_mult (float, optional): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
142147
index_query_options (Optional[QueryOptions], optional): Index query option. Defaults to None.
148+
hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None.
143149
144150
Returns:
145151
PGVectorStore
@@ -160,6 +166,7 @@ def create_sync(
160166
fetch_k=fetch_k,
161167
lambda_mult=lambda_mult,
162168
index_query_options=index_query_options,
169+
hybrid_search_config=hybrid_search_config,
163170
)
164171
vs = engine._run_as_sync(coro)
165172
return cls(cls.__create_key, engine, vs)
@@ -301,6 +308,7 @@ async def afrom_texts( # type: ignore[override]
301308
fetch_k: int = 20,
302309
lambda_mult: float = 0.5,
303310
index_query_options: Optional[QueryOptions] = None,
311+
hybrid_search_config: Optional[HybridSearchConfig] = None,
304312
**kwargs: Any,
305313
) -> PGVectorStore:
306314
"""Create an PGVectorStore instance from texts.
@@ -324,6 +332,7 @@ async def afrom_texts( # type: ignore[override]
324332
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
325333
lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
326334
index_query_options (QueryOptions): Index query option.
335+
hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None.
327336
328337
Raises:
329338
:class:`InvalidTextRepresentationError <asyncpg.exceptions.InvalidTextRepresentationError>`: if the `ids` data type does not match that of the `id_column`.
@@ -347,6 +356,7 @@ async def afrom_texts( # type: ignore[override]
347356
fetch_k=fetch_k,
348357
lambda_mult=lambda_mult,
349358
index_query_options=index_query_options,
359+
hybrid_search_config=hybrid_search_config,
350360
)
351361
await vs.aadd_texts(texts, metadatas=metadatas, ids=ids)
352362
return vs
@@ -371,6 +381,7 @@ async def afrom_documents( # type: ignore[override]
371381
fetch_k: int = 20,
372382
lambda_mult: float = 0.5,
373383
index_query_options: Optional[QueryOptions] = None,
384+
hybrid_search_config: Optional[HybridSearchConfig] = None,
374385
**kwargs: Any,
375386
) -> PGVectorStore:
376387
"""Create an PGVectorStore instance from documents.
@@ -393,6 +404,7 @@ async def afrom_documents( # type: ignore[override]
393404
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
394405
lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
395406
index_query_options (QueryOptions): Index query option.
407+
hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None.
396408
397409
Raises:
398410
:class:`InvalidTextRepresentationError <asyncpg.exceptions.InvalidTextRepresentationError>`: if the `ids` data type does not match that of the `id_column`.
@@ -417,6 +429,7 @@ async def afrom_documents( # type: ignore[override]
417429
fetch_k=fetch_k,
418430
lambda_mult=lambda_mult,
419431
index_query_options=index_query_options,
432+
hybrid_search_config=hybrid_search_config,
420433
)
421434
await vs.aadd_documents(documents, ids=ids)
422435
return vs
@@ -442,6 +455,7 @@ def from_texts( # type: ignore[override]
442455
fetch_k: int = 20,
443456
lambda_mult: float = 0.5,
444457
index_query_options: Optional[QueryOptions] = None,
458+
hybrid_search_config: Optional[HybridSearchConfig] = None,
445459
**kwargs: Any,
446460
) -> PGVectorStore:
447461
"""Create an PGVectorStore instance from texts.
@@ -465,6 +479,7 @@ def from_texts( # type: ignore[override]
465479
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
466480
lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
467481
index_query_options (QueryOptions): Index query option.
482+
hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None.
468483
469484
Raises:
470485
:class:`InvalidTextRepresentationError <asyncpg.exceptions.InvalidTextRepresentationError>`: if the `ids` data type does not match that of the `id_column`.
@@ -488,6 +503,7 @@ def from_texts( # type: ignore[override]
488503
fetch_k=fetch_k,
489504
lambda_mult=lambda_mult,
490505
index_query_options=index_query_options,
506+
hybrid_search_config=hybrid_search_config,
491507
**kwargs,
492508
)
493509
vs.add_texts(texts, metadatas=metadatas, ids=ids)
@@ -513,6 +529,7 @@ def from_documents( # type: ignore[override]
513529
fetch_k: int = 20,
514530
lambda_mult: float = 0.5,
515531
index_query_options: Optional[QueryOptions] = None,
532+
hybrid_search_config: Optional[HybridSearchConfig] = None,
516533
**kwargs: Any,
517534
) -> PGVectorStore:
518535
"""Create an PGVectorStore instance from documents.
@@ -535,6 +552,7 @@ def from_documents( # type: ignore[override]
535552
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
536553
lambda_mult (float): Number between 0 and 1 that determines the degree of diversity among the results with 0 corresponding to maximum diversity and 1 to minimum diversity. Defaults to 0.5.
537554
index_query_options (QueryOptions): Index query option.
555+
hybrid_search_config (HybridSearchConfig): Hybrid search configuration. Defaults to None.
538556
539557
Raises:
540558
:class:`InvalidTextRepresentationError <asyncpg.exceptions.InvalidTextRepresentationError>`: if the `ids` data type does not match that of the `id_column`.
@@ -558,6 +576,7 @@ def from_documents( # type: ignore[override]
558576
fetch_k=fetch_k,
559577
lambda_mult=lambda_mult,
560578
index_query_options=index_query_options,
579+
hybrid_search_config=hybrid_search_config,
561580
**kwargs,
562581
)
563582
vs.add_documents(documents, ids=ids)

tests/unit_tests/v2/test_pg_vectorstore_search.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99
from sqlalchemy import text
1010

1111
from langchain_postgres import Column, PGEngine, PGVectorStore
12+
from langchain_postgres.v2.hybrid_search_config import (
13+
HybridSearchConfig,
14+
reciprocal_rank_fusion,
15+
weighted_sum_ranking,
16+
)
1217
from langchain_postgres.v2.indexes import DistanceStrategy, HNSWQueryOptions
1318
from tests.unit_tests.fixtures.metadata_filtering_data import (
1419
FILTERING_TEST_CASES,
@@ -261,6 +266,37 @@ async def test_vectorstore_with_metadata_filters(
261266
)
262267
assert [doc.metadata["code"] for doc in docs] == expected_ids, test_filter
263268

269+
async def test_asimilarity_hybrid_search(self, vs: PGVectorStore):
270+
results = await vs.asimilarity_search(
271+
"foo", k=1, hybrid_search_config=HybridSearchConfig()
272+
)
273+
assert len(results) == 1
274+
assert results == [Document(page_content="foo", id=ids[0])]
275+
276+
results = await vs.asimilarity_search(
277+
"bar",
278+
k=1,
279+
hybrid_search_config=HybridSearchConfig(),
280+
)
281+
assert results[0] == Document(page_content="bar", id=ids[1])
282+
283+
results = await vs.asimilarity_search(
284+
"foo",
285+
k=1,
286+
filter={"content": {"$ne": "baz"}},
287+
hybrid_search_config=HybridSearchConfig(
288+
fusion_function=weighted_sum_ranking,
289+
fusion_function_parameters={
290+
"primary_results_weight": 0.1,
291+
"secondary_results_weight": 0.9,
292+
"fetch_top_k": 10,
293+
},
294+
primary_top_k=1,
295+
secondary_top_k=1,
296+
),
297+
)
298+
assert results == [Document(page_content="foo", id=ids[0])]
299+
264300

265301
@pytest.mark.enable_socket
266302
class TestVectorStoreSearchSync:
@@ -398,4 +434,30 @@ def test_metadata_filter_negative_tests(
398434
self, vs_custom_filter_sync: PGVectorStore, test_filter: dict
399435
) -> None:
400436
with pytest.raises((ValueError, NotImplementedError)):
401-
vs_custom_filter_sync.similarity_search("meow", k=5, filter=test_filter)
437+
docs = vs_custom_filter_sync.similarity_search(
438+
"meow", k=5, filter=test_filter
439+
)
440+
441+
def test_similarity_hybrid_search(self, vs_custom):
442+
results = vs_custom.similarity_search(
443+
"foo", k=1, hybrid_search_config=HybridSearchConfig()
444+
)
445+
assert len(results) == 1
446+
assert results == [Document(page_content="foo", id=ids[0])]
447+
448+
results = vs_custom.similarity_search(
449+
"bar",
450+
k=1,
451+
hybrid_search_config=HybridSearchConfig(),
452+
)
453+
assert results == [Document(page_content="bar", id=ids[1])]
454+
455+
results = vs_custom.similarity_search(
456+
"foo",
457+
k=1,
458+
filter={"mycontent": {"$ne": "baz"}},
459+
hybrid_search_config=HybridSearchConfig(
460+
fusion_function=reciprocal_rank_fusion
461+
),
462+
)
463+
assert results == [Document(page_content="foo", id=ids[0])]

0 commit comments

Comments
 (0)