Skip to content

Commit 57ceb2c

Browse files
fix: lint
1 parent 73d4400 commit 57ceb2c

File tree

1 file changed

+37
-30
lines changed

1 file changed

+37
-30
lines changed

tests/unit_tests/v2/test_hybrid_search_config.py

Lines changed: 37 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -21,35 +21,41 @@ def get_row(doc_id: str, score: float, content: str = "content") -> dict:
2121

2222

2323
class TestWeightedSumRanking:
24-
def test_empty_inputs(self):
24+
def test_empty_inputs(self) -> None:
2525
results = weighted_sum_ranking([], [])
2626
assert results == []
2727

28-
def test_primary_only(self):
28+
def test_primary_only(self) -> None:
2929
primary = [get_row("p1", 0.8), get_row("p2", 0.6)]
3030
# Expected scores: p1 = 0.8 * 0.5 = 0.4, p2 = 0.6 * 0.5 = 0.3
31-
results = weighted_sum_ranking(
32-
primary, [], primary_results_weight=0.5, secondary_results_weight=0.5
31+
results = weighted_sum_ranking( # type: ignore
32+
primary, # type: ignore
33+
[],
34+
primary_results_weight=0.5,
35+
secondary_results_weight=0.5,
3336
)
3437
assert len(results) == 2
3538
assert results[0]["id_val"] == "p1"
3639
assert results[0]["distance"] == pytest.approx(0.4)
3740
assert results[1]["id_val"] == "p2"
3841
assert results[1]["distance"] == pytest.approx(0.3)
3942

40-
def test_secondary_only(self):
43+
def test_secondary_only(self) -> None:
4144
secondary = [get_row("s1", 0.9), get_row("s2", 0.7)]
4245
# Expected scores: s1 = 0.9 * 0.5 = 0.45, s2 = 0.7 * 0.5 = 0.35
4346
results = weighted_sum_ranking(
44-
[], secondary, primary_results_weight=0.5, secondary_results_weight=0.5
47+
[],
48+
secondary, # type: ignore
49+
primary_results_weight=0.5,
50+
secondary_results_weight=0.5,
4551
)
4652
assert len(results) == 2
4753
assert results[0]["id_val"] == "s1"
4854
assert results[0]["distance"] == pytest.approx(0.45)
4955
assert results[1]["id_val"] == "s2"
5056
assert results[1]["distance"] == pytest.approx(0.35)
5157

52-
def test_mixed_results_default_weights(self):
58+
def test_mixed_results_default_weights(self) -> None:
5359
primary = [get_row("common", 0.8), get_row("p_only", 0.7)]
5460
secondary = [get_row("common", 0.9), get_row("s_only", 0.6)]
5561
# Weights are 0.5, 0.5
@@ -58,7 +64,7 @@ def test_mixed_results_default_weights(self):
5864
# s_only_score = (0.6 * 0.5) = 0.30
5965
# Order: common (0.85), p_only (0.35), s_only (0.30)
6066

61-
results = weighted_sum_ranking(primary, secondary)
67+
results = weighted_sum_ranking(primary, secondary) # type: ignore
6268
assert len(results) == 3
6369
assert results[0]["id_val"] == "common"
6470
assert results[0]["distance"] == pytest.approx(0.85)
@@ -67,24 +73,26 @@ def test_mixed_results_default_weights(self):
6773
assert results[2]["id_val"] == "s_only"
6874
assert results[2]["distance"] == pytest.approx(0.30)
6975

70-
def test_mixed_results_custom_weights(self):
76+
def test_mixed_results_custom_weights(self) -> None:
7177
primary = [get_row("d1", 1.0)] # p_w=0.2 -> 0.2
7278
secondary = [get_row("d1", 0.5)] # s_w=0.8 -> 0.4
7379
# Expected: d1_score = (1.0 * 0.2) + (0.5 * 0.8) = 0.2 + 0.4 = 0.6
7480

7581
results = weighted_sum_ranking(
76-
primary, secondary, primary_results_weight=0.2, secondary_results_weight=0.8
82+
primary, # type: ignore
83+
secondary, # type: ignore
84+
primary_results_weight=0.2,
85+
secondary_results_weight=0.8,
7786
)
7887
assert len(results) == 1
7988
assert results[0]["id_val"] == "d1"
8089
assert results[0]["distance"] == pytest.approx(0.6)
8190

82-
def test_fetch_top_k(self):
91+
def test_fetch_top_k(self) -> None:
8392
primary = [get_row(f"p{i}", (10 - i) / 10.0) for i in range(5)]
8493
# Scores: 1.0, 0.9, 0.8, 0.7, 0.6
8594
# Weighted (0.5): 0.5, 0.45, 0.4, 0.35, 0.3
86-
secondary = []
87-
results = weighted_sum_ranking(primary, secondary, fetch_top_k=2)
95+
results = weighted_sum_ranking(primary, [], fetch_top_k=2) # type: ignore
8896
assert len(results) == 2
8997
assert results[0]["id_val"] == "p0"
9098
assert results[0]["distance"] == pytest.approx(0.5)
@@ -93,46 +101,46 @@ def test_fetch_top_k(self):
93101

94102

95103
class TestReciprocalRankFusion:
96-
def test_empty_inputs(self):
104+
def test_empty_inputs(self) -> None:
97105
results = reciprocal_rank_fusion([], [])
98106
assert results == []
99107

100-
def test_primary_only(self):
108+
def test_primary_only(self) -> None:
101109
primary = [
102110
get_row("p1", 0.8),
103111
get_row("p2", 0.6),
104112
] # p1 rank 0, p2 rank 1
105113
rrf_k = 60
106114
# p1_score = 1 / (0 + 60)
107115
# p2_score = 1 / (1 + 60)
108-
results = reciprocal_rank_fusion(primary, [], rrf_k=rrf_k)
116+
results = reciprocal_rank_fusion(primary, [], rrf_k=rrf_k) # type: ignore
109117
assert len(results) == 2
110118
assert results[0]["id_val"] == "p1"
111119
assert results[0]["distance"] == pytest.approx(1.0 / (0 + rrf_k))
112120
assert results[1]["id_val"] == "p2"
113121
assert results[1]["distance"] == pytest.approx(1.0 / (1 + rrf_k))
114122

115-
def test_secondary_only(self):
123+
def test_secondary_only(self) -> None:
116124
secondary = [
117125
get_row("s1", 0.9),
118126
get_row("s2", 0.7),
119127
] # s1 rank 0, s2 rank 1
120128
rrf_k = 60
121-
results = reciprocal_rank_fusion([], secondary, rrf_k=rrf_k)
129+
results = reciprocal_rank_fusion([], secondary, rrf_k=rrf_k) # type: ignore
122130
assert len(results) == 2
123131
assert results[0]["id_val"] == "s1"
124132
assert results[0]["distance"] == pytest.approx(1.0 / (0 + rrf_k))
125133
assert results[1]["id_val"] == "s2"
126134
assert results[1]["distance"] == pytest.approx(1.0 / (1 + rrf_k))
127135

128-
def test_mixed_results_default_k(self):
136+
def test_mixed_results_default_k(self) -> None:
129137
primary = [get_row("common", 0.8), get_row("p_only", 0.7)]
130138
secondary = [get_row("common", 0.9), get_row("s_only", 0.6)]
131139
rrf_k = 60
132140
# common_score = (1/(0+k))_prim + (1/(0+k))_sec = 2/k
133141
# p_only_score = (1/(1+k))_prim = 1/(k+1)
134142
# s_only_score = (1/(1+k))_sec = 1/(k+1)
135-
results = reciprocal_rank_fusion(primary, secondary, rrf_k=rrf_k)
143+
results = reciprocal_rank_fusion(primary, secondary, rrf_k=rrf_k) # type: ignore
136144
assert len(results) == 3
137145
assert results[0]["id_val"] == "common"
138146
assert results[0]["distance"] == pytest.approx(2.0 / rrf_k)
@@ -143,32 +151,31 @@ def test_mixed_results_default_k(self):
143151
for score in next_scores:
144152
assert score == pytest.approx(1.0 / (1 + rrf_k))
145153

146-
def test_fetch_top_k_rrf(self):
154+
def test_fetch_top_k_rrf(self) -> None:
147155
primary = [get_row(f"p{i}", (10 - i) / 10.0) for i in range(5)]
148-
secondary = []
149156
rrf_k = 1
150-
results = reciprocal_rank_fusion(primary, secondary, rrf_k=rrf_k, fetch_top_k=2)
157+
results = reciprocal_rank_fusion(primary, [], rrf_k=rrf_k, fetch_top_k=2) # type: ignore
151158
assert len(results) == 2
152159
assert results[0]["id_val"] == "p0"
153160
assert results[0]["distance"] == pytest.approx(1.0 / (0 + rrf_k))
154161
assert results[1]["id_val"] == "p1"
155162
assert results[1]["distance"] == pytest.approx(1.0 / (1 + rrf_k))
156163

157-
def test_rrf_content_preservation(self):
164+
def test_rrf_content_preservation(self) -> None:
158165
primary = [get_row("doc1", 0.9, content="Primary Content")]
159166
secondary = [get_row("doc1", 0.8, content="Secondary Content")]
160167
# RRF processes primary then secondary. If a doc is in both,
161168
# the content from the secondary list will overwrite primary's.
162-
results = reciprocal_rank_fusion(primary, secondary, rrf_k=60)
169+
results = reciprocal_rank_fusion(primary, secondary, rrf_k=60) # type: ignore
163170
assert len(results) == 1
164171
assert results[0]["id_val"] == "doc1"
165172
assert results[0]["content_field"] == "Secondary Content"
166173

167174
# If only in primary
168-
results_prim_only = reciprocal_rank_fusion(primary, [], rrf_k=60)
175+
results_prim_only = reciprocal_rank_fusion(primary, [], rrf_k=60) # type: ignore
169176
assert results_prim_only[0]["content_field"] == "Primary Content"
170177

171-
def test_reordering_from_inputs_rrf(self):
178+
def test_reordering_from_inputs_rrf(self) -> None:
172179
"""
173180
Tests that RRF fused ranking can be different from both primary and secondary
174181
input rankings.
@@ -190,15 +197,15 @@ def test_reordering_from_inputs_rrf(self):
190197
# docA_score = 1/(0+1) [P] + 1/(2+1) [S] = 1 + 1/3 = 4/3
191198
# docB_score = 1/(1+1) [P] + 1/(1+1) [S] = 1/2 + 1/2 = 1
192199
# docC_score = 1/(2+1) [P] + 1/(0+1) [S] = 1/3 + 1 = 4/3
193-
results = reciprocal_rank_fusion(primary, secondary, rrf_k=rrf_k)
200+
results = reciprocal_rank_fusion(primary, secondary, rrf_k=rrf_k) # type: ignore
194201
assert len(results) == 3
195202
assert {results[0]["id_val"], results[1]["id_val"]} == {"docA", "docC"}
196203
assert results[0]["distance"] == pytest.approx(4.0 / 3.0)
197204
assert results[1]["distance"] == pytest.approx(4.0 / 3.0)
198205
assert results[2]["id_val"] == "docB"
199206
assert results[2]["distance"] == pytest.approx(1.0)
200207

201-
def test_reordering_from_inputs_weighted_sum(self):
208+
def test_reordering_from_inputs_weighted_sum(self) -> None:
202209
"""
203210
Tests that the fused ranking can be different from both primary and secondary
204211
input rankings.
@@ -214,7 +221,7 @@ def test_reordering_from_inputs_weighted_sum(self):
214221
primary = [get_row("docA", 0.9), get_row("docB", 0.7)]
215222
secondary = [get_row("docB", 0.8), get_row("docA", 0.2)]
216223

217-
results = weighted_sum_ranking(primary, secondary)
224+
results = weighted_sum_ranking(primary, secondary) # type: ignore
218225
assert len(results) == 2
219226
assert results[0]["id_val"] == "docB"
220227
assert results[0]["distance"] == pytest.approx(0.75)

0 commit comments

Comments
 (0)