Skip to content

Commit 77a731b

Browse files
committed
Add early return logic in results merge
1 parent ae244a2 commit 77a731b

File tree

2 files changed

+150
-2
lines changed

2 files changed

+150
-2
lines changed

pinecone/grpc/query_results_aggregator.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,13 @@ def _process_matches(self, matches, ns, heap_item_fn):
134134
if len(self.heap) < self.top_k:
135135
heapq.heappush(self.heap, heap_item_fn(match, ns))
136136
else:
137+
# Assume we have dotproduct scores sorted in descending order
138+
if self.is_dotproduct and match["score"] < self.heap[0][0]:
139+
# No further matches can improve the top-K heap
140+
break
141+
elif not self.is_dotproduct and match["score"] > -self.heap[0][0]:
142+
# No further matches can improve the top-K heap
143+
break
137144
heapq.heappushpop(self.heap, heap_item_fn(match, ns))
138145

139146
def add_results(self, results: Dict[str, Any]):
@@ -156,9 +163,9 @@ def add_results(self, results: Dict[str, Any]):
156163
self.is_dotproduct = self._is_dotproduct_index(matches)
157164

158165
if self.is_dotproduct:
159-
self._process_matches(matches, ns, self._dotproduct_heap_item)
166+
self._process_matches2(matches, ns, self._dotproduct_heap_item)
160167
else:
161-
self._process_matches(matches, ns, self._non_dotproduct_heap_item)
168+
self._process_matches2(matches, ns, self._non_dotproduct_heap_item)
162169

163170
def get_results(self) -> QueryNamespacesResults:
164171
if self.read:

tests/unit_grpc/test_query_results_aggregator.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
QueryResultsAggregatorInvalidTopKError,
44
QueryResultsAggregregatorNotEnoughResultsError,
55
)
6+
import random
67
import pytest
78

89

@@ -125,6 +126,146 @@ def test_correctly_handles_dotproduct_metric(self):
125126
assert results.matches[1].id == "1" # 0.9
126127
assert results.matches[2].id == "8" # 0.88
127128

