@@ -81,39 +81,85 @@ def __repr__(self):
8181 )
8282
8383
84+ class QueryResultsAggregationEmptyResultsError (Exception ):
85+ def __init__ (self , namespace : str ):
86+ super ().__init__ (
87+ f"Cannot infer metric type from empty query results. Query result for namespace '{ namespace } ' is empty. Have you spelled the namespace name correctly?"
88+ )
89+
90+
91+ class QueryResultsAggregregatorNotEnoughResultsError (Exception ):
92+ def __init__ (self , top_k : int , num_results : int ):
93+ super ().__init__ (
94+ f"Cannot interpret results without at least two matches. In order to aggregate results from multiple queries, top_k must be greater than 1 in order to correctly infer the similarity metric from scores. Expected at least { top_k } results but got { num_results } ."
95+ )
96+
97+
98+ class QueryResultsAggregatorInvalidTopKError (Exception ):
99+ def __init__ (self , top_k : int ):
100+ super ().__init__ (f"Invalid top_k value { top_k } . top_k must be a positive integer." )
101+
102+
84103class QueryResultsAggregator :
85104 def __init__ (self , top_k : int ):
105+ if top_k < 1 :
106+ raise QueryResultsAggregatorInvalidTopKError (top_k )
86107 self .top_k = top_k
87108 self .usage_read_units = 0
88109 self .heap = []
89110 self .insertion_counter = 0
111+ self .is_dotproduct = None
90112 self .read = False
91113
114+ def __is_dotproduct_index (self , matches ):
115+ # The interpretation of the score depends on the similar metric used.
116+ # Unlike other index types, in indexes configured for dotproduct,
117+ # a higher score is better. We have to infer this is the case by inspecting
118+ # the order of the scores in the results.
119+ for i in range (1 , len (matches )):
120+ if matches [i ].get ("score" ) > matches [i - 1 ].get ("score" ): # Found an increase
121+ return False
122+ return True
123+
92124 def add_results (self , results : QueryResponse ):
93125 if self .read :
126+ # This is mainly just to sanity check in test cases which get quite confusing
127+ # if you read results twice due to the heap being emptied when constructing
128+ # the ordered results.
94129 raise ValueError ("Results have already been read. Cannot add more results." )
95130
96- self . usage_read_units + = results .get ("usage " , {}). get ( "readUnits" , 0 )
131+ matches = results .get ("matches " , [] )
97132 ns = results .get ("namespace" )
98- for match in results .get ("matches" , []):
99- self .insertion_counter += 1
100- score = match .get ("score" )
101- if len (self .heap ) < self .top_k :
102- heapq .heappush (self .heap , (- score , - self .insertion_counter , match , ns ))
103- else :
104- heapq .heappushpop (self .heap , (- score , - self .insertion_counter , match , ns ))
133+ self .usage_read_units += results .get ("usage" , {}).get ("readUnits" , 0 )
134+
135+ if self .is_dotproduct is None :
136+ if len (matches ) == 0 :
137+ raise QueryResultsAggregationEmptyResultsError (ns )
138+ if len (matches ) == 1 :
139+ raise QueryResultsAggregregatorNotEnoughResultsError (self .top_k , len (matches ))
140+ self .is_dotproduct = self .__is_dotproduct_index (matches )
141+
142+ print ("is_dotproduct:" , self .is_dotproduct )
143+ if self .is_dotproduct :
144+ raise NotImplementedError ("Dotproduct indexes are not yet supported." )
145+ else :
146+ for match in matches :
147+ self .insertion_counter += 1
148+ score = match .get ("score" )
149+ if len (self .heap ) < self .top_k :
150+ heapq .heappush (self .heap , (- score , - self .insertion_counter , match , ns ))
151+ else :
152+ heapq .heappushpop (self .heap , (- score , - self .insertion_counter , match , ns ))
105153
106154 def get_results (self ) -> CompositeQueryResults :
107155 if self .read :
108- # This is mainly just to sanity check in test cases which get quite confusing
109- # if you read results twice due to the heap being emptied each time you read
110- # results into an ordered list.
111- raise ValueError ("Results have already been read. Cannot read again." )
156+ return self .final_results
112157 self .read = True
113158
114- return CompositeQueryResults (
159+ self . final_results = CompositeQueryResults (
115160 usage = Usage (read_units = self .usage_read_units ),
116161 matches = [
117162 ScoredVectorWithNamespace (heapq .heappop (self .heap )) for _ in range (len (self .heap ))
118163 ][::- 1 ],
119164 )
165+ return self .final_results
0 commit comments