@@ -27,82 +27,127 @@ def get_row(doc_id: str, score: float, content: str = "content") -> RowMapping:
2727
2828class  TestWeightedSumRanking :
2929    def  test_empty_inputs (self ) ->  None :
30+         """Tests that the function handles empty inputs gracefully.""" 
3031        results  =  weighted_sum_ranking ([], [])
3132        assert  results  ==  []
3233
33-     def  test_primary_only (self ) ->  None :
34+     def  test_primary_only_cosine_default (self ) ->  None :
35+         """Tests ranking with only primary results using default cosine distance.""" 
3436        primary  =  [get_row ("p1" , 0.8 ), get_row ("p2" , 0.6 )]
35-         # Expected scores: p1 = 0.8 * 0.5 = 0.4, p2 = 0.6 * 0.5 = 0.3 
36-         results  =  weighted_sum_ranking (  # type: ignore 
37+         # --- Calculation (Cosine = lower is better) --- 
38+         # Scores: [0.8, 0.6]. Range: 0.2. Min: 0.6. 
39+         # p1 norm: 1.0 - ((0.8 - 0.6) / 0.2) = 0.0 
40+         # p2 norm: 1.0 - ((0.6 - 0.6) / 0.2) = 1.0 
41+         # Weighted (0.5): p1 = 0.0, p2 = 0.5 
42+         # Order: p2, p1 
43+         results  =  weighted_sum_ranking (
3744            primary ,  # type: ignore 
3845            [],
39-             primary_results_weight = 0.5 ,
40-             secondary_results_weight = 0.5 ,
4146        )
4247        assert  len (results ) ==  2 
43-         assert  results [0 ]["id_val" ] ==  "p1 " 
44-         assert  results [0 ]["distance" ] ==  pytest .approx (0.4  )
45-         assert  results [1 ]["id_val" ] ==  "p2 " 
46-         assert  results [1 ]["distance" ] ==  pytest .approx (0.3  )
48+         assert  results [0 ]["id_val" ] ==  "p2 " 
49+         assert  results [0 ]["distance" ] ==  pytest .approx (0.5  )
50+         assert  results [1 ]["id_val" ] ==  "p1 " 
51+         assert  results [1 ]["distance" ] ==  pytest .approx (0.0  )
4752
4853    def  test_secondary_only (self ) ->  None :
49-         secondary  =  [get_row ("s1" , 0.9 ), get_row ("s2" , 0.7 )]
50-         # Expected scores: s1 = 0.9 * 0.5 = 0.45, s2 = 0.7 * 0.5 = 0.35 
54+         """Tests ranking with only secondary (keyword) results.""" 
55+         secondary  =  [get_row ("s1" , 15.0 ), get_row ("s2" , 5.0 )]
56+         # --- Calculation (Keyword = higher is better) --- 
57+         # Scores: [15.0, 5.0]. Range: 10.0. Min: 5.0. 
58+         # s1 norm: (15.0 - 5.0) / 10.0 = 1.0 
59+         # s2 norm: (5.0 - 5.0) / 10.0 = 0.0 
60+         # Weighted (0.5): s1 = 0.5, s2 = 0.0 
61+         # Order: s1, s2 
5162        results  =  weighted_sum_ranking (
5263            [],
5364            secondary ,  # type: ignore 
54-             primary_results_weight = 0.5 ,
55-             secondary_results_weight = 0.5 ,
5665        )
5766        assert  len (results ) ==  2 
5867        assert  results [0 ]["id_val" ] ==  "s1" 
59-         assert  results [0 ]["distance" ] ==  pytest .approx (0.45  )
68+         assert  results [0 ]["distance" ] ==  pytest .approx (0.5  )
6069        assert  results [1 ]["id_val" ] ==  "s2" 
61-         assert  results [1 ]["distance" ] ==  pytest .approx (0.35  )
70+         assert  results [1 ]["distance" ] ==  pytest .approx (0.0  )
6271
63-     def  test_mixed_results_default_weights (self ) ->  None :
72+     def  test_mixed_results_cosine (self ) ->  None :
73+         """Tests combining cosine (lower is better) and keyword (higher is better) scores.""" 
6474        primary  =  [get_row ("common" , 0.8 ), get_row ("p_only" , 0.7 )]
65-         secondary  =  [get_row ("common" , 0.9 ), get_row ("s_only" , 0.6 )]
66-         # Weights are 0.5, 0.5 
67-         # common_score = (0.8 * 0.5) + (0.9 * 0.5) = 0.4 + 0.45 = 0.85 
68-         # p_only_score = (0.7 * 0.5) = 0.35 
69-         # s_only_score = (0.6 * 0.5) = 0.30 
70-         # Order: common (0.85), p_only (0.35), s_only (0.30) 
71- 
72-         results  =  weighted_sum_ranking (primary , secondary )  # type: ignore 
75+         secondary  =  [get_row ("common" , 9.0 ), get_row ("s_only" , 6.0 )]
76+         # --- Calculation --- 
77+         # Primary norm (inverted): common=0.0, p_only=1.0 
78+         # Secondary norm: common=1.0, s_only=0.0 
79+         # Weighted (0.5): 
80+         # common = (0.0 * 0.5) + (1.0 * 0.5) = 0.5 
81+         # p_only = (1.0 * 0.5) + 0 = 0.5 
82+         # s_only = 0 + (0.0 * 0.5) = 0.0 
83+         results  =  weighted_sum_ranking (
84+             primary ,  # type: ignore 
85+             secondary ,  # type: ignore 
86+         )
7387        assert  len (results ) ==  3 
74-         assert  results [0 ]["id_val" ] ==  "common" 
75-         assert  results [0 ]["distance" ] ==  pytest .approx (0.85 )
76-         assert  results [1 ]["id_val" ] ==  "p_only" 
77-         assert  results [1 ]["distance" ] ==  pytest .approx (0.35 )
88+         # Check that the top two results have the correct score and IDs (order may vary) 
89+         top_ids  =  {res ["id_val" ] for  res  in  results [:2 ]}
90+         assert  top_ids  ==  {"common" , "p_only" }
91+         assert  results [0 ]["distance" ] ==  pytest .approx (0.5 )
92+         assert  results [1 ]["distance" ] ==  pytest .approx (0.5 )
7893        assert  results [2 ]["id_val" ] ==  "s_only" 
79-         assert  results [2 ]["distance" ] ==  pytest .approx (0.30  )
94+         assert  results [2 ]["distance" ] ==  pytest .approx (0.0  )
8095
81-     def  test_mixed_results_custom_weights (self ) ->  None :
82-         primary  =  [get_row ("d1" , 1.0 )]  # p_w=0.2 -> 0.2 
83-         secondary  =  [get_row ("d1" , 0.5 )]  # s_w=0.8 -> 0.4 
84-         # Expected: d1_score = (1.0 * 0.2) + (0.5 * 0.8) = 0.2 + 0.4 = 0.6 
96+     def  test_primary_max_inner_product (self ) ->  None :
97+         """Tests using MAX_INNER_PRODUCT (higher is better) for primary search.""" 
98+         primary  =  [get_row ("best" , 0.9 ), get_row ("worst" , 0.1 )]
99+         secondary  =  [get_row ("best" , 20.0 ), get_row ("worst" , 5.0 )]
100+         # --- Calculation --- 
101+         # Primary norm (NOT inverted): best=1.0, worst=0.0 
102+         # Secondary norm: best=1.0, worst=0.0 
103+         # Weighted (0.5): 
104+         # best = (1.0 * 0.5) + (1.0 * 0.5) = 1.0 
105+         # worst = (0.0 * 0.5) + (0.0 * 0.5) = 0.0 
106+         results  =  weighted_sum_ranking (
107+             primary ,  # type: ignore 
108+             secondary ,  # type: ignore 
109+             distance_strategy = DistanceStrategy .INNER_PRODUCT ,
110+         )
111+         assert  len (results ) ==  2 
112+         assert  results [0 ]["id_val" ] ==  "best" 
113+         assert  results [0 ]["distance" ] ==  pytest .approx (1.0 )
114+         assert  results [1 ]["id_val" ] ==  "worst" 
115+         assert  results [1 ]["distance" ] ==  pytest .approx (0.0 )
85116
117+     def  test_primary_euclidean (self ) ->  None :
118+         """Tests using EUCLIDEAN (lower is better) for primary search.""" 
119+         primary  =  [get_row ("closer" , 10.5 ), get_row ("farther" , 25.5 )]
120+         secondary  =  [get_row ("closer" , 100.0 ), get_row ("farther" , 10.0 )]
121+         # --- Calculation --- 
122+         # Primary norm (inverted): closer=1.0, farther=0.0 
123+         # Secondary norm: closer=1.0, farther=0.0 
124+         # Weighted (0.5): 
125+         # closer = (1.0 * 0.5) + (1.0 * 0.5) = 1.0 
126+         # farther = (0.0 * 0.5) + (0.0 * 0.5) = 0.0 
86127        results  =  weighted_sum_ranking (
87128            primary ,  # type: ignore 
88129            secondary ,  # type: ignore 
89-             primary_results_weight = 0.2 ,
90-             secondary_results_weight = 0.8 ,
130+             distance_strategy = DistanceStrategy .EUCLIDEAN ,
91131        )
92-         assert  len (results ) ==  1 
93-         assert  results [0 ]["id_val" ] ==  "d1" 
94-         assert  results [0 ]["distance" ] ==  pytest .approx (0.6 )
132+         assert  len (results ) ==  2 
133+         assert  results [0 ]["id_val" ] ==  "closer" 
134+         assert  results [0 ]["distance" ] ==  pytest .approx (1.0 )
135+         assert  results [1 ]["id_val" ] ==  "farther" 
136+         assert  results [1 ]["distance" ] ==  pytest .approx (0.0 )
95137
96138    def  test_fetch_top_k (self ) ->  None :
139+         """Tests that fetch_top_k correctly limits the number of results.""" 
97140        primary  =  [get_row (f"p{ i }  " , (10  -  i ) /  10.0 ) for  i  in  range (5 )]
98-         # Scores: 1.0, 0.9, 0.8, 0.7, 0.6 
99-         # Weighted (0.5): 0.5, 0.45, 0.4, 0.35, 0.3 
100-         results  =  weighted_sum_ranking (primary , [], fetch_top_k = 2 )  # type: ignore 
141+         # p0=1.0, p1=0.9, p2=0.8, p3=0.7, p4=0.6 
142+         # The best scores (lowest distance) are p4 and p3 
143+         results  =  weighted_sum_ranking (
144+             primary ,  # type: ignore 
145+             [],
146+             fetch_top_k = 2 ,
147+         )
101148        assert  len (results ) ==  2 
102-         assert  results [0 ]["id_val" ] ==  "p0" 
103-         assert  results [0 ]["distance" ] ==  pytest .approx (0.5 )
104-         assert  results [1 ]["id_val" ] ==  "p1" 
105-         assert  results [1 ]["distance" ] ==  pytest .approx (0.45 )
149+         assert  results [0 ]["id_val" ] ==  "p4"   # Has the best normalized score 
150+         assert  results [1 ]["id_val" ] ==  "p3" 
106151
107152
108153class  TestReciprocalRankFusion :
0 commit comments