1- from typing import Optional , Union , List , Dict , Awaitable
1+ from typing import Optional , Union , List , Dict , Awaitable , Any
22
33from tqdm .asyncio import tqdm
44import asyncio
@@ -91,14 +91,14 @@ async def upsert(
9191 max_concurrent_requests : Optional [int ] = None ,
9292 semaphore : Optional [asyncio .Semaphore ] = None ,
9393 ** kwargs ,
94- ) -> Awaitable [ UpsertResponse ] :
94+ ) -> UpsertResponse :
9595 timeout = kwargs .pop ("timeout" , None )
9696 vectors = list (map (VectorFactoryGRPC .build , vectors ))
9797 semaphore = self ._get_semaphore (max_concurrent_requests , semaphore )
9898
9999 if batch_size is None :
100100 return await self ._upsert_batch (
101- vectors , namespace , timeout = timeout , semaphore = semaphore , ** kwargs
101+ vectors = vectors , namespace = namespace , timeout = timeout , semaphore = semaphore , ** kwargs
102102 )
103103
104104 if not isinstance (batch_size , int ) or batch_size <= 0 :
@@ -132,7 +132,7 @@ async def _upsert_batch(
132132 namespace : Optional [str ],
133133 timeout : Optional [int ] = None ,
134134 ** kwargs ,
135- ) -> Awaitable [ UpsertResponse ] :
135+ ) -> UpsertResponse :
136136 args_dict = parse_non_empty_args ([("namespace" , namespace )])
137137 request = UpsertRequest (vectors = vectors , ** args_dict )
138138 return await self .runner .run_asyncio (
@@ -151,7 +151,7 @@ async def _query(
151151 sparse_vector : Optional [Union [GRPCSparseValues , SparseVectorTypedDict ]] = None ,
152152 semaphore : Optional [asyncio .Semaphore ] = None ,
153153 ** kwargs ,
154- ) -> Awaitable [ Dict ]:
154+ ) -> dict [ str , Any ]:
155155 if vector is not None and id is not None :
156156 raise ValueError ("Cannot specify both `id` and `vector`" )
157157
@@ -182,7 +182,8 @@ async def _query(
182182 response = await self .runner .run_asyncio (
183183 self .stub .Query , request , timeout = timeout , semaphore = semaphore
184184 )
185- return json_format .MessageToDict (response )
185+ parsed = json_format .MessageToDict (response )
186+ return parsed
186187
187188 async def query (
188189 self ,
@@ -196,7 +197,7 @@ async def query(
196197 sparse_vector : Optional [Union [GRPCSparseValues , SparseVectorTypedDict ]] = None ,
197198 semaphore : Optional [asyncio .Semaphore ] = None ,
198199 ** kwargs ,
199- ) -> Awaitable [ QueryResponse ] :
200+ ) -> QueryResponse :
200201 """
201202 The Query operation searches a namespace, using a query vector.
202203 It retrieves the ids of the most similar items in a namespace, along with their similarity scores.
@@ -257,9 +258,9 @@ async def query(
257258
258259 async def composite_query (
259260 self ,
260- vector : Optional [ List [float ]] = None ,
261- namespaces : Optional [ List [str ]] = None ,
262- top_k : Optional [int ] = 10 ,
261+ vector : List [float ],
262+ namespaces : List [str ],
263+ top_k : Optional [int ] = None ,
263264 filter : Optional [Dict [str , Union [str , float , int , bool , List , dict ]]] = None ,
264265 include_values : Optional [bool ] = None ,
265266 include_metadata : Optional [bool ] = None ,
@@ -268,17 +269,23 @@ async def composite_query(
268269 max_concurrent_requests : Optional [int ] = None ,
269270 semaphore : Optional [asyncio .Semaphore ] = None ,
270271 ** kwargs ,
271- ) -> Awaitable [ CompositeQueryResults ] :
272+ ) -> CompositeQueryResults :
272273 aggregator_lock = asyncio .Lock ()
273274 semaphore = self ._get_semaphore (max_concurrent_requests , semaphore )
274275
275- # The caller may only want the topK=1 result across all queries,
276+ if len (namespaces ) == 0 :
277+ raise ValueError ("At least one namespace must be specified" )
278+ if len (vector ) == 0 :
279+ raise ValueError ("Query vector must not be empty" )
280+
281+ # The caller may only want the top_k=1 result across all queries,
276282 # but we need to get at least 2 results from each query in order to
277283 # aggregate them correctly. So we'll temporarily set topK to 2 for the
278284 # subqueries, and then we'll take the topK=1 results from the aggregated
279285 # results.
280- aggregator = QueryResultsAggregator (top_k = top_k )
281- subquery_topk = top_k if top_k > 2 else 2
286+ overall_topk = top_k if top_k is not None else 10
287+ aggregator = QueryResultsAggregator (top_k = overall_topk )
288+ subquery_topk = overall_topk if overall_topk > 2 else 2
282289
283290 target_namespaces = set (namespaces ) # dedup namespaces
284291 query_tasks = [
0 commit comments