From da013d1fea2e739141299be82ea7c3ec73ef3efa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20B=C3=BCscher?= Date: Thu, 7 Jun 2018 15:36:02 +0200 Subject: [PATCH 1/3] Add details section for dcg ranking metric While the other two ranking evaluation metrics (precicion and reciprocal rank) already provide a more detailed output for how their score is calculated, the discounted cumulative gain metric (dcg) and its normalized variant are lacking this until now. Its not really clear which level of detail might be useful for debugging and understanding the final metric calculation, but this change adds a `metric_details` section to REST output that contains some information about the evaluation details. --- .../rankeval/DiscountedCumulativeGain.java | 142 ++++++++++++++++-- .../RankEvalNamedXContentProvider.java | 2 + .../index/rankeval/RankEvalPlugin.java | 5 +- .../DiscountedCumulativeGainTests.java | 26 +++- .../index/rankeval/EvalQueryQualityTests.java | 14 +- 5 files changed, 170 insertions(+), 19 deletions(-) diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGain.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGain.java index 13926d7d362ff..01a6e35299b29 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGain.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGain.java @@ -36,6 +36,7 @@ import java.util.Optional; import java.util.stream.Collectors; +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg; import static org.elasticsearch.index.rankeval.EvaluationMetric.joinHitsWithRatings; @@ -129,26 +130,31 @@ public EvalQueryQuality evaluate(String taskId, SearchHit[] hits, .collect(Collectors.toList()); List ratedHits = joinHitsWithRatings(hits, ratedDocs); List ratingsInSearchHits = new ArrayList<>(ratedHits.size()); + int unratedResults = 0; for (RatedSearchHit hit : ratedHits) { - // unknownDocRating might be null, which means it will be unrated docs are - // ignored in the dcg calculation - // we still need to add them as a placeholder so the rank of the subsequent - // ratings is correct + // unknownDocRating might be null, in which case unrated docs will be ignored in the dcg calculation. + // we still need to add them as a placeholder so the rank of the subsequent ratings is correct ratingsInSearchHits.add(hit.getRating().orElse(unknownDocRating)); + if (hit.getRating().isPresent() == false) { + unratedResults++; + } } - double dcg = computeDCG(ratingsInSearchHits); + final double dcg = computeDCG(ratingsInSearchHits); + double result = dcg; + double idcg = 0; if (normalize) { Collections.sort(allRatings, Comparator.nullsLast(Collections.reverseOrder())); - double idcg = computeDCG(allRatings.subList(0, Math.min(ratingsInSearchHits.size(), allRatings.size()))); - if (idcg > 0) { - dcg = dcg / idcg; + idcg = computeDCG(allRatings.subList(0, Math.min(ratingsInSearchHits.size(), allRatings.size()))); + if (idcg != 0) { + result = dcg / idcg; } else { - dcg = 0; + result = 0; } } - EvalQueryQuality evalQueryQuality = new EvalQueryQuality(taskId, dcg); + EvalQueryQuality evalQueryQuality = new EvalQueryQuality(taskId, result); evalQueryQuality.addHitsAndRatings(ratedHits); + evalQueryQuality.setMetricDetails(new Detail(dcg, idcg, unratedResults)); return evalQueryQuality; } @@ -167,7 +173,7 @@ private static double computeDCG(List ratings) { private static final ParseField K_FIELD = new ParseField("k"); private static final ParseField NORMALIZE_FIELD = new ParseField("normalize"); private static final ParseField UNKNOWN_DOC_RATING_FIELD = new ParseField("unknown_doc_rating"); - private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("dcg_at", false, + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("dcg", false, args -> { Boolean normalized = (Boolean) args[0]; Integer optK = (Integer) args[2]; @@ -217,4 +223,118 @@ public final boolean equals(Object obj) { public final int hashCode() { return Objects.hash(normalize, unknownDocRating, k); } + + public static final class Detail implements MetricDetail { + + private static ParseField DCG_FIELD = new ParseField("dcg"); + private static ParseField IDCG_FIELD = new ParseField("ideal_dcg"); + private static ParseField NDCG_FIELD = new ParseField("normalized_dcg"); + private static ParseField UNRATED_FIELD = new ParseField("unrated_docs"); + private final double dcg; + private final double idcg; + private final int unratedDocs; + + Detail(double dcg, double idcg, int unratedDocs) { + this.dcg = dcg; + this.idcg = idcg; + this.unratedDocs = unratedDocs; + } + + Detail(StreamInput in) throws IOException { + this.dcg = in.readDouble(); + this.idcg = in.readDouble(); + this.unratedDocs = in.readVInt(); + } + + @Override + public + String getMetricName() { + return NAME; + } + + @Override + public XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(DCG_FIELD.getPreferredName(), this.dcg); + if (this.idcg != 0) { + builder.field(IDCG_FIELD.getPreferredName(), this.idcg); + builder.field(NDCG_FIELD.getPreferredName(), this.dcg / this.idcg); + } + builder.field(UNRATED_FIELD.getPreferredName(), this.unratedDocs); + return builder; + } + + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, true, args -> { + return new Detail((Double) args[0], (Double) args[1] != null ? (Double) args[1] : 0.0d, (Integer) args[2]); + }); + + static { + PARSER.declareDouble(constructorArg(), DCG_FIELD); + PARSER.declareDouble(optionalConstructorArg(), IDCG_FIELD); + PARSER.declareInt(constructorArg(), UNRATED_FIELD); + } + + public static Detail fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDouble(this.dcg); + out.writeDouble(this.idcg); + out.writeVInt(this.unratedDocs); + } + + @Override + public String getWriteableName() { + return NAME; + } + + /** + * @return the discounted cumulative gain + */ + public double getDCG() { + return this.dcg; + } + + /** + * @return the ideal discounted cumulative gain, can be 0 if nothing was computed, e.g. because no normalization was required + */ + public double getIDCG() { + return this.idcg; + } + + /** + * @return the normalized discounted cumulative gain, can be 0 if nothing was computed, e.g. because no normalization was required + */ + public double getNDCG() { + return (this.idcg != 0) ? this.dcg / this.idcg : 0; + } + + /** + * @return the number of unrated documents in the search results + */ + public Object getUnratedDocs() { + return this.unratedDocs; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + DiscountedCumulativeGain.Detail other = (DiscountedCumulativeGain.Detail) obj; + return (this.dcg == other.dcg && + this.idcg == other.idcg && + this.unratedDocs == other.unratedDocs); + } + + @Override + public int hashCode() { + return Objects.hash(this.dcg, this.idcg, this.unratedDocs); + } + } } + diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalNamedXContentProvider.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalNamedXContentProvider.java index c5785ca3847d4..f2176113cdf9d 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalNamedXContentProvider.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalNamedXContentProvider.java @@ -41,6 +41,8 @@ public List getNamedXContentParsers() { PrecisionAtK.Detail::fromXContent)); namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(MeanReciprocalRank.NAME), MeanReciprocalRank.Detail::fromXContent)); + namedXContent.add(new NamedXContentRegistry.Entry(MetricDetail.class, new ParseField(DiscountedCumulativeGain.NAME), + DiscountedCumulativeGain.Detail::fromXContent)); return namedXContent; } } diff --git a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java index 884cf3bafdcda..8ac2b7fbee528 100644 --- a/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java +++ b/modules/rank-eval/src/main/java/org/elasticsearch/index/rankeval/RankEvalPlugin.java @@ -61,8 +61,9 @@ public List getNamedWriteables() { namedWriteables.add( new NamedWriteableRegistry.Entry(EvaluationMetric.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, PrecisionAtK.NAME, PrecisionAtK.Detail::new)); - namedWriteables - .add(new NamedWriteableRegistry.Entry(MetricDetail.class, MeanReciprocalRank.NAME, MeanReciprocalRank.Detail::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(MetricDetail.class, MeanReciprocalRank.NAME, MeanReciprocalRank.Detail::new)); + namedWriteables.add( + new NamedWriteableRegistry.Entry(MetricDetail.class, DiscountedCumulativeGain.NAME, DiscountedCumulativeGain.Detail::new)); return namedWriteables; } diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainTests.java index 64337786b1eb6..24ac600a11398 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/DiscountedCumulativeGainTests.java @@ -19,6 +19,7 @@ package org.elasticsearch.index.rankeval; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.text.Text; @@ -254,9 +255,8 @@ private void assertParsedCorrect(String xContent, Integer expectedUnknownDocRati public static DiscountedCumulativeGain createTestItem() { boolean normalize = randomBoolean(); - Integer unknownDocRating = Integer.valueOf(randomIntBetween(0, 1000)); - - return new DiscountedCumulativeGain(normalize, unknownDocRating, 10); + Integer unknownDocRating = frequently() ? Integer.valueOf(randomIntBetween(0, 1000)) : null; + return new DiscountedCumulativeGain(normalize, unknownDocRating, randomIntBetween(1, 10)); } public void testXContentRoundtrip() throws IOException { @@ -283,7 +283,25 @@ public void testXContentParsingIsNotLenient() throws IOException { parser.nextToken(); XContentParseException exception = expectThrows(XContentParseException.class, () -> DiscountedCumulativeGain.fromXContent(parser)); - assertThat(exception.getMessage(), containsString("[dcg_at] unknown field")); + assertThat(exception.getMessage(), containsString("[dcg] unknown field")); + } + } + + public void testMetricDetails() { + double dcg = randomDoubleBetween(0, 1, true); + double idcg = randomBoolean() ? 0.0 : randomDoubleBetween(0, 1, true); + double expectedNdcg = idcg != 0 ? dcg / idcg : 0.0; + int unratedDocs = randomIntBetween(0, 100); + DiscountedCumulativeGain.Detail detail = new DiscountedCumulativeGain.Detail(dcg, idcg, unratedDocs); + assertEquals(dcg, detail.getDCG(), 0.0); + assertEquals(idcg, detail.getIDCG(), 0.0); + assertEquals(expectedNdcg, detail.getNDCG(), 0.0); + assertEquals(unratedDocs, detail.getUnratedDocs()); + if (idcg != 0) { + assertEquals("{\"dcg\":{\"dcg\":" + dcg + ",\"ideal_dcg\":" + idcg + ",\"normalized_dcg\":" + expectedNdcg + + ",\"unrated_docs\":" + unratedDocs + "}}", Strings.toString(detail)); + } else { + assertEquals("{\"dcg\":{\"dcg\":" + dcg + ",\"unrated_docs\":" + unratedDocs + "}}", Strings.toString(detail)); } } diff --git a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/EvalQueryQualityTests.java b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/EvalQueryQualityTests.java index 112cf4eaaf72e..e9fae6b5c63ee 100644 --- a/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/EvalQueryQualityTests.java +++ b/modules/rank-eval/src/test/java/org/elasticsearch/index/rankeval/EvalQueryQualityTests.java @@ -68,10 +68,20 @@ public static EvalQueryQuality randomEvalQueryQuality() { EvalQueryQuality evalQueryQuality = new EvalQueryQuality(randomAlphaOfLength(10), randomDoubleBetween(0.0, 1.0, true)); if (randomBoolean()) { - if (randomBoolean()) { + int metricDetail = randomIntBetween(0, 2); + switch (metricDetail) { + case 0: evalQueryQuality.setMetricDetails(new PrecisionAtK.Detail(randomIntBetween(0, 1000), randomIntBetween(0, 1000))); - } else { + break; + case 1: evalQueryQuality.setMetricDetails(new MeanReciprocalRank.Detail(randomIntBetween(0, 1000))); + break; + case 2: + evalQueryQuality.setMetricDetails(new DiscountedCumulativeGain.Detail(randomDoubleBetween(0, 1, true), + randomBoolean() ? randomDoubleBetween(0, 1, true) : 0, randomInt())); + break; + default: + throw new IllegalArgumentException("illegal randomized value in test"); } } evalQueryQuality.addHitsAndRatings(ratedHits); From 34bc510beea9a154282d590f6041ba7b5509438d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20B=C3=BCscher?= Date: Fri, 8 Jun 2018 14:00:00 +0200 Subject: [PATCH 2/3] iter --- .../org/elasticsearch/client/RestHighLevelClientTests.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index 9084a547c162d..8e4f2b2008fe2 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -20,6 +20,7 @@ package org.elasticsearch.client; import com.fasterxml.jackson.core.JsonParseException; + import org.apache.http.Header; import org.apache.http.HttpEntity; import org.apache.http.HttpHost; @@ -625,9 +626,10 @@ public void testProvidedNamedXContents() { assertTrue(names.contains(PrecisionAtK.NAME)); assertTrue(names.contains(DiscountedCumulativeGain.NAME)); assertTrue(names.contains(MeanReciprocalRank.NAME)); - assertEquals(Integer.valueOf(2), categories.get(MetricDetail.class)); + assertEquals(Integer.valueOf(3), categories.get(MetricDetail.class)); assertTrue(names.contains(PrecisionAtK.NAME)); assertTrue(names.contains(MeanReciprocalRank.NAME)); + assertTrue(names.contains(DiscountedCumulativeGain.NAME)); } private static class TrackingActionListener implements ActionListener { From b5c2d95b7183c1f04eb4a8c3ad343c403863530a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Christoph=20B=C3=BCscher?= Date: Wed, 13 Jun 2018 22:24:48 +0200 Subject: [PATCH 3/3] iter --- .../java/org/elasticsearch/client/RestHighLevelClientTests.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java index 8e4f2b2008fe2..cadbf5da795d9 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/RestHighLevelClientTests.java @@ -608,7 +608,7 @@ public void testDefaultNamedXContents() { public void testProvidedNamedXContents() { List namedXContents = RestHighLevelClient.getProvidedNamedXContents(); - assertEquals(7, namedXContents.size()); + assertEquals(8, namedXContents.size()); Map, Integer> categories = new HashMap<>(); List names = new ArrayList<>(); for (NamedXContentRegistry.Entry namedXContent : namedXContents) {