@@ -418,7 +418,7 @@ def __init__(
418418 limit : Optional [int ] = None ,
419419 page_size : int = DEFAULT_PAGE_SIZE ,
420420 sort_fields : Optional [List [str ]] = None ,
421- return_fields : Optional [List [str ]] = None ,
421+ projected_fields : Optional [List [str ]] = None ,
422422 nocontent : bool = False ,
423423 ):
424424 if not has_redisearch (model .db ()):
@@ -443,10 +443,10 @@ def __init__(
443443 else :
444444 self .sort_fields = []
445445
446- if return_fields :
447- self .return_fields = self .validate_return_fields ( return_fields )
446+ if projected_fields :
447+ self .projected_fields = self .validate_projected_fields ( projected_fields )
448448 else :
449- self .return_fields = []
449+ self .projected_fields = []
450450
451451 self ._expression = None
452452 self ._query : Optional [str ] = None
@@ -505,18 +505,45 @@ def query(self):
505505 if self ._query .startswith ("(" ) or self ._query == "*"
506506 else f"({ self ._query } )"
507507 ) + f"=>[{ self .knn } ]"
508- if self .return_fields :
509- self ._query += f" RETURN { ',' .join (self .return_fields )} "
508+ # RETURN clause should be added to args, not to the query string
510509 return self ._query
511510
512- def validate_return_fields (self , return_fields : List [str ]):
513- for field in return_fields :
514- if field not in self .model .__fields__ : # type: ignore
511+ def validate_projected_fields (self , projected_fields : List [str ]):
512+ for field in projected_fields :
513+ if field not in self .model .model_fields : # type: ignore
515514 raise QueryNotSupportedError (
516515 f"You tried to return the field { field } , but that field "
517516 f"does not exist on the model { self .model } "
518517 )
519- return return_fields
518+ return projected_fields
519+
520+ def _parse_projected_results (self , res : Any ) -> List [Dict [str , Any ]]:
521+ """Parse results when using RETURN clause with specific fields."""
522+
523+ def to_string (s ):
524+ if isinstance (s , (str ,)):
525+ return s
526+ elif isinstance (s , bytes ):
527+ return s .decode (errors = "ignore" )
528+ else :
529+ return s
530+
531+ docs = []
532+ step = 2 # Because the result has content
533+ offset = 1 # The first item is the count of total matches.
534+
535+ for i in range (1 , len (res ), step ):
536+ if res [i + offset ] is None :
537+ continue
538+ # When using RETURN, we get flat key-value pairs
539+ fields : Dict [str , str ] = dict (
540+ zip (
541+ map (to_string , res [i + offset ][::2 ]),
542+ map (to_string , res [i + offset ][1 ::2 ]),
543+ )
544+ )
545+ docs .append (fields )
546+ return docs
520547
521548 @property
522549 def query_params (self ):
@@ -899,6 +926,12 @@ async def execute(
899926 if self .nocontent :
900927 args .append ("NOCONTENT" )
901928
929+ # Add RETURN clause to the args list, not to the query string
930+ if self .projected_fields :
931+ args .extend (
932+ ["RETURN" , str (len (self .projected_fields ))] + self .projected_fields
933+ )
934+
902935 if return_query_args :
903936 return self .model .Meta .index_name , args
904937
@@ -912,7 +945,12 @@ async def execute(
912945 if return_raw_result :
913946 return raw_result
914947 count = raw_result [0 ]
915- results = self .model .from_redis (raw_result , self .knn )
948+
949+ # If we're using field projection, return dictionaries instead of model instances
950+ if self .projected_fields :
951+ results = self ._parse_projected_results (raw_result )
952+ else :
953+ results = self .model .from_redis (raw_result , self .knn )
916954 self ._model_cache += results
917955
918956 if not exhaust_results :
@@ -966,11 +1004,11 @@ def sort_by(self, *fields: str):
9661004 if not fields :
9671005 return self
9681006 return self .copy (sort_fields = list (fields ))
969-
1007+
9701008 def return_fields (self , * fields : str ):
9711009 if not fields :
9721010 return self
973- return self .copy (return_fields = list (fields ))
1011+ return self .copy (projected_fields = list (fields ))
9741012
9751013 async def update (self , use_transaction = True , ** field_values ):
9761014 """
@@ -1546,9 +1584,7 @@ def find(
15461584 * expressions : Union [Any , Expression ],
15471585 knn : Optional [KNNExpression ] = None ,
15481586 ) -> FindQuery :
1549- return FindQuery (
1550- expressions = expressions , knn = knn , model = cls
1551- )
1587+ return FindQuery (expressions = expressions , knn = knn , model = cls )
15521588
15531589 @classmethod
15541590 def from_redis (cls , res : Any , knn : Optional [KNNExpression ] = None ):
0 commit comments