1919
2020package org .elasticsearch .index .similarity ;
2121
22+ import org .apache .logging .log4j .LogManager ;
23+ import org .apache .lucene .index .FieldInvertState ;
24+ import org .apache .lucene .index .IndexOptions ;
25+ import org .apache .lucene .search .CollectionStatistics ;
26+ import org .apache .lucene .search .Explanation ;
27+ import org .apache .lucene .search .TermStatistics ;
2228import org .apache .lucene .search .similarities .BM25Similarity ;
2329import org .apache .lucene .search .similarities .BooleanSimilarity ;
2430import org .apache .lucene .search .similarities .ClassicSimilarity ;
2531import org .apache .lucene .search .similarities .PerFieldSimilarityWrapper ;
2632import org .apache .lucene .search .similarities .Similarity ;
33+ import org .apache .lucene .search .similarities .Similarity .SimScorer ;
34+ import org .apache .lucene .util .BytesRef ;
2735import org .elasticsearch .Version ;
2836import org .elasticsearch .common .TriFunction ;
2937import org .elasticsearch .common .logging .DeprecationLogger ;
30- import org .elasticsearch .common .logging .Loggers ;
3138import org .elasticsearch .common .settings .Settings ;
3239import org .elasticsearch .index .AbstractIndexComponent ;
3340import org .elasticsearch .index .IndexModule ;
4451
4552public final class SimilarityService extends AbstractIndexComponent {
4653
47- private static final DeprecationLogger DEPRECATION_LOGGER = new DeprecationLogger (Loggers .getLogger (SimilarityService .class ));
54+ private static final DeprecationLogger DEPRECATION_LOGGER = new DeprecationLogger (LogManager .getLogger (SimilarityService .class ));
4855 public static final String DEFAULT_SIMILARITY = "BM25" ;
4956 private static final String CLASSIC_SIMILARITY = "classic" ;
5057 private static final Map <String , Function <Version , Supplier <Similarity >>> DEFAULTS ;
@@ -131,8 +138,14 @@ public SimilarityService(IndexSettings indexSettings, ScriptService scriptServic
131138 }
132139 TriFunction <Settings , Version , ScriptService , Similarity > defaultFactory = BUILT_IN .get (typeName );
133140 TriFunction <Settings , Version , ScriptService , Similarity > factory = similarities .getOrDefault (typeName , defaultFactory );
134- final Similarity similarity = factory .apply (providerSettings , indexSettings .getIndexVersionCreated (), scriptService );
135- providers .put (name , () -> similarity );
141+ Similarity similarity = factory .apply (providerSettings , indexSettings .getIndexVersionCreated (), scriptService );
142+ validateSimilarity (indexSettings .getIndexVersionCreated (), similarity );
143+ if (BUILT_IN .containsKey (typeName ) == false || "scripted" .equals (typeName )) {
144+ // We don't trust custom similarities
145+ similarity = new NonNegativeScoresSimilarity (similarity );
146+ }
147+ final Similarity similarityF = similarity ; // like similarity but final
148+ providers .put (name , () -> similarityF );
136149 }
137150 for (Map .Entry <String , Function <Version , Supplier <Similarity >>> entry : DEFAULTS .entrySet ()) {
138151 providers .put (entry .getKey (), entry .getValue ().apply (indexSettings .getIndexVersionCreated ()));
@@ -151,7 +164,7 @@ public Similarity similarity(MapperService mapperService) {
151164 defaultSimilarity ;
152165 }
153166
154-
167+
155168 public SimilarityProvider getSimilarity (String name ) {
156169 Supplier <Similarity > sim = similarities .get (name );
157170 if (sim == null ) {
@@ -182,4 +195,80 @@ public Similarity get(String name) {
182195 return (fieldType != null && fieldType .similarity () != null ) ? fieldType .similarity ().get () : defaultSimilarity ;
183196 }
184197 }
198+
199+ static void validateSimilarity (Version indexCreatedVersion , Similarity similarity ) {
200+ validateScoresArePositive (indexCreatedVersion , similarity );
201+ validateScoresDoNotDecreaseWithFreq (indexCreatedVersion , similarity );
202+ validateScoresDoNotIncreaseWithNorm (indexCreatedVersion , similarity );
203+ }
204+
205+ private static void validateScoresArePositive (Version indexCreatedVersion , Similarity similarity ) {
206+ CollectionStatistics collectionStats = new CollectionStatistics ("some_field" , 1200 , 1100 , 3000 , 2000 );
207+ TermStatistics termStats = new TermStatistics (new BytesRef ("some_value" ), 100 , 130 );
208+ SimScorer scorer = similarity .scorer (2f , collectionStats , termStats );
209+ FieldInvertState state = new FieldInvertState (indexCreatedVersion .major , "some_field" ,
210+ IndexOptions .DOCS_AND_FREQS , 20 , 20 , 0 , 50 , 10 , 3 ); // length = 20, no overlap
211+ final long norm = similarity .computeNorm (state );
212+ for (int freq = 1 ; freq <= 10 ; ++freq ) {
213+ float score = scorer .score (freq , norm );
214+ if (score < 0 ) {
215+ fail (indexCreatedVersion , "Similarities should not return negative scores:\n " +
216+ scorer .explain (Explanation .match (freq , "term freq" ), norm ));
217+ }
218+ }
219+ }
220+
221+ private static void validateScoresDoNotDecreaseWithFreq (Version indexCreatedVersion , Similarity similarity ) {
222+ CollectionStatistics collectionStats = new CollectionStatistics ("some_field" , 1200 , 1100 , 3000 , 2000 );
223+ TermStatistics termStats = new TermStatistics (new BytesRef ("some_value" ), 100 , 130 );
224+ SimScorer scorer = similarity .scorer (2f , collectionStats , termStats );
225+ FieldInvertState state = new FieldInvertState (indexCreatedVersion .major , "some_field" ,
226+ IndexOptions .DOCS_AND_FREQS , 20 , 20 , 0 , 50 , 10 , 3 ); // length = 20, no overlap
227+ final long norm = similarity .computeNorm (state );
228+ float previousScore = 0 ;
229+ for (int freq = 1 ; freq <= 10 ; ++freq ) {
230+ float score = scorer .score (freq , norm );
231+ if (score < previousScore ) {
232+ fail (indexCreatedVersion , "Similarity scores should not decrease when term frequency increases:\n " +
233+ scorer .explain (Explanation .match (freq - 1 , "term freq" ), norm ) + "\n " +
234+ scorer .explain (Explanation .match (freq , "term freq" ), norm ));
235+ }
236+ previousScore = score ;
237+ }
238+ }
239+
240+ private static void validateScoresDoNotIncreaseWithNorm (Version indexCreatedVersion , Similarity similarity ) {
241+ CollectionStatistics collectionStats = new CollectionStatistics ("some_field" , 1200 , 1100 , 3000 , 2000 );
242+ TermStatistics termStats = new TermStatistics (new BytesRef ("some_value" ), 100 , 130 );
243+ SimScorer scorer = similarity .scorer (2f , collectionStats , termStats );
244+
245+ long previousNorm = 0 ;
246+ float previousScore = Float .MAX_VALUE ;
247+ for (int length = 1 ; length <= 10 ; ++length ) {
248+ FieldInvertState state = new FieldInvertState (indexCreatedVersion .major , "some_field" ,
249+ IndexOptions .DOCS_AND_FREQS , length , length , 0 , 50 , 10 , 3 ); // length = 20, no overlap
250+ final long norm = similarity .computeNorm (state );
251+ if (Long .compareUnsigned (previousNorm , norm ) > 0 ) {
252+ // esoteric similarity, skip this check
253+ break ;
254+ }
255+ float score = scorer .score (1 , norm );
256+ if (score > previousScore ) {
257+ fail (indexCreatedVersion , "Similarity scores should not increase when norm increases:\n " +
258+ scorer .explain (Explanation .match (1 , "term freq" ), norm - 1 ) + "\n " +
259+ scorer .explain (Explanation .match (1 , "term freq" ), norm ));
260+ }
261+ previousScore = score ;
262+ previousNorm = norm ;
263+ }
264+ }
265+
266+ private static void fail (Version indexCreatedVersion , String message ) {
267+ if (indexCreatedVersion .onOrAfter (Version .V_7_0_0_alpha1 )) {
268+ throw new IllegalArgumentException (message );
269+ } else if (indexCreatedVersion .onOrAfter (Version .V_6_5_0 )) {
270+ DEPRECATION_LOGGER .deprecated (message );
271+ }
272+ }
273+
185274}
0 commit comments