From 78a4be304f7fc44ff70cfaf9be971a9a55ff6d5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Przemys=C5=82aw=20Witek?= Date: Thu, 17 Oct 2019 17:59:22 +0200 Subject: [PATCH] Make num_top_classes parameter's default value equal to 2 (#48119) --- .../client/ml/dataframe/Classification.java | 29 ++++++-- .../client/MachineLearningIT.java | 7 +- .../MlClientDocumentationIT.java | 1 + .../ml/dataframe/ClassificationTests.java | 1 + .../ml/put-data-frame-analytics.asciidoc | 1 + .../ml/dataframe/analyses/Classification.java | 12 +++- .../analyses/ClassificationTests.java | 66 +++++++++++++++---- .../ml/integration/ClassificationIT.java | 4 +- .../test/ml/data_frame_analytics_crud.yml | 2 +- 9 files changed, 96 insertions(+), 27 deletions(-) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java index fb9234d25b84e..d4e7bce5ec442 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java @@ -48,6 +48,7 @@ public static Builder builder(String dependentVariable) { static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction"); static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name"); static final ParseField TRAINING_PERCENT = new ParseField("training_percent"); + static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( @@ -61,7 +62,8 @@ public static Builder builder(String dependentVariable) { (Integer) a[4], (Double) a[5], (String) a[6], - (Double) a[7])); + (Double) a[7], + (Integer) a[8])); static { PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE); @@ -72,6 +74,7 @@ public static Builder builder(String dependentVariable) { PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION); PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME); PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT); + PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES); } private final String dependentVariable; @@ -82,10 +85,11 @@ public static Builder builder(String dependentVariable) { private final Double featureBagFraction; private final String predictionFieldName; private final Double trainingPercent; + private final Integer numTopClasses; private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta, @Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName, - @Nullable Double trainingPercent) { + @Nullable Double trainingPercent, @Nullable Integer numTopClasses) { this.dependentVariable = Objects.requireNonNull(dependentVariable); this.lambda = lambda; this.gamma = gamma; @@ -94,6 +98,7 @@ private Classification(String dependentVariable, @Nullable Double lambda, @Nulla this.featureBagFraction = featureBagFraction; this.predictionFieldName = predictionFieldName; this.trainingPercent = trainingPercent; + this.numTopClasses = numTopClasses; } @Override @@ -133,6 +138,10 @@ public Double getTrainingPercent() { return trainingPercent; } + public Integer getNumTopClasses() { + return numTopClasses; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); @@ -158,6 +167,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (trainingPercent != null) { builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent); } + if (numTopClasses != null) { + builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses); + } builder.endObject(); return builder; } @@ -165,7 +177,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws @Override public int hashCode() { return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName, - trainingPercent); + trainingPercent, numTopClasses); } @Override @@ -180,7 +192,8 @@ public boolean equals(Object o) { && Objects.equals(maximumNumberTrees, that.maximumNumberTrees) && Objects.equals(featureBagFraction, that.featureBagFraction) && Objects.equals(predictionFieldName, that.predictionFieldName) - && Objects.equals(trainingPercent, that.trainingPercent); + && Objects.equals(trainingPercent, that.trainingPercent) + && Objects.equals(numTopClasses, that.numTopClasses); } @Override @@ -197,6 +210,7 @@ public static class Builder { private Double featureBagFraction; private String predictionFieldName; private Double trainingPercent; + private Integer numTopClasses; private Builder(String dependentVariable) { this.dependentVariable = Objects.requireNonNull(dependentVariable); @@ -237,9 +251,14 @@ public Builder setTrainingPercent(Double trainingPercent) { return this; } + public Builder setNumTopClasses(Integer numTopClasses) { + this.numTopClasses = numTopClasses; + return this; + } + public Classification build() { return new Classification(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName, - trainingPercent); + trainingPercent, numTopClasses); } } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index ac323e6a2ea01..429dbb2d5030b 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -1296,8 +1296,7 @@ public void testPutDataFrameAnalyticsConfig_GivenRegression() throws Exception { .setDest(DataFrameAnalyticsDest.builder() .setIndex("put-test-dest-index") .build()) - .setAnalysis(org.elasticsearch.client.ml.dataframe.Regression - .builder("my_dependent_variable") + .setAnalysis(org.elasticsearch.client.ml.dataframe.Regression.builder("my_dependent_variable") .setTrainingPercent(80.0) .build()) .setDescription("this is a regression") @@ -1331,9 +1330,9 @@ public void testPutDataFrameAnalyticsConfig_GivenClassification() throws Excepti .setDest(DataFrameAnalyticsDest.builder() .setIndex("put-test-dest-index") .build()) - .setAnalysis(org.elasticsearch.client.ml.dataframe.Classification - .builder("my_dependent_variable") + .setAnalysis(org.elasticsearch.client.ml.dataframe.Classification.builder("my_dependent_variable") .setTrainingPercent(80.0) + .setNumTopClasses(1) .build()) .setDescription("this is a classification") .build(); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index 717e8c04c9d10..9bfe943e2c093 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -2951,6 +2951,7 @@ public void testPutDataFrameAnalytics() throws Exception { .setFeatureBagFraction(0.4) // <6> .setPredictionFieldName("my_prediction_field_name") // <7> .setTrainingPercent(50.0) // <8> + .setNumTopClasses(1) // <9> .build(); // end::put-data-frame-analytics-classification diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java index 9f0a418178dd5..98f060cc8534a 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java @@ -34,6 +34,7 @@ public static Classification randomClassification() { .setFeatureBagFraction(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false)) .setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10)) .setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true)) + .setNumTopClasses(randomBoolean() ? null : randomIntBetween(0, 10)) .build(); } diff --git a/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc b/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc index 4ca4c31ecf574..c4e7184de7e04 100644 --- a/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc +++ b/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc @@ -118,6 +118,7 @@ include-tagged::{doc-tests-file}[{api}-classification] <6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1]. <7> The name of the prediction field in the results object. <8> The percentage of training-eligible rows to be used in training. Defaults to 100%. +<9> The number of top classes to be reported in the results. Defaults to 2. ===== Regression diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java index ea1cdc2c28b70..edb92f4ce000f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java @@ -67,6 +67,12 @@ public static Classification fromXContent(XContentParser parser, boolean ignoreU .flatMap(Set::stream) .collect(Collectors.toSet())); + /** + * As long as we only support binary classification it makes sense to always report both classes with their probabilities. + * This way the user can see if the prediction was made with confidence they need. + */ + private static final int DEFAULT_NUM_TOP_CLASSES = 2; + private final String dependentVariable; private final BoostedTreeParams boostedTreeParams; private final String predictionFieldName; @@ -87,7 +93,7 @@ public Classification(String dependentVariable, this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE); this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME); this.predictionFieldName = predictionFieldName; - this.numTopClasses = numTopClasses == null ? 0 : numTopClasses; + this.numTopClasses = numTopClasses == null ? DEFAULT_NUM_TOP_CLASSES : numTopClasses; this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent; } @@ -107,6 +113,10 @@ public String getDependentVariable() { return dependentVariable; } + public int getNumTopClasses() { + return numTopClasses; + } + public double getTrainingPercent() { return trainingPercent; } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java index 2cc1fae8eee07..59df68e794425 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java @@ -19,6 +19,8 @@ public class ClassificationTests extends AbstractSerializingTestCase { + private static final BoostedTreeParams BOOSTED_TREE_PARAMS = new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0); + @Override protected Classification doParseInstance(XContentParser parser) throws IOException { return Classification.fromXContent(parser, false); @@ -43,32 +45,68 @@ protected Writeable.Reader instanceReader() { return Classification::new; } - public void testConstructor_GivenTrainingPercentIsNull() { - Classification classification = new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, null); - assertThat(classification.getTrainingPercent(), equalTo(100.0)); - } - - public void testConstructor_GivenTrainingPercentIsBoundary() { - Classification classification = new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, 1.0); - assertThat(classification.getTrainingPercent(), equalTo(1.0)); - classification = new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, 100.0); - assertThat(classification.getTrainingPercent(), equalTo(100.0)); - } - public void testConstructor_GivenTrainingPercentIsLessThanOne() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, 0.999)); + () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 0.999)); assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); } public void testConstructor_GivenTrainingPercentIsGreaterThan100() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Classification("foo", new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0), "result", 3, 100.0001)); + () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0001)); assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); } + public void testConstructor_GivenNumTopClassesIsLessThanZero() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", -1, 1.0)); + + assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]")); + } + + public void testConstructor_GivenNumTopClassesIsGreaterThan1000() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1001, 1.0)); + + assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]")); + } + + public void testGetNumTopClasses() { + Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 7, 1.0); + assertThat(classification.getNumTopClasses(), equalTo(7)); + + // Boundary condition: num_top_classes == 0 + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 0, 1.0); + assertThat(classification.getNumTopClasses(), equalTo(0)); + + // Boundary condition: num_top_classes == 1000 + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1000, 1.0); + assertThat(classification.getNumTopClasses(), equalTo(1000)); + + // num_top_classes == null, default applied + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1.0); + assertThat(classification.getNumTopClasses(), equalTo(2)); + } + + public void testGetTrainingPercent() { + Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0); + assertThat(classification.getTrainingPercent(), equalTo(50.0)); + + // Boundary condition: training_percent == 1.0 + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 1.0); + assertThat(classification.getTrainingPercent(), equalTo(1.0)); + + // Boundary condition: training_percent == 100.0 + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0); + assertThat(classification.getTrainingPercent(), equalTo(100.0)); + + // training_percent == null, default applied + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, null); + assertThat(classification.getTrainingPercent(), equalTo(100.0)); + } + public void testFieldCardinalityLimitsIsNonNull() { assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(nullValue()))); } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index 64dcddfa9fc42..183ee9ab27bf6 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -83,7 +83,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES))); assertThat(resultsObject.containsKey("is_training"), is(true)); assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(KEYWORD_FIELD))); - assertThat(resultsObject.containsKey("top_classes"), is(false)); + assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES, String::valueOf); } assertProgress(jobId, 100, 100, 100, 100); @@ -120,7 +120,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti assertThat((String) resultsObject.get("keyword-field_prediction"), is(in(KEYWORD_FIELD_VALUES))); assertThat(resultsObject.containsKey("is_training"), is(true)); assertThat(resultsObject.get("is_training"), is(true)); - assertThat(resultsObject.containsKey("top_classes"), is(false)); + assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES, String::valueOf); } assertProgress(jobId, 100, 100, 100, 100); diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml index 2e44618cb761a..b8bea46422b43 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml @@ -1810,7 +1810,7 @@ setup: "maximum_number_trees": 400, "feature_bag_fraction": 0.3, "training_percent": 60.3, - "num_top_classes": 0 + "num_top_classes": 2 } }} - is_true: create_time