11import itertools
22import time
3- from typing import Dict , Union
3+ from typing import Dict , Optional , Union
44
55from redis .client import Pipeline
66
@@ -363,7 +363,11 @@ def info(self):
363363 it = map (to_string , res )
364364 return dict (zip (it , it ))
365365
366- def get_params_args (self , query_params : Dict [str , Union [str , int , float ]]):
366+ def get_params_args (
367+ self , query_params : Union [Dict [str , Union [str , int , float ]], None ]
368+ ):
369+ if query_params is None :
370+ return []
367371 args = []
368372 if len (query_params ) > 0 :
369373 args .append ("params" )
@@ -383,8 +387,7 @@ def _mk_query_args(self, query, query_params: Dict[str, Union[str, int, float]])
383387 raise ValueError (f"Bad query type { type (query )} " )
384388
385389 args += query .get_args ()
386- if query_params is not None :
387- args += self .get_params_args (query_params )
390+ args += self .get_params_args (query_params )
388391
389392 return args , query
390393
@@ -459,8 +462,7 @@ def aggregate(
459462 cmd = [CURSOR_CMD , "READ" , self .index_name ] + query .build_args ()
460463 else :
461464 raise ValueError ("Bad query" , query )
462- if query_params is not None :
463- cmd += self .get_params_args (query_params )
465+ cmd += self .get_params_args (query_params )
464466
465467 raw = self .execute_command (* cmd )
466468 return self ._get_aggregate_result (raw , query , has_cursor )
@@ -485,16 +487,22 @@ def _get_aggregate_result(self, raw, query, has_cursor):
485487
486488 return AggregateResult (rows , cursor , schema )
487489
488- def profile (self , query , limited = False ):
490+ def profile (
491+ self ,
492+ query : Union [str , Query , AggregateRequest ],
493+ limited : bool = False ,
494+ query_params : Optional [Dict [str , Union [str , int , float ]]] = None ,
495+ ):
489496 """
490497 Performs a search or aggregate command and collects performance
491498 information.
492499
493500 ### Parameters
494501
495- **query**: This can be either an `AggregateRequest`, `Query` or
496- string.
502+ **query**: This can be either an `AggregateRequest`, `Query` or string.
497503 **limited**: If set to True, removes details of reader iterator.
504+ **query_params**: Define one or more value parameters.
505+ Each parameter has a name and a value.
498506
499507 """
500508 st = time .time ()
@@ -509,6 +517,7 @@ def profile(self, query, limited=False):
509517 elif isinstance (query , Query ):
510518 cmd [2 ] = "SEARCH"
511519 cmd += query .get_args ()
520+ cmd += self .get_params_args (query_params )
512521 else :
513522 raise ValueError ("Must provide AggregateRequest object or " "Query object." )
514523
@@ -907,8 +916,7 @@ async def aggregate(
907916 cmd = [CURSOR_CMD , "READ" , self .index_name ] + query .build_args ()
908917 else :
909918 raise ValueError ("Bad query" , query )
910- if query_params is not None :
911- cmd += self .get_params_args (query_params )
919+ cmd += self .get_params_args (query_params )
912920
913921 raw = await self .execute_command (* cmd )
914922 return self ._get_aggregate_result (raw , query , has_cursor )
0 commit comments