@@ -252,6 +252,24 @@ def tree(self):
252252 return render_tree (self )
253253
254254
255+ @dataclasses .dataclass
256+ class KNNExpression :
257+ k : int
258+ vector_field : ModelField
259+ reference_vector : bytes
260+
261+ def __str__ (self ):
262+ return f"KNN $K @{ self .vector_field .name } $knn_ref_vector"
263+
264+ @property
265+ def query_params (self ) -> Dict [str , Union [str , bytes ]]:
266+ return {"K" : str (self .k ), "knn_ref_vector" : self .reference_vector }
267+
268+ @property
269+ def score_field (self ) -> str :
270+ return f"__{ self .vector_field .name } _score"
271+
272+
255273ExpressionOrNegated = Union [Expression , NegatedExpression ]
256274
257275
@@ -349,8 +367,9 @@ def __init__(
349367 self ,
350368 expressions : Sequence [ExpressionOrNegated ],
351369 model : Type ["RedisModel" ],
370+ knn : Optional [KNNExpression ] = None ,
352371 offset : int = 0 ,
353- limit : int = DEFAULT_PAGE_SIZE ,
372+ limit : Optional [ int ] = None ,
354373 page_size : int = DEFAULT_PAGE_SIZE ,
355374 sort_fields : Optional [List [str ]] = None ,
356375 nocontent : bool = False ,
@@ -364,13 +383,16 @@ def __init__(
364383
365384 self .expressions = expressions
366385 self .model = model
386+ self .knn = knn
367387 self .offset = offset
368- self .limit = limit
388+ self .limit = limit or ( self . knn . k if self . knn else DEFAULT_PAGE_SIZE )
369389 self .page_size = page_size
370390 self .nocontent = nocontent
371391
372392 if sort_fields :
373393 self .sort_fields = self .validate_sort_fields (sort_fields )
394+ elif self .knn :
395+ self .sort_fields = [self .knn .score_field ]
374396 else :
375397 self .sort_fields = []
376398
@@ -425,11 +447,26 @@ def query(self):
425447 if self ._query :
426448 return self ._query
427449 self ._query = self .resolve_redisearch_query (self .expression )
450+ if self .knn :
451+ self ._query = (
452+ self ._query
453+ if self ._query .startswith ("(" ) or self ._query == "*"
454+ else f"({ self ._query } )"
455+ ) + f"=>[{ self .knn } ]"
428456 return self ._query
429457
458+ @property
459+ def query_params (self ):
460+ params : List [Union [str , bytes ]] = []
461+ if self .knn :
462+ params += [attr for kv in self .knn .query_params .items () for attr in kv ]
463+ return params
464+
430465 def validate_sort_fields (self , sort_fields : List [str ]):
431466 for sort_field in sort_fields :
432467 field_name = sort_field .lstrip ("-" )
468+ if self .knn and field_name == self .knn .score_field :
469+ continue
433470 if field_name not in self .model .__fields__ :
434471 raise QueryNotSupportedError (
435472 f"You tried sort by { field_name } , but that field "
@@ -728,10 +765,27 @@ def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str:
728765 return result
729766
730767 async def execute (self , exhaust_results = True , return_raw_result = False ):
731- args = ["ft.search" , self .model .Meta .index_name , self .query , * self .pagination ]
768+ args : List [Union [str , bytes ]] = [
769+ "FT.SEARCH" ,
770+ self .model .Meta .index_name ,
771+ self .query ,
772+ * self .pagination ,
773+ ]
732774 if self .sort_fields :
733775 args += self .resolve_redisearch_sort_fields ()
734776
777+ if self .query_params :
778+ args += ["PARAMS" , str (len (self .query_params ))] + self .query_params
779+
780+ if self .knn :
781+ # Ensure DIALECT is at least 2
782+ if "DIALECT" not in args :
783+ args += ["DIALECT" , "2" ]
784+ else :
785+ i_dialect = args .index ("DIALECT" ) + 1
786+ if int (args [i_dialect ]) < 2 :
787+ args [i_dialect ] = "2"
788+
735789 if self .nocontent :
736790 args .append ("NOCONTENT" )
737791
@@ -917,11 +971,13 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
917971 sortable = kwargs .pop ("sortable" , Undefined )
918972 index = kwargs .pop ("index" , Undefined )
919973 full_text_search = kwargs .pop ("full_text_search" , Undefined )
974+ vector_options = kwargs .pop ("vector_options" , None )
920975 super ().__init__ (default = default , ** kwargs )
921976 self .primary_key = primary_key
922977 self .sortable = sortable
923978 self .index = index
924979 self .full_text_search = full_text_search
980+ self .vector_options = vector_options
925981
926982
927983class RelationshipInfo (Representation ):
@@ -935,6 +991,94 @@ def __init__(
935991 self .link_model = link_model
936992
937993
994+ @dataclasses .dataclass
995+ class VectorFieldOptions :
996+ class ALGORITHM (Enum ):
997+ FLAT = "FLAT"
998+ HNSW = "HNSW"
999+
1000+ class TYPE (Enum ):
1001+ FLOAT32 = "FLOAT32"
1002+ FLOAT64 = "FLOAT64"
1003+
1004+ class DISTANCE_METRIC (Enum ):
1005+ L2 = "L2"
1006+ IP = "IP"
1007+ COSINE = "COSINE"
1008+
1009+ algorithm : ALGORITHM
1010+ type : TYPE
1011+ dimension : int
1012+ distance_metric : DISTANCE_METRIC
1013+
1014+ # Common optional parameters
1015+ initial_cap : Optional [int ] = None
1016+
1017+ # Optional parameters for FLAT
1018+ block_size : Optional [int ] = None
1019+
1020+ # Optional parameters for HNSW
1021+ m : Optional [int ] = None
1022+ ef_construction : Optional [int ] = None
1023+ ef_runtime : Optional [int ] = None
1024+ epsilon : Optional [float ] = None
1025+
1026+ @staticmethod
1027+ def flat (
1028+ type : TYPE ,
1029+ dimension : int ,
1030+ distance_metric : DISTANCE_METRIC ,
1031+ initial_cap : Optional [int ] = None ,
1032+ block_size : Optional [int ] = None ,
1033+ ):
1034+ return VectorFieldOptions (
1035+ algorithm = VectorFieldOptions .ALGORITHM .FLAT ,
1036+ type = type ,
1037+ dimension = dimension ,
1038+ distance_metric = distance_metric ,
1039+ initial_cap = initial_cap ,
1040+ block_size = block_size ,
1041+ )
1042+
1043+ @staticmethod
1044+ def hnsw (
1045+ type : TYPE ,
1046+ dimension : int ,
1047+ distance_metric : DISTANCE_METRIC ,
1048+ initial_cap : Optional [int ] = None ,
1049+ m : Optional [int ] = None ,
1050+ ef_construction : Optional [int ] = None ,
1051+ ef_runtime : Optional [int ] = None ,
1052+ epsilon : Optional [float ] = None ,
1053+ ):
1054+ return VectorFieldOptions (
1055+ algorithm = VectorFieldOptions .ALGORITHM .HNSW ,
1056+ type = type ,
1057+ dimension = dimension ,
1058+ distance_metric = distance_metric ,
1059+ initial_cap = initial_cap ,
1060+ m = m ,
1061+ ef_construction = ef_construction ,
1062+ ef_runtime = ef_runtime ,
1063+ epsilon = epsilon ,
1064+ )
1065+
1066+ @property
1067+ def schema (self ):
1068+ attr = []
1069+ for k , v in vars (self ).items ():
1070+ if k == "algorithm" or v is None :
1071+ continue
1072+ attr .extend (
1073+ [
1074+ k .upper () if k != "dimension" else "DIM" ,
1075+ str (v ) if not isinstance (v , Enum ) else v .name ,
1076+ ]
1077+ )
1078+
1079+ return " " .join ([f"VECTOR { self .algorithm .name } { len (attr )} " ] + attr )
1080+
1081+
9381082def Field (
9391083 default : Any = Undefined ,
9401084 * ,
@@ -964,6 +1108,7 @@ def Field(
9641108 sortable : Union [bool , UndefinedType ] = Undefined ,
9651109 index : Union [bool , UndefinedType ] = Undefined ,
9661110 full_text_search : Union [bool , UndefinedType ] = Undefined ,
1111+ vector_options : Optional [VectorFieldOptions ] = None ,
9671112 schema_extra : Optional [Dict [str , Any ]] = None ,
9681113) -> Any :
9691114 current_schema_extra = schema_extra or {}
@@ -991,6 +1136,7 @@ def Field(
9911136 sortable = sortable ,
9921137 index = index ,
9931138 full_text_search = full_text_search ,
1139+ vector_options = vector_options ,
9941140 ** current_schema_extra ,
9951141 )
9961142 field_info ._validate ()
@@ -1083,6 +1229,10 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901
10831229 new_class ._meta .primary_key = PrimaryKey (
10841230 name = field_name , field = field
10851231 )
1232+ if field .field_info .vector_options :
1233+ score_attr = f"_{ field_name } _score"
1234+ setattr (new_class , score_attr , None )
1235+ new_class .__annotations__ [score_attr ] = Union [float , None ]
10861236
10871237 if not getattr (new_class ._meta , "global_key_prefix" , None ):
10881238 new_class ._meta .global_key_prefix = getattr (
@@ -1216,8 +1366,12 @@ def db(cls):
12161366 return cls ._meta .database
12171367
12181368 @classmethod
1219- def find (cls , * expressions : Union [Any , Expression ]) -> FindQuery :
1220- return FindQuery (expressions = expressions , model = cls )
1369+ def find (
1370+ cls ,
1371+ * expressions : Union [Any , Expression ],
1372+ knn : Optional [KNNExpression ] = None ,
1373+ ) -> FindQuery :
1374+ return FindQuery (expressions = expressions , knn = knn , model = cls )
12211375
12221376 @classmethod
12231377 def from_redis (cls , res : Any ):
@@ -1237,7 +1391,7 @@ def to_string(s):
12371391 for i in range (1 , len (res ), step ):
12381392 if res [i + offset ] is None :
12391393 continue
1240- fields = dict (
1394+ fields : Dict [ str , str ] = dict (
12411395 zip (
12421396 map (to_string , res [i + offset ][::2 ]),
12431397 map (to_string , res [i + offset ][1 ::2 ]),
@@ -1247,6 +1401,9 @@ def to_string(s):
12471401 if fields .get ("$" ):
12481402 json_fields = json .loads (fields .pop ("$" ))
12491403 doc = cls (** json_fields )
1404+ for k , v in fields .items ():
1405+ if k .startswith ("__" ) and k .endswith ("_score" ):
1406+ setattr (doc , k [1 :], float (v ))
12501407 else :
12511408 doc = cls (** fields )
12521409
@@ -1474,7 +1631,13 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo):
14741631 embedded_cls = embedded_cls [0 ]
14751632 schema = cls .schema_for_type (name , embedded_cls , field_info )
14761633 elif any (issubclass (typ , t ) for t in NUMERIC_TYPES ):
1477- schema = f"{ name } NUMERIC"
1634+ vector_options : Optional [VectorFieldOptions ] = getattr (
1635+ field_info , "vector_options" , None
1636+ )
1637+ if vector_options :
1638+ schema = f"{ name } { vector_options .schema } "
1639+ else :
1640+ schema = f"{ name } NUMERIC"
14781641 elif issubclass (typ , str ):
14791642 if getattr (field_info , "full_text_search" , False ) is True :
14801643 schema = (
@@ -1623,10 +1786,22 @@ def schema_for_type(
16231786 # Not a class, probably a type annotation
16241787 field_is_model = False
16251788
1789+ vector_options : Optional [VectorFieldOptions ] = getattr (
1790+ field_info , "vector_options" , None
1791+ )
1792+ try :
1793+ is_vector = vector_options and any (
1794+ issubclass (get_args (typ )[0 ], t ) for t in NUMERIC_TYPES
1795+ )
1796+ except IndexError :
1797+ raise RedisModelError (
1798+ f"Vector field '{ name } ' must be annotated as a container type"
1799+ )
1800+
16261801 # When we encounter a list or model field, we need to descend
16271802 # into the values of the list or the fields of the model to
16281803 # find any values marked as indexed.
1629- if is_container_type :
1804+ if is_container_type and not is_vector :
16301805 field_type = get_origin (typ )
16311806 embedded_cls = get_args (typ )
16321807 if not embedded_cls :
@@ -1689,7 +1864,9 @@ def schema_for_type(
16891864 )
16901865
16911866 # TODO: GEO field
1692- if parent_is_container_type or parent_is_model_in_container :
1867+ if is_vector and vector_options :
1868+ schema = f"{ path } AS { index_field_name } { vector_options .schema } "
1869+ elif parent_is_container_type or parent_is_model_in_container :
16931870 if typ is not str :
16941871 raise RedisModelError (
16951872 "In this Preview release, list and tuple fields can only "
0 commit comments