diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RankEvalIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RankEvalIT.java index 31b0fe116cd3d..7de2af477d9f8 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RankEvalIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RankEvalIT.java @@ -28,6 +28,7 @@ import org.elasticsearch.index.rankeval.ExpectedReciprocalRank; import org.elasticsearch.index.rankeval.MeanReciprocalRank; import org.elasticsearch.index.rankeval.PrecisionAtK; +import org.elasticsearch.index.rankeval.RecallAtK; import org.elasticsearch.index.rankeval.RankEvalRequest; import org.elasticsearch.index.rankeval.RankEvalResponse; import org.elasticsearch.index.rankeval.RankEvalSpec; @@ -130,9 +131,9 @@ private static List createTestEvaluationSpec() { */ public void testMetrics() throws IOException { List specifications = createTestEvaluationSpec(); - List> metrics = Arrays.asList(PrecisionAtK::new, MeanReciprocalRank::new, DiscountedCumulativeGain::new, - () -> new ExpectedReciprocalRank(1)); - double expectedScores[] = new double[] {0.4285714285714286, 0.75, 1.6408962261063627, 0.4407738095238095}; + List> metrics = Arrays.asList(PrecisionAtK::new, RecallAtK::new, + MeanReciprocalRank::new, DiscountedCumulativeGain::new, () -> new ExpectedReciprocalRank(1)); + double expectedScores[] = new double[] {0.4285714285714286, 1.0, 0.75, 1.6408962261063627, 0.4407738095238095}; int i = 0; for (Supplier metricSupplier : metrics) { RankEvalSpec spec = new RankEvalSpec(specifications, metricSupplier.get()); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index e7bc5d4a923b1..a8e8037930741 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -98,6 +98,7 @@ import org.elasticsearch.index.rankeval.MeanReciprocalRank; import org.elasticsearch.index.rankeval.MetricDetail; import org.elasticsearch.index.rankeval.PrecisionAtK; +import org.elasticsearch.index.rankeval.RecallAtK; import org.elasticsearch.join.aggregations.ChildrenAggregationBuilder; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHits; @@ -696,7 +697,7 @@ public void testDefaultNamedXContents() { public void testProvidedNamedXContents() { List namedXContents = RestHighLevelClient.getProvidedNamedXContents(); - assertEquals(57, namedXContents.size()); + assertEquals(59, namedXContents.size()); Map, Integer> categories = new HashMap<>(); List names = new ArrayList<>(); for (NamedXContentRegistry.Entry namedXContent : namedXContents) { @@ -710,13 +711,15 @@ public void testProvidedNamedXContents() { assertEquals(Integer.valueOf(3), categories.get(Aggregation.class)); assertTrue(names.contains(ChildrenAggregationBuilder.NAME)); assertTrue(names.contains(MatrixStatsAggregationBuilder.NAME)); - assertEquals(Integer.valueOf(4), categories.get(EvaluationMetric.class)); + assertEquals(Integer.valueOf(5), categories.get(EvaluationMetric.class)); assertTrue(names.contains(PrecisionAtK.NAME)); + assertTrue(names.contains(RecallAtK.NAME)); assertTrue(names.contains(DiscountedCumulativeGain.NAME)); assertTrue(names.contains(MeanReciprocalRank.NAME)); assertTrue(names.contains(ExpectedReciprocalRank.NAME)); - assertEquals(Integer.valueOf(4), categories.get(MetricDetail.class)); + assertEquals(Integer.valueOf(5), categories.get(MetricDetail.class)); assertTrue(names.contains(PrecisionAtK.NAME)); + assertTrue(names.contains(RecallAtK.NAME)); assertTrue(names.contains(MeanReciprocalRank.NAME)); assertTrue(names.contains(DiscountedCumulativeGain.NAME)); assertTrue(names.contains(ExpectedReciprocalRank.NAME)); diff --git a/docs/reference/search/rank-eval.asciidoc b/docs/reference/search/rank-eval.asciidoc index 57488f2459516..3e9311558fa19 100644 --- a/docs/reference/search/rank-eval.asciidoc +++ b/docs/reference/search/rank-eval.asciidoc @@ -201,20 +201,21 @@ will be used. The following metrics are supported: [[k-precision]] ===== Precision at K (P@k) -This metric measures the number of relevant results in the top k search results. -It's a form of the well-known -https://en.wikipedia.org/wiki/Information_retrieval#Precision[Precision] metric -that only looks at the top k documents. It is the fraction of relevant documents -in those first k results. A precision at 10 (P@10) value of 0.6 then means six -out of the 10 top hits are relevant with respect to the user's information need. - -P@k works well as a simple evaluation metric that has the benefit of being easy -to understand and explain. Documents in the collection need to be rated as either -relevant or irrelevant with respect to the current query. P@k does not take -into account the position of the relevant documents within the top k results, -so a ranking of ten results that contains one relevant result in position 10 is -equally as good as a ranking of ten results that contains one relevant result -in position 1. +This metric measures the proportion of relevant results in the top k search results. +It's a form of the well-known +https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Precision[Precision] +metric that only looks at the top k documents. It is the fraction of relevant +documents in those first k results. A precision at 10 (P@10) value of 0.6 then +means 6 out of the 10 top hits are relevant with respect to the user's +information need. + +P@k works well as a simple evaluation metric that has the benefit of being easy +to understand and explain. Documents in the collection need to be rated as either +relevant or irrelevant with respect to the current query. P@k is a set-based +metric and does not take into account the position of the relevant documents +within the top k results, so a ranking of ten results that contains one +relevant result in position 10 is equally as good as a ranking of ten results +that contains one relevant result in position 1. [source,console] -------------------------------- @@ -251,6 +252,58 @@ If set to 'true', unlabeled documents are ignored and neither count as relevant |======================================================================= +[float] +[[k-recall]] +===== Recall at K (R@k) + +This metric measures the total number of relevant results in the top k search +results. It's a form of the well-known +https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Recall[Recall] +metric. It is the fraction of relevant documents in those first k results +relative to all possible relevant results. A recall at 10 (R@10) value of 0.5 then +means 4 out of 8 relevant documents, with respect to the user's information +need, were retrieved in the 10 top hits. + +R@k works well as a simple evaluation metric that has the benefit of being easy +to understand and explain. Documents in the collection need to be rated as either +relevant or irrelevant with respect to the current query. R@k is a set-based +metric and does not take into account the position of the relevant documents +within the top k results, so a ranking of ten results that contains one +relevant result in position 10 is equally as good as a ranking of ten results +that contains one relevant result in position 1. + +[source,console] +-------------------------------- +GET /twitter/_rank_eval +{ + "requests": [ + { + "id": "JFK query", + "request": { "query": { "match_all": {}}}, + "ratings": [] + }], + "metric": { + "recall": { + "k" : 20, + "relevant_rating_threshold": 1 + } + } +} +-------------------------------- +// TEST[setup:twitter] + +The `recall` metric takes the following optional parameters + +[cols="<,<",options="header",] +|======================================================================= +|Parameter |Description +|`k` |sets the maximum number of documents retrieved per query. This value will act in place of the usual `size` parameter +in the query. Defaults to 10. +|`relevant_rating_threshold` |sets the rating threshold above which documents are considered to be +"relevant". Defaults to `1`. +|======================================================================= + + [float] ===== Mean reciprocal rank diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/MetricDetail.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/MetricDetail.java index da77825c0c867..0273658d82b7d 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/MetricDetail.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/MetricDetail.java @@ -26,7 +26,7 @@ import java.io.IOException; /** - * Details about a specific {@link EvaluationMetric} that should be included in the resonse. + * Details about a specific {@link EvaluationMetric} that should be included in the response. */ public interface MetricDetail extends ToXContentObject, NamedWriteable { diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/PrecisionAtK.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/PrecisionAtK.java index bb5a579ead6ee..71c121b352be9 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/PrecisionAtK.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/PrecisionAtK.java @@ -40,10 +40,10 @@ /** * Metric implementing Precision@K - * (https://en.wikipedia.org/wiki/Information_retrieval#Precision_at_K).
+ * (https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Precision).
* By default documents with a rating equal or bigger than 1 are considered to - * be "relevant" for this calculation. This value can be changes using the - * relevant_rating_threshold` parameter.
+ * be "relevant" for this calculation. This value can be changed using the + * `relevant_rating_threshold` parameter.
* The `ignore_unlabeled` parameter (default to false) controls if unrated * documents should be ignored. * The `k` parameter (defaults to 10) controls the search window size. @@ -52,19 +52,21 @@ public class PrecisionAtK implements EvaluationMetric { public static final String NAME = "precision"; - private static final ParseField RELEVANT_RATING_FIELD = new ParseField("relevant_rating_threshold"); + private static final int DEFAULT_RELEVANT_RATING_THRESHOLD = 1; + private static final boolean DEFAULT_IGNORE_UNLABELED = false; + private static final int DEFAULT_K = 10; + + private static final ParseField RELEVANT_RATING_THRESHOLD_FIELD = new ParseField("relevant_rating_threshold"); private static final ParseField IGNORE_UNLABELED_FIELD = new ParseField("ignore_unlabeled"); private static final ParseField K_FIELD = new ParseField("k"); - private static final int DEFAULT_K = 10; - + private final int relevantRatingThreshold; private final boolean ignoreUnlabeled; - private final int relevantRatingThreshhold; private final int k; /** * Metric implementing Precision@K. - * @param threshold + * @param relevantRatingThreshold * ratings equal or above this value will be considered relevant. * @param ignoreUnlabeled * Controls how unlabeled documents in the search hits are treated. @@ -74,53 +76,67 @@ public class PrecisionAtK implements EvaluationMetric { * @param k * controls the window size for the search results the metric takes into account */ - public PrecisionAtK(int threshold, boolean ignoreUnlabeled, int k) { - if (threshold < 0) { + public PrecisionAtK(int relevantRatingThreshold, boolean ignoreUnlabeled, int k) { + if (relevantRatingThreshold < 0) { throw new IllegalArgumentException("Relevant rating threshold for precision must be positive integer."); } if (k <= 0) { throw new IllegalArgumentException("Window size k must be positive."); } - this.relevantRatingThreshhold = threshold; + this.relevantRatingThreshold = relevantRatingThreshold; this.ignoreUnlabeled = ignoreUnlabeled; this.k = k; } + public PrecisionAtK(boolean ignoreUnlabeled) { + this(DEFAULT_RELEVANT_RATING_THRESHOLD, ignoreUnlabeled, DEFAULT_K); + } + public PrecisionAtK() { - this(1, false, DEFAULT_K); + this(DEFAULT_RELEVANT_RATING_THRESHOLD, DEFAULT_IGNORE_UNLABELED, DEFAULT_K); + } + + PrecisionAtK(StreamInput in) throws IOException { + this(in.readVInt(), in.readBoolean(), in.readVInt()); } - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, - args -> { - Integer threshHold = (Integer) args[0]; - Boolean ignoreUnlabeled = (Boolean) args[1]; - Integer k = (Integer) args[2]; - return new PrecisionAtK(threshHold == null ? 1 : threshHold, - ignoreUnlabeled == null ? false : ignoreUnlabeled, - k == null ? DEFAULT_K : k); - }); + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, args -> { + Integer relevantRatingThreshold = (Integer) args[0]; + Boolean ignoreUnlabeled = (Boolean) args[1]; + Integer k = (Integer) args[2]; + return new PrecisionAtK( + relevantRatingThreshold == null ? DEFAULT_RELEVANT_RATING_THRESHOLD : relevantRatingThreshold, + ignoreUnlabeled == null ? DEFAULT_IGNORE_UNLABELED : ignoreUnlabeled, + k == null ? DEFAULT_K : k); + }); static { - PARSER.declareInt(optionalConstructorArg(), RELEVANT_RATING_FIELD); + PARSER.declareInt(optionalConstructorArg(), RELEVANT_RATING_THRESHOLD_FIELD); PARSER.declareBoolean(optionalConstructorArg(), IGNORE_UNLABELED_FIELD); PARSER.declareInt(optionalConstructorArg(), K_FIELD); } - PrecisionAtK(StreamInput in) throws IOException { - relevantRatingThreshhold = in.readVInt(); - ignoreUnlabeled = in.readBoolean(); - k = in.readVInt(); + public static PrecisionAtK fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); } - int getK() { - return this.k; + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVInt(getRelevantRatingThreshold()); + out.writeBoolean(getIgnoreUnlabeled()); + out.writeVInt(getK()); } @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeVInt(relevantRatingThreshhold); - out.writeBoolean(ignoreUnlabeled); - out.writeVInt(k); + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.startObject(NAME); + builder.field(RELEVANT_RATING_THRESHOLD_FIELD.getPreferredName(), getRelevantRatingThreshold()); + builder.field(IGNORE_UNLABELED_FIELD.getPreferredName(), getIgnoreUnlabeled()); + builder.field(K_FIELD.getPreferredName(), getK()); + builder.endObject(); + builder.endObject(); + return builder; } @Override @@ -133,7 +149,7 @@ public String getWriteableName() { * "relevant" for this metric. Defaults to 1. */ public int getRelevantRatingThreshold() { - return relevantRatingThreshhold; + return relevantRatingThreshold; } /** @@ -143,61 +159,66 @@ public boolean getIgnoreUnlabeled() { return ignoreUnlabeled; } + public int getK() { + return k; + } + @Override public OptionalInt forcedSearchSize() { - return OptionalInt.of(k); + return OptionalInt.of(getK()); } - public static PrecisionAtK fromXContent(XContentParser parser) { - return PARSER.apply(parser, null); + /** + * Binarizes a rating based on the relevant rating threshold. + */ + private boolean isRelevant(int rating) { + return rating >= getRelevantRatingThreshold(); + } + + /** + * Should we count unlabeled documents? This is the inverse of {@link #getIgnoreUnlabeled()}. + */ + private boolean shouldCountUnlabeled() { + return !getIgnoreUnlabeled(); } /** - * Compute precisionAtN based on provided relevant document IDs. + * Compute precision at k based on provided relevant document IDs. * - * @return precision at n for above {@link SearchResult} list. + * @return precision at k for above {@link SearchResult} list. **/ @Override public EvalQueryQuality evaluate(String taskId, SearchHit[] hits, - List ratedDocs) { - int truePositives = 0; - int falsePositives = 0; + List ratedDocs) { + List ratedSearchHits = joinHitsWithRatings(hits, ratedDocs); + + int relevantRetrieved = 0; + int retrieved = 0; + for (RatedSearchHit hit : ratedSearchHits) { OptionalInt rating = hit.getRating(); if (rating.isPresent()) { - if (rating.getAsInt() >= this.relevantRatingThreshhold) { - truePositives++; - } else { - falsePositives++; + retrieved++; + if (isRelevant(rating.getAsInt())) { + relevantRetrieved++; } - } else if (ignoreUnlabeled == false) { - falsePositives++; + } else if (shouldCountUnlabeled()) { + retrieved++; } } + double precision = 0.0; - if (truePositives + falsePositives > 0) { - precision = (double) truePositives / (truePositives + falsePositives); + if (retrieved > 0) { + precision = (double) relevantRetrieved / retrieved; } + EvalQueryQuality evalQueryQuality = new EvalQueryQuality(taskId, precision); - evalQueryQuality.setMetricDetails( - new PrecisionAtK.Detail(truePositives, truePositives + falsePositives)); + evalQueryQuality.setMetricDetails(new PrecisionAtK.Detail(relevantRetrieved, retrieved)); evalQueryQuality.addHitsAndRatings(ratedSearchHits); return evalQueryQuality; } - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.startObject(NAME); - builder.field(RELEVANT_RATING_FIELD.getPreferredName(), this.relevantRatingThreshhold); - builder.field(IGNORE_UNLABELED_FIELD.getPreferredName(), this.ignoreUnlabeled); - builder.field(K_FIELD.getPreferredName(), this.k); - builder.endObject(); - builder.endObject(); - return builder; - } - @Override public final boolean equals(Object obj) { if (this == obj) { @@ -207,20 +228,21 @@ public final boolean equals(Object obj) { return false; } PrecisionAtK other = (PrecisionAtK) obj; - return Objects.equals(relevantRatingThreshhold, other.relevantRatingThreshhold) - && Objects.equals(k, other.k) - && Objects.equals(ignoreUnlabeled, other.ignoreUnlabeled); + return Objects.equals(relevantRatingThreshold, other.relevantRatingThreshold) + && Objects.equals(ignoreUnlabeled, other.ignoreUnlabeled) + && Objects.equals(k, other.k); } @Override public final int hashCode() { - return Objects.hash(relevantRatingThreshhold, ignoreUnlabeled, k); + return Objects.hash(relevantRatingThreshold, ignoreUnlabeled, k); } public static final class Detail implements MetricDetail { - private static final ParseField DOCS_RETRIEVED_FIELD = new ParseField("docs_retrieved"); private static final ParseField RELEVANT_DOCS_RETRIEVED_FIELD = new ParseField("relevant_docs_retrieved"); + private static final ParseField DOCS_RETRIEVED_FIELD = new ParseField("docs_retrieved"); + private int relevantRetrieved; private int retrieved; @@ -230,21 +252,11 @@ public static final class Detail implements MetricDetail { } Detail(StreamInput in) throws IOException { - this.relevantRetrieved = in.readVInt(); - this.retrieved = in.readVInt(); - } - - @Override - public XContentBuilder innerToXContent(XContentBuilder builder, Params params) - throws IOException { - builder.field(RELEVANT_DOCS_RETRIEVED_FIELD.getPreferredName(), relevantRetrieved); - builder.field(DOCS_RETRIEVED_FIELD.getPreferredName(), retrieved); - return builder; + this(in.readVInt(), in.readVInt()); } - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, true, args -> { - return new Detail((Integer) args[0], (Integer) args[1]); - }); + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(NAME, true, args -> new Detail((Integer) args[0], (Integer) args[1])); static { PARSER.declareInt(constructorArg(), RELEVANT_DOCS_RETRIEVED_FIELD); @@ -257,8 +269,16 @@ public static Detail fromXContent(XContentParser parser) { @Override public void writeTo(StreamOutput out) throws IOException { - out.writeVInt(relevantRetrieved); - out.writeVInt(retrieved); + out.writeVLong(relevantRetrieved); + out.writeVLong(retrieved); + } + + @Override + public XContentBuilder innerToXContent(XContentBuilder builder, Params params) + throws IOException { + builder.field(RELEVANT_DOCS_RETRIEVED_FIELD.getPreferredName(), relevantRetrieved); + builder.field(DOCS_RETRIEVED_FIELD.getPreferredName(), retrieved); + return builder; } @Override diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalNamedXContentProvider.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalNamedXContentProvider.java index 7eddcf9dff644..1995e7f42df60 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalNamedXContentProvider.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalNamedXContentProvider.java @@ -31,8 +31,10 @@ public class RankEvalNamedXContentProvider implements NamedXContentProvider { @Override public List getNamedXContentParsers() { List namedXContent = new ArrayList<>(); - namedXContent.add( - new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(PrecisionAtK.NAME), PrecisionAtK::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(PrecisionAtK.NAME), + PrecisionAtK::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(RecallAtK.NAME), + RecallAtK::fromXContent)); namedXContent.add(new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(MeanReciprocalRank.NAME), MeanReciprocalRank::fromXContent)); namedXContent.add(new NamedXContentRegistry.Entry(EvaluationMetric.class, new ParseField(DiscountedCumulativeGain.NAME), @@ -42,6 +44,8 @@ public List getNamedXContentParsers() { namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(PrecisionAtK.NAME), PrecisionAtK.Detail::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(RecallAtK.NAME), + RecallAtK.Detail::fromXContent)); namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(MeanReciprocalRank.NAME), MeanReciprocalRank.Detail::fromXContent)); namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(DiscountedCumulativeGain.NAME), diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java index 9f21a6065900e..33bc5928a3699 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java @@ -58,12 +58,14 @@ public List getRestHandlers(Settings settings, RestController restC public List getNamedWriteables() { List namedWriteables = new ArrayList<>(); namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, PrecisionAtK.NAME, PrecisionAtK::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, RecallAtK.NAME, RecallAtK::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, MeanReciprocalRank.NAME, MeanReciprocalRank::new)); namedWriteables.add( new NamedWriteableRegistry.Entry(EvaluationMetric.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain::new)); namedWriteables.add( new NamedWriteableRegistry.Entry(EvaluationMetric.class, ExpectedReciprocalRank.NAME, ExpectedReciprocalRank::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, PrecisionAtK.NAME, PrecisionAtK.Detail::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, RecallAtK.NAME, RecallAtK.Detail::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, MeanReciprocalRank.NAME, MeanReciprocalRank.Detail::new)); namedWriteables.add( new NamedWriteableRegistry.Entry(MetricDetail.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain.Detail::new)); diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RecallAtK.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RecallAtK.java new file mode 100644 index 0000000000000..3fe2f2c2d82a7 --- /dev/null +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RecallAtK.java @@ -0,0 +1,280 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.index.rankeval; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.SearchHit; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; +import java.util.OptionalInt; + +import javax.naming.directory.SearchResult; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; +import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRatings; + +/** + * Metric implementing Recall@K + * (https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Recall).
+ * By default documents with a rating equal or bigger than 1 are considered to + * be "relevant" for this calculation. This value can be changed using the + * `relevant_rating_threshold` parameter.
+ * The `k` parameter (defaults to 10) controls the search window size. + */ +public class RecallAtK implements EvaluationMetric { + + public static final String NAME = "recall"; + + private static final int DEFAULT_RELEVANT_RATING_THRESHOLD = 1; + private static final int DEFAULT_K = 10; + + private static final ParseField RELEVANT_RATING_THRESHOLD_FIELD = new ParseField("relevant_rating_threshold"); + private static final ParseField K_FIELD = new ParseField("k"); + + private final int relevantRatingThreshold; + private final int k; + + /** + * Metric implementing Recall@K. + * @param relevantRatingThreshold + * ratings equal or above this value will be considered relevant. + * @param k + * controls the window size for the search results the metric takes into account + */ + public RecallAtK(int relevantRatingThreshold, int k) { + if (relevantRatingThreshold < 0) { + throw new IllegalArgumentException("Relevant rating threshold for precision must be positive integer."); + } + if (k <= 0) { + throw new IllegalArgumentException("Window size k must be positive."); + } + this.relevantRatingThreshold = relevantRatingThreshold; + this.k = k; + } + + public RecallAtK() { + this(DEFAULT_RELEVANT_RATING_THRESHOLD, DEFAULT_K); + } + + RecallAtK(StreamInput in) throws IOException { + this(in.readVInt(), in.readVInt()); + } + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, args -> { + Integer relevantRatingThreshold = (Integer) args[0]; + Integer k = (Integer) args[1]; + return new RecallAtK( + relevantRatingThreshold == null ? DEFAULT_RELEVANT_RATING_THRESHOLD : relevantRatingThreshold, + k == null ? DEFAULT_K : k); + }); + + static { + PARSER.declareInt(optionalConstructorArg(), RELEVANT_RATING_THRESHOLD_FIELD); + PARSER.declareInt(optionalConstructorArg(), K_FIELD); + } + + public static RecallAtK fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVInt(getRelevantRatingThreshold()); + out.writeVInt(getK()); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.startObject(NAME); + builder.field(RELEVANT_RATING_THRESHOLD_FIELD.getPreferredName(), getRelevantRatingThreshold()); + builder.field(K_FIELD.getPreferredName(), getK()); + builder.endObject(); + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + /** + * Return the rating threshold above which ratings are considered to be + * "relevant" for this metric. Defaults to 1. + */ + public int getRelevantRatingThreshold() { + return relevantRatingThreshold; + } + + public int getK() { + return k; + } + + @Override + public OptionalInt forcedSearchSize() { + return OptionalInt.of(getK()); + } + + /** + * Binarizes a rating based on the relevant rating threshold. + */ + private boolean isRelevant(int rating) { + return rating >= getRelevantRatingThreshold(); + } + + /** + * Compute recall at k based on provided relevant document IDs. + * + * @return recall at k for above {@link SearchResult} list. + **/ + @Override + public EvalQueryQuality evaluate(String taskId, SearchHit[] hits, + List ratedDocs) { + + List ratedSearchHits = joinHitsWithRatings(hits, ratedDocs); + + int relevantRetrieved = 0; + for (RatedSearchHit hit : ratedSearchHits) { + OptionalInt rating = hit.getRating(); + if (rating.isPresent() && isRelevant(rating.getAsInt())) { + relevantRetrieved++; + } + } + + int relevant = 0; + for (RatedDocument rd : ratedDocs) { + if(isRelevant(rd.getRating())) { + relevant++; + } + } + + double recall = 0.0; + if (relevant > 0) { + recall = (double) relevantRetrieved / relevant; + } + + EvalQueryQuality evalQueryQuality = new EvalQueryQuality(taskId, recall); + evalQueryQuality.setMetricDetails(new RecallAtK.Detail(relevantRetrieved, relevant)); + evalQueryQuality.addHitsAndRatings(ratedSearchHits); + return evalQueryQuality; + } + + @Override + public final boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + RecallAtK other = (RecallAtK) obj; + return Objects.equals(relevantRatingThreshold, other.relevantRatingThreshold) + && Objects.equals(k, other.k); + } + + @Override + public final int hashCode() { + return Objects.hash(relevantRatingThreshold, k); + } + + public static final class Detail implements MetricDetail { + + private static final ParseField RELEVANT_DOCS_RETRIEVED_FIELD = new ParseField("relevant_docs_retrieved"); + private static final ParseField RELEVANT_DOCS_FIELD = new ParseField("relevant_docs"); + private long relevantRetrieved; + private long relevant; + + Detail(long relevantRetrieved, long relevant) { + this.relevantRetrieved = relevantRetrieved; + this.relevant = relevant; + } + + Detail(StreamInput in) throws IOException { + this.relevantRetrieved = in.readVLong(); + this.relevant = in.readVLong(); + } + + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>(NAME, true, args -> new Detail((Integer) args[0], (Integer) args[1])); + + static { + PARSER.declareInt(constructorArg(), RELEVANT_DOCS_RETRIEVED_FIELD); + PARSER.declareInt(constructorArg(), RELEVANT_DOCS_FIELD); + } + + public static Detail fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVLong(relevantRetrieved); + out.writeVLong(relevant); + } + + @Override + public XContentBuilder innerToXContent(XContentBuilder builder, Params params) + throws IOException { + builder.field(RELEVANT_DOCS_RETRIEVED_FIELD.getPreferredName(), relevantRetrieved); + builder.field(RELEVANT_DOCS_FIELD.getPreferredName(), relevant); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + public long getRelevantRetrieved() { + return relevantRetrieved; + } + + public long getRelevant() { + return relevant; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + RecallAtK.Detail other = (RecallAtK.Detail) obj; + return Objects.equals(relevantRetrieved, other.relevantRetrieved) + && Objects.equals(relevant, other.relevant); + } + + @Override + public int hashCode() { + return Objects.hash(relevantRetrieved, relevant); + } + } +} diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/PrecisionAtKTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/PrecisionAtKTests.java index aa5be25f86aef..f7e0b0dc21c73 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/PrecisionAtKTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/PrecisionAtKTests.java @@ -46,25 +46,25 @@ public class PrecisionAtKTests extends ESTestCase { - private static final int IRRELEVANT_RATING_0 = 0; - private static final int RELEVANT_RATING_1 = 1; + private static final int IRRELEVANT_RATING = 0; + private static final int RELEVANT_RATING = 1; - public void testPrecisionAtFiveCalculation() { + public void testCalculation() { List rated = new ArrayList<>(); - rated.add(createRatedDoc("test", "0", RELEVANT_RATING_1)); + rated.add(createRatedDoc("test", "0", RELEVANT_RATING)); EvalQueryQuality evaluated = (new PrecisionAtK()).evaluate("id", toSearchHits(rated, "test"), rated); assertEquals(1, evaluated.metricScore(), 0.00001); assertEquals(1, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved()); assertEquals(1, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRetrieved()); } - public void testPrecisionAtFiveIgnoreOneResult() { + public void testIgnoreOneResult() { List rated = new ArrayList<>(); - rated.add(createRatedDoc("test", "0", RELEVANT_RATING_1)); - rated.add(createRatedDoc("test", "1", RELEVANT_RATING_1)); - rated.add(createRatedDoc("test", "2", RELEVANT_RATING_1)); - rated.add(createRatedDoc("test", "3", RELEVANT_RATING_1)); - rated.add(createRatedDoc("test", "4", IRRELEVANT_RATING_0)); + rated.add(createRatedDoc("test", "0", RELEVANT_RATING)); + rated.add(createRatedDoc("test", "1", RELEVANT_RATING)); + rated.add(createRatedDoc("test", "2", RELEVANT_RATING)); + rated.add(createRatedDoc("test", "3", RELEVANT_RATING)); + rated.add(createRatedDoc("test", "4", IRRELEVANT_RATING)); EvalQueryQuality evaluated = (new PrecisionAtK()).evaluate("id", toSearchHits(rated, "test"), rated); assertEquals((double) 4 / 5, evaluated.metricScore(), 0.00001); assertEquals(4, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved()); @@ -73,10 +73,10 @@ public void testPrecisionAtFiveIgnoreOneResult() { /** * test that the relevant rating threshold can be set to something larger than - * 1. e.g. we set it to 2 here and expect dics 0-2 to be not relevant, doc 3 and - * 4 to be relevant + * 1. e.g. we set it to 2 here and expect docs 0-1 to be not relevant, docs 2-4 + * to be relevant */ - public void testPrecisionAtFiveRelevanceThreshold() { + public void testRelevanceThreshold() { List rated = new ArrayList<>(); rated.add(createRatedDoc("test", "0", 0)); rated.add(createRatedDoc("test", "1", 1)); @@ -92,13 +92,15 @@ public void testPrecisionAtFiveRelevanceThreshold() { public void testPrecisionAtFiveCorrectIndex() { List rated = new ArrayList<>(); - rated.add(createRatedDoc("test_other", "0", RELEVANT_RATING_1)); - rated.add(createRatedDoc("test_other", "1", RELEVANT_RATING_1)); - rated.add(createRatedDoc("test", "0", RELEVANT_RATING_1)); - rated.add(createRatedDoc("test", "1", RELEVANT_RATING_1)); - rated.add(createRatedDoc("test", "2", IRRELEVANT_RATING_0)); + rated.add(createRatedDoc("test_other", "0", RELEVANT_RATING)); + rated.add(createRatedDoc("test_other", "1", RELEVANT_RATING)); + rated.add(createRatedDoc("test", "0", RELEVANT_RATING)); + rated.add(createRatedDoc("test", "1", RELEVANT_RATING)); + rated.add(createRatedDoc("test", "2", IRRELEVANT_RATING)); // the following search hits contain only the last three documents - EvalQueryQuality evaluated = (new PrecisionAtK()).evaluate("id", toSearchHits(rated.subList(2, 5), "test"), rated); + List ratedSubList = rated.subList(2, 5); + PrecisionAtK precisionAtK = new PrecisionAtK(1, false, 5); + EvalQueryQuality evaluated = (precisionAtK).evaluate("id", toSearchHits(ratedSubList, "test"), rated); assertEquals((double) 2 / 3, evaluated.metricScore(), 0.00001); assertEquals(2, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved()); assertEquals(3, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRetrieved()); @@ -106,8 +108,8 @@ public void testPrecisionAtFiveCorrectIndex() { public void testIgnoreUnlabeled() { List rated = new ArrayList<>(); - rated.add(createRatedDoc("test", "0", RELEVANT_RATING_1)); - rated.add(createRatedDoc("test", "1", RELEVANT_RATING_1)); + rated.add(createRatedDoc("test", "0", RELEVANT_RATING)); + rated.add(createRatedDoc("test", "1", RELEVANT_RATING)); // add an unlabeled search hit SearchHit[] searchHits = Arrays.copyOf(toSearchHits(rated, "test"), 3); searchHits[2] = new SearchHit(2, "2", Collections.emptyMap()); @@ -119,7 +121,7 @@ public void testIgnoreUnlabeled() { assertEquals(3, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRetrieved()); // also try with setting `ignore_unlabeled` - PrecisionAtK prec = new PrecisionAtK(1, true, 10); + PrecisionAtK prec = new PrecisionAtK(true); evaluated = prec.evaluate("id", searchHits, rated); assertEquals((double) 2 / 2, evaluated.metricScore(), 0.00001); assertEquals(2, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved()); @@ -138,7 +140,7 @@ public void testNoRatedDocs() throws Exception { assertEquals(5, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRetrieved()); // also try with setting `ignore_unlabeled` - PrecisionAtK prec = new PrecisionAtK(1, true, 10); + PrecisionAtK prec = new PrecisionAtK(true); evaluated = prec.evaluate("id", hits, Collections.emptyList()); assertEquals(0.0d, evaluated.metricScore(), 0.00001); assertEquals(0, ((PrecisionAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved()); diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalSpecTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalSpecTests.java index e0899b451af11..3dace4dea50c0 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalSpecTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RankEvalSpecTests.java @@ -73,6 +73,7 @@ private static List randomList(Supplier randomSupplier) { static RankEvalSpec createTestItem() { Supplier metric = randomFrom(Arrays.asList( () -> PrecisionAtKTests.createTestItem(), + () -> RecallAtKTests.createTestItem(), () -> MeanReciprocalRankTests.createTestItem(), () -> DiscountedCumulativeGainTests.createTestItem())); @@ -149,6 +150,7 @@ private static RankEvalSpec copy(RankEvalSpec original) throws IOException { List namedWriteables = new ArrayList<>(); namedWriteables.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, PrecisionAtK.NAME, PrecisionAtK::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, RecallAtK.NAME, RecallAtK::new)); namedWriteables.add( new NamedWriteableRegistry.Entry(EvaluationMetric.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(EvaluationMetric.class, MeanReciprocalRank.NAME, MeanReciprocalRank::new)); diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RecallAtKTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RecallAtKTests.java new file mode 100644 index 0000000000000..990a7751fd2f4 --- /dev/null +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/RecallAtKTests.java @@ -0,0 +1,248 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.index.rankeval; + +import org.elasticsearch.action.OriginalIndices; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentParseException; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.common.xcontent.json.JsonXContent; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.SearchShardTarget; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode; +import static org.elasticsearch.test.XContentTestUtils.insertRandomFields; +import static org.hamcrest.CoreMatchers.containsString; + +public class RecallAtKTests extends ESTestCase { + + private static final int IRRELEVANT_RATING = 0; + private static final int RELEVANT_RATING = 1; + + public void testCalculation() { + List rated = new ArrayList<>(); + rated.add(createRatedDoc("test", "0", RELEVANT_RATING)); + + EvalQueryQuality evaluated = (new RecallAtK()).evaluate("id", toSearchHits(rated, "test"), rated); + assertEquals(1, evaluated.metricScore(), 0.00001); + assertEquals(1, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved()); + assertEquals(1, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevant()); + } + + public void testIgnoreOneResult() { + List rated = new ArrayList<>(); + rated.add(createRatedDoc("test", "0", RELEVANT_RATING)); + rated.add(createRatedDoc("test", "1", RELEVANT_RATING)); + rated.add(createRatedDoc("test", "2", RELEVANT_RATING)); + rated.add(createRatedDoc("test", "3", RELEVANT_RATING)); + rated.add(createRatedDoc("test", "4", IRRELEVANT_RATING)); + + EvalQueryQuality evaluated = (new RecallAtK()).evaluate("id", toSearchHits(rated, "test"), rated); + assertEquals((double) 4 / 4, evaluated.metricScore(), 0.00001); + assertEquals(4, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved()); + assertEquals(4, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevant()); + } + + /** + * Test that the relevant rating threshold can be set to something larger than + * 1. e.g. we set it to 2 here and expect docs 0-1 to be not relevant, docs 2-4 + * to be relevant, and only 0-3 are hits. + */ + public void testRelevanceThreshold() { + List rated = new ArrayList<>(); + rated.add(createRatedDoc("test", "0", 0)); // not relevant, hit + rated.add(createRatedDoc("test", "1", 1)); // not relevant, hit + rated.add(createRatedDoc("test", "2", 2)); // relevant, hit + rated.add(createRatedDoc("test", "3", 3)); // relevant + rated.add(createRatedDoc("test", "4", 4)); // relevant + + RecallAtK recallAtN = new RecallAtK(2, 5); + + EvalQueryQuality evaluated = recallAtN.evaluate("id", toSearchHits(rated.subList(0,3), "test"), rated); + assertEquals((double) 1 / 3, evaluated.metricScore(), 0.00001); + assertEquals(1, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved()); + assertEquals(3, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevant()); + } + + public void testCorrectIndex() { + List rated = new ArrayList<>(); + rated.add(createRatedDoc("test_other", "0", RELEVANT_RATING)); + rated.add(createRatedDoc("test_other", "1", RELEVANT_RATING)); + rated.add(createRatedDoc("test", "0", RELEVANT_RATING)); + rated.add(createRatedDoc("test", "1", RELEVANT_RATING)); + rated.add(createRatedDoc("test", "2", IRRELEVANT_RATING)); + + // the following search hits contain only the last three documents + List ratedSubList = rated.subList(2, 5); + + EvalQueryQuality evaluated = (new RecallAtK(1, 5)).evaluate("id", toSearchHits(ratedSubList, "test"), rated); + assertEquals((double) 2 / 4, evaluated.metricScore(), 0.00001); + assertEquals(2, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved()); + assertEquals(4, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevant()); + } + + public void testNoRatedDocs() throws Exception { + int k = 5; + SearchHit[] hits = new SearchHit[k]; + for (int i = 0; i < k; i++) { + hits[i] = new SearchHit(i, i + "", Collections.emptyMap()); + hits[i].shard(new SearchShardTarget("testnode", new ShardId("index", "uuid", 0), null, OriginalIndices.NONE)); + } + + EvalQueryQuality evaluated = (new RecallAtK()).evaluate("id", hits, Collections.emptyList()); + assertEquals(0.0d, evaluated.metricScore(), 0.00001); + assertEquals(0, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved()); + assertEquals(0, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevant()); + } + + public void testNoResults() throws Exception { + EvalQueryQuality evaluated = (new RecallAtK()).evaluate("id", new SearchHit[0], Collections.emptyList()); + assertEquals(0.0d, evaluated.metricScore(), 0.00001); + assertEquals(0, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved()); + assertEquals(0, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevant()); + } + + public void testNoResultsWithRatedDocs() throws Exception { + List rated = new ArrayList<>(); + rated.add(createRatedDoc("test", "0", RELEVANT_RATING)); + + EvalQueryQuality evaluated = (new RecallAtK()).evaluate("id", new SearchHit[0], rated); + assertEquals(0.0d, evaluated.metricScore(), 0.00001); + assertEquals(0, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevantRetrieved()); + assertEquals(1, ((RecallAtK.Detail) evaluated.getMetricDetails()).getRelevant()); + } + + public void testParseFromXContent() throws IOException { + String xContent = " {\n" + " \"relevant_rating_threshold\" : 2" + "}"; + try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) { + RecallAtK recallAtK = RecallAtK.fromXContent(parser); + assertEquals(2, recallAtK.getRelevantRatingThreshold()); + } + } + + public void testCombine() { + RecallAtK metric = new RecallAtK(); + List partialResults = new ArrayList<>(3); + partialResults.add(new EvalQueryQuality("a", 0.1)); + partialResults.add(new EvalQueryQuality("b", 0.2)); + partialResults.add(new EvalQueryQuality("c", 0.6)); + assertEquals(0.3, metric.combine(partialResults), Double.MIN_VALUE); + } + + public void testInvalidRelevantThreshold() { + expectThrows(IllegalArgumentException.class, () -> new RecallAtK(-1, 10)); + } + + public void testInvalidK() { + expectThrows(IllegalArgumentException.class, () -> new RecallAtK(1, -10)); + } + + public static RecallAtK createTestItem() { + return new RecallAtK(randomIntBetween(0, 10), randomIntBetween(1, 50)); + } + + public void testXContentRoundtrip() throws IOException { + RecallAtK testItem = createTestItem(); + XContentBuilder builder = XContentFactory.contentBuilder(randomFrom(XContentType.values())); + XContentBuilder shuffled = shuffleXContent(testItem.toXContent(builder, ToXContent.EMPTY_PARAMS)); + try (XContentParser itemParser = createParser(shuffled)) { + itemParser.nextToken(); + itemParser.nextToken(); + RecallAtK parsedItem = RecallAtK.fromXContent(itemParser); + assertNotSame(testItem, parsedItem); + assertEquals(testItem, parsedItem); + assertEquals(testItem.hashCode(), parsedItem.hashCode()); + } + } + + public void testXContentParsingIsNotLenient() throws IOException { + RecallAtK testItem = createTestItem(); + XContentType xContentType = randomFrom(XContentType.values()); + BytesReference originalBytes = toShuffledXContent(testItem, xContentType, ToXContent.EMPTY_PARAMS, randomBoolean()); + BytesReference withRandomFields = insertRandomFields(xContentType, originalBytes, null, random()); + try (XContentParser parser = createParser(xContentType.xContent(), withRandomFields)) { + parser.nextToken(); + parser.nextToken(); + XContentParseException exception = expectThrows(XContentParseException.class, () -> RecallAtK.fromXContent(parser)); + assertThat(exception.getMessage(), containsString("[recall] unknown field")); + } + } + + public void testSerialization() throws IOException { + RecallAtK original = createTestItem(); + RecallAtK deserialized = ESTestCase.copyWriteable(original, new NamedWriteableRegistry(Collections.emptyList()), + RecallAtK::new); + assertEquals(deserialized, original); + assertEquals(deserialized.hashCode(), original.hashCode()); + assertNotSame(deserialized, original); + } + + public void testEqualsAndHash() throws IOException { + checkEqualsAndHashCode(createTestItem(), RecallAtKTests::copy, RecallAtKTests::mutate); + } + + private static RecallAtK copy(RecallAtK original) { + return new RecallAtK(original.getRelevantRatingThreshold(), original.forcedSearchSize().getAsInt()); + } + + private static RecallAtK mutate(RecallAtK original) { + RecallAtK recallAtK; + switch (randomIntBetween(0, 1)) { + case 0: + recallAtK = new RecallAtK( + randomValueOtherThan(original.getRelevantRatingThreshold(), () -> randomIntBetween(0, 10)), + original.forcedSearchSize().getAsInt()); + break; + case 1: + recallAtK = new RecallAtK( + original.getRelevantRatingThreshold(), + original.forcedSearchSize().getAsInt() + 1); + break; + default: + throw new IllegalStateException("The test should only allow two parameters mutated"); + } + return recallAtK; + } + + private static SearchHit[] toSearchHits(List rated, String index) { + SearchHit[] hits = new SearchHit[rated.size()]; + for (int i = 0; i < rated.size(); i++) { + hits[i] = new SearchHit(i, i + "", Collections.emptyMap()); + hits[i].shard(new SearchShardTarget("testnode", new ShardId(index, "uuid", 0), null, OriginalIndices.NONE)); + } + return hits; + } + + private static RatedDocument createRatedDoc(String index, String id, int rating) { + return new RatedDocument(index, id, rating); + } +}