|
1 | 1 | import logging |
2 | | -from typing import Optional, Dict, Union, List, Tuple, Any, TypedDict, cast |
| 2 | +from typing import Optional, Dict, Union, List, Tuple, Any, TypedDict, Iterable, cast |
3 | 3 |
|
4 | 4 | from google.protobuf import json_format |
5 | 5 |
|
6 | 6 | from tqdm.autonotebook import tqdm |
| 7 | +from concurrent.futures import as_completed, Future |
| 8 | + |
7 | 9 |
|
8 | 10 | from .utils import ( |
9 | 11 | dict_to_proto_struct, |
|
35 | 37 | SparseValues as GRPCSparseValues, |
36 | 38 | ) |
37 | 39 | from pinecone import Vector as NonGRPCVector |
| 40 | +from pinecone.data.query_results_aggregator import QueryNamespacesResults, QueryResultsAggregator |
38 | 41 | from pinecone.core.grpc.protos.vector_service_pb2_grpc import VectorServiceStub |
39 | 42 | from .base import GRPCIndexBase |
40 | 43 | from .future import PineconeGrpcFuture |
@@ -402,6 +405,49 @@ def query( |
402 | 405 | json_response = json_format.MessageToDict(response) |
403 | 406 | return parse_query_response(json_response, _check_type=False) |
404 | 407 |
|
| 408 | + def query_namespaces( |
| 409 | + self, |
| 410 | + vector: List[float], |
| 411 | + namespaces: List[str], |
| 412 | + top_k: Optional[int] = None, |
| 413 | + filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, |
| 414 | + include_values: Optional[bool] = None, |
| 415 | + include_metadata: Optional[bool] = None, |
| 416 | + sparse_vector: Optional[Union[GRPCSparseValues, SparseVectorTypedDict]] = None, |
| 417 | + **kwargs, |
| 418 | + ) -> QueryNamespacesResults: |
| 419 | + if namespaces is None or len(namespaces) == 0: |
| 420 | + raise ValueError("At least one namespace must be specified") |
| 421 | + if len(vector) == 0: |
| 422 | + raise ValueError("Query vector must not be empty") |
| 423 | + |
| 424 | + overall_topk = top_k if top_k is not None else 10 |
| 425 | + aggregator = QueryResultsAggregator(top_k=overall_topk) |
| 426 | + |
| 427 | + target_namespaces = set(namespaces) # dedup namespaces |
| 428 | + futures = [ |
| 429 | + self.threadpool_executor.submit( |
| 430 | + self.query, |
| 431 | + vector=vector, |
| 432 | + namespace=ns, |
| 433 | + top_k=overall_topk, |
| 434 | + filter=filter, |
| 435 | + include_values=include_values, |
| 436 | + include_metadata=include_metadata, |
| 437 | + sparse_vector=sparse_vector, |
| 438 | + async_req=False, |
| 439 | + **kwargs, |
| 440 | + ) |
| 441 | + for ns in target_namespaces |
| 442 | + ] |
| 443 | + |
| 444 | + only_futures = cast(Iterable[Future], futures) |
| 445 | + for response in as_completed(only_futures): |
| 446 | + aggregator.add_results(response.result()) |
| 447 | + |
| 448 | + final_results = aggregator.get_results() |
| 449 | + return final_results |
| 450 | + |
405 | 451 | def update( |
406 | 452 | self, |
407 | 453 | id: str, |
|
0 commit comments