From ee38d95c18a31dcff8a7dc4a737bf9e8f65fd564 Mon Sep 17 00:00:00 2001 From: Jen Hamon Date: Wed, 13 Nov 2024 09:24:05 -0500 Subject: [PATCH 1/2] Fix type issue --- pinecone/grpc/index_grpc.py | 49 ++++++++++++++++++- .../integration/data/test_query_namespaces.py | 4 -- 2 files changed, 48 insertions(+), 5 deletions(-) diff --git a/pinecone/grpc/index_grpc.py b/pinecone/grpc/index_grpc.py index eba611b7..5b8f1157 100644 --- a/pinecone/grpc/index_grpc.py +++ b/pinecone/grpc/index_grpc.py @@ -1,9 +1,11 @@ import logging -from typing import Optional, Dict, Union, List, Tuple, Any, TypedDict, cast +from typing import Optional, Dict, Union, List, Tuple, Any, TypedDict, Iterable, cast from google.protobuf import json_format from tqdm.autonotebook import tqdm +from concurrent.futures import as_completed, Future + from .utils import ( dict_to_proto_struct, @@ -35,6 +37,7 @@ SparseValues as GRPCSparseValues, ) from pinecone import Vector as NonGRPCVector +from pinecone.data.query_results_aggregator import QueryNamespacesResults, QueryResultsAggregator from pinecone.core.grpc.protos.vector_service_pb2_grpc import VectorServiceStub from .base import GRPCIndexBase from .future import PineconeGrpcFuture @@ -402,6 +405,50 @@ def query( json_response = json_format.MessageToDict(response) return parse_query_response(json_response, _check_type=False) + def query_namespaces( + self, + vector: List[float], + namespaces: List[str], + top_k: Optional[int] = None, + filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, + include_values: Optional[bool] = None, + include_metadata: Optional[bool] = None, + sparse_vector: Optional[Union[GRPCSparseValues, SparseVectorTypedDict]] = None, + **kwargs, + ) -> QueryNamespacesResults: + if namespaces is None or len(namespaces) == 0: + raise ValueError("At least one namespace must be specified") + if len(vector) == 0: + raise ValueError("Query vector must not be empty") + + overall_topk = top_k if top_k is not None else 10 + aggregator = QueryResultsAggregator(top_k=overall_topk) + + target_namespaces = set(namespaces) # dedup namespaces + futures = [ + self.query( + vector=vector, + namespace=ns, + top_k=overall_topk, + filter=filter, + include_values=include_values, + include_metadata=include_metadata, + sparse_vector=sparse_vector, + async_req=True, + **kwargs, + ) + for ns in target_namespaces + ] + + only_futures = cast(Iterable[Future], futures) + for future in as_completed(only_futures): + response = future.result() + json_result = json_format.MessageToDict(response) + aggregator.add_results(json_result) + + final_results = aggregator.get_results() + return final_results + def update( self, id: str, diff --git a/tests/integration/data/test_query_namespaces.py b/tests/integration/data/test_query_namespaces.py index e52c58b0..414cea69 100644 --- a/tests/integration/data/test_query_namespaces.py +++ b/tests/integration/data/test_query_namespaces.py @@ -1,5 +1,4 @@ import pytest -import os from ..helpers import random_string, poll_stats_for_namespace from pinecone.data.query_results_aggregator import ( QueryResultsAggregatorInvalidTopKError, @@ -9,9 +8,6 @@ from pinecone import Vector -@pytest.mark.skipif( - os.getenv("USE_GRPC") == "true", reason="query_namespaces currently only available via rest" -) class TestQueryNamespacesRest: def test_query_namespaces(self, idx): ns_prefix = random_string(5) From d01168bd56dee0c235df12cb015d3254ae6001e1 Mon Sep 17 00:00:00 2001 From: Jen Hamon Date: Wed, 13 Nov 2024 10:10:46 -0500 Subject: [PATCH 2/2] Use PoolThreadExecutor --- pinecone/grpc/base.py | 10 ++++++++++ pinecone/grpc/index_grpc.py | 11 +++++------ pinecone/grpc/pinecone.py | 4 +++- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/pinecone/grpc/base.py b/pinecone/grpc/base.py index cc4ca2c6..8964e72d 100644 --- a/pinecone/grpc/base.py +++ b/pinecone/grpc/base.py @@ -10,6 +10,7 @@ from pinecone import Config from .config import GRPCClientConfig from .grpc_runner import GrpcRunner +from concurrent.futures import ThreadPoolExecutor from pinecone_plugin_interface import load_and_install as install_plugins @@ -29,10 +30,12 @@ def __init__( config: Config, channel: Optional[Channel] = None, grpc_config: Optional[GRPCClientConfig] = None, + pool_threads: Optional[int] = None, _endpoint_override: Optional[str] = None, ): self.config = config self.grpc_client_config = grpc_config or GRPCClientConfig() + self.pool_threads = pool_threads self._endpoint_override = _endpoint_override @@ -58,6 +61,13 @@ def stub_openapi_client_builder(*args, **kwargs): except Exception as e: _logger.error(f"Error loading plugins in GRPCIndex: {e}") + @property + def threadpool_executor(self): + if self._pool is None: + pt = self.pool_threads or 10 + self._pool = ThreadPoolExecutor(max_workers=pt) + return self._pool + @property @abstractmethod def stub_class(self): diff --git a/pinecone/grpc/index_grpc.py b/pinecone/grpc/index_grpc.py index 5b8f1157..6791ae68 100644 --- a/pinecone/grpc/index_grpc.py +++ b/pinecone/grpc/index_grpc.py @@ -426,7 +426,8 @@ def query_namespaces( target_namespaces = set(namespaces) # dedup namespaces futures = [ - self.query( + self.threadpool_executor.submit( + self.query, vector=vector, namespace=ns, top_k=overall_topk, @@ -434,17 +435,15 @@ def query_namespaces( include_values=include_values, include_metadata=include_metadata, sparse_vector=sparse_vector, - async_req=True, + async_req=False, **kwargs, ) for ns in target_namespaces ] only_futures = cast(Iterable[Future], futures) - for future in as_completed(only_futures): - response = future.result() - json_result = json_format.MessageToDict(response) - aggregator.add_results(json_result) + for response in as_completed(only_futures): + aggregator.add_results(response.result()) final_results = aggregator.get_results() return final_results diff --git a/pinecone/grpc/pinecone.py b/pinecone/grpc/pinecone.py index af6a8baa..c78481ff 100644 --- a/pinecone/grpc/pinecone.py +++ b/pinecone/grpc/pinecone.py @@ -124,6 +124,8 @@ def Index(self, name: str = "", host: str = "", **kwargs): # Use host if it is provided, otherwise get host from describe_index index_host = host or self.index_host_store.get_host(self.index_api, self.config, name) + pt = kwargs.pop("pool_threads", None) or self.pool_threads + config = ConfigBuilder.build( api_key=self.config.api_key, host=index_host, @@ -131,4 +133,4 @@ def Index(self, name: str = "", host: str = "", **kwargs): proxy_url=self.config.proxy_url, ssl_ca_certs=self.config.ssl_ca_certs, ) - return GRPCIndex(index_name=name, config=config, **kwargs) + return GRPCIndex(index_name=name, config=config, pool_threads=pt, **kwargs)