129+
def test_still_correct_with_early_return(self):
130+
aggregator = QueryResultsAggregator(top_k=5)
131+
132+
results1 = {
133+
"matches": [
134+
{"id": "1", "score": 0.1, "values": []},
135+
{"id": "2", "score": 0.11, "values": []},
136+
{"id": "3", "score": 0.12, "values": []},
137+
{"id": "4", "score": 0.13, "values": []},
138+
{"id": "5", "score": 0.14, "values": []},
139+
],
140+
"usage": {"readUnits": 5},
141+
"namespace": "ns1",
142+
}
143+
aggregator.add_results(results1)
144+
145+
results2 = {
146+
"matches": [
147+
{"id": "6", "score": 0.10, "values": []},
148+
{"id": "7", "score": 0.101, "values": []},
149+
{"id": "8", "score": 0.12, "values": []},
150+
{"id": "9", "score": 0.13, "values": []},
151+
{"id": "10", "score": 0.14, "values": []},
152+
],
153+
"usage": {"readUnits": 5},
154+
"namespace": "ns2",
155+
}
156+
aggregator.add_results(results2)
157+
158+
results = aggregator.get_results()
159+
assert results.usage.read_units == 10
160+
assert len(results.matches) == 5
161+
assert results.matches[0].id == "1"
162+
assert results.matches[1].id == "6"
163+
assert results.matches[2].id == "7"
164+
assert results.matches[3].id == "2"
165+
assert results.matches[4].id == "3"
166+
167+
def test_still_correct_with_early_return_generated_nont_dotproduct(self):
168+
aggregator = QueryResultsAggregator(top_k=1000)
169+
matches1 = [
170+
{"id": f"id{i}", "score": random.random(), "values": []} for i in range(1, 1000)
171+
]
172+
matches1.sort(key=lambda x: x["score"], reverse=False)
173+
174+
matches2 = [
175+
{"id": f"id{i}", "score": random.random(), "values": []} for i in range(1001, 2000)
176+
]
177+
matches2.sort(key=lambda x: x["score"], reverse=False)
178+
179+
matches3 = [
180+
{"id": f"id{i}", "score": random.random(), "values": []} for i in range(2001, 3000)
181+
]
182+
matches3.sort(key=lambda x: x["score"], reverse=False)
183+
184+
matches4 = [
185+
{"id": f"id{i}", "score": random.random(), "values": []} for i in range(3001, 4000)
186+
]
187+
matches4.sort(key=lambda x: x["score"], reverse=False)
188+
189+
matches5 = [
190+
{"id": f"id{i}", "score": random.random(), "values": []} for i in range(4001, 5000)
191+
]
192+
matches5.sort(key=lambda x: x["score"], reverse=False)
193+
194+
results1 = {"matches": matches1, "namespace": "ns1", "usage": {"readUnits": 5}}
195+
results2 = {"matches": matches2, "namespace": "ns2", "usage": {"readUnits": 5}}
196+
results3 = {"matches": matches3, "namespace": "ns3", "usage": {"readUnits": 5}}
197+
results4 = {"matches": matches4, "namespace": "ns4", "usage": {"readUnits": 5}}
198+
results5 = {"matches": matches5, "namespace": "ns5", "usage": {"readUnits": 5}}
199+
200+
aggregator.add_results(results1)
201+
aggregator.add_results(results2)
202+
aggregator.add_results(results3)
203+
aggregator.add_results(results4)
204+
aggregator.add_results(results5)
205+
206+
merged_matches = matches1 + matches2 + matches3 + matches4 + matches5
207+
merged_matches.sort(key=lambda x: x["score"], reverse=False)
208+
209+
results = aggregator.get_results()
210+
assert results.usage.read_units == 25
211+
assert len(results.matches) == 1000
212+
assert results.matches[0].score == merged_matches[0]["score"]
213+
assert results.matches[1].score == merged_matches[1]["score"]
214+
assert results.matches[2].score == merged_matches[2]["score"]
215+
assert results.matches[3].score == merged_matches[3]["score"]
216+
assert results.matches[4].score == merged_matches[4]["score"]
217+
218+
def test_still_correct_with_early_return_generated_dotproduct(self):
219+
aggregator = QueryResultsAggregator(top_k=1000)
220+
matches1 = [
221+
{"id": f"id{i}", "score": random.random(), "values": []} for i in range(1, 1000)
222+
]
223+
matches1.sort(key=lambda x: x["score"], reverse=True)
224+
225+
matches2 = [
226+
{"id": f"id{i}", "score": random.random(), "values": []} for i in range(1001, 2000)
227+
]
228+
matches2.sort(key=lambda x: x["score"], reverse=True)
229+
230+
matches3 = [
231+
{"id": f"id{i}", "score": random.random(), "values": []} for i in range(2001, 3000)
232+
]
233+
matches3.sort(key=lambda x: x["score"], reverse=True)
234+
235+
matches4 = [
236+
{"id": f"id{i}", "score": random.random(), "values": []} for i in range(3001, 4000)
237+
]
238+
matches4.sort(key=lambda x: x["score"], reverse=True)
239+
240+
matches5 = [
241+
{"id": f"id{i}", "score": random.random(), "values": []} for i in range(4001, 5000)
242+
]
243+
matches5.sort(key=lambda x: x["score"], reverse=True)
244+
245+
results1 = {"matches": matches1, "namespace": "ns1", "usage": {"readUnits": 5}}
246+
results2 = {"matches": matches2, "namespace": "ns2", "usage": {"readUnits": 5}}
247+
results3 = {"matches": matches3, "namespace": "ns3", "usage": {"readUnits": 5}}
248+
results4 = {"matches": matches4, "namespace": "ns4", "usage": {"readUnits": 5}}
249+
results5 = {"matches": matches5, "namespace": "ns5", "usage": {"readUnits": 5}}
250+
251+
aggregator.add_results(results1)
252+
aggregator.add_results(results2)
253+
aggregator.add_results(results3)
254+
aggregator.add_results(results4)
255+
aggregator.add_results(results5)
256+
257+
merged_matches = matches1 + matches2 + matches3 + matches4 + matches5
258+
merged_matches.sort(key=lambda x: x["score"], reverse=True)
259+
260+
results = aggregator.get_results()
261+
assert results.usage.read_units == 25
262+
assert len(results.matches) == 1000
263+
assert results.matches[0].score == merged_matches[0]["score"]
264+
assert results.matches[1].score == merged_matches[1]["score"]
265+
assert results.matches[2].score == merged_matches[2]["score"]
266+
assert results.matches[3].score == merged_matches[3]["score"]
267+
assert results.matches[4].score == merged_matches[4]["score"]
268+
128269

129270
class TestQueryResultsAggregatorOutputUX:
130271
def test_can_interact_with_attributes(self):

0 commit comments

Comments
 (0)