From 120883f5d9032747595ff5d6a325ad2c89808409 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Tue, 14 Jan 2020 15:01:47 +0200 Subject: [PATCH] =?UTF-8?q?[7.x][ML]=20Add=20num=5Ftop=5Ffeature=5Fimporta?= =?UTF-8?q?nce=5Fvalues=20param=20to=20regression=20and=20classi=E2=80=A6?= =?UTF-8?q?=20(#50914)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a new parameter to regression and classification that enables computation of importance for the top most important features. The computation of the importance is based on SHAP (SHapley Additive exPlanations) method. Backport of #50914 --- .../client/ml/dataframe/Classification.java | 38 ++++++-- .../client/ml/dataframe/Regression.java | 38 ++++++-- .../client/MachineLearningIT.java | 12 +++ .../MlClientDocumentationIT.java | 18 ++-- .../ml/dataframe/ClassificationTests.java | 1 + .../client/ml/dataframe/RegressionTests.java | 1 + .../ml/put-data-frame-analytics.asciidoc | 16 ++-- .../apis/put-dfanalytics.asciidoc | 8 ++ docs/reference/ml/ml-shared.asciidoc | 8 ++ .../dataframe/analyses/BoostedTreeParams.java | 91 +++++++++++++++++-- .../ml/dataframe/analyses/Classification.java | 12 +-- .../ml/dataframe/analyses/Regression.java | 10 +- .../persistence/ElasticsearchMappings.java | 6 ++ .../ml/job/results/ReservedFieldNames.java | 1 + .../analyses/BoostedTreeParamsTests.java | 39 +++++--- .../analyses/ClassificationTests.java | 2 +- .../dataframe/analyses/RegressionTests.java | 2 +- .../ml/integration/ClassificationIT.java | 22 ++++- .../xpack/ml/integration/RegressionIT.java | 21 ++++- 19 files changed, 266 insertions(+), 80 deletions(-) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java index 9d384e6d86786..02861adc73845 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java @@ -46,6 +46,7 @@ public static Builder builder(String dependentVariable) { static final ParseField ETA = new ParseField("eta"); static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees"); static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction"); + static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values"); static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name"); static final ParseField TRAINING_PERCENT = new ParseField("training_percent"); static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); @@ -62,10 +63,11 @@ public static Builder builder(String dependentVariable) { (Double) a[3], (Integer) a[4], (Double) a[5], - (String) a[6], - (Double) a[7], - (Integer) a[8], - (Long) a[9])); + (Integer) a[6], + (String) a[7], + (Double) a[8], + (Integer) a[9], + (Long) a[10])); static { PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE); @@ -74,6 +76,7 @@ public static Builder builder(String dependentVariable) { PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA); PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES); PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION); + PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES); PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME); PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT); PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES); @@ -86,13 +89,15 @@ public static Builder builder(String dependentVariable) { private final Double eta; private final Integer maximumNumberTrees; private final Double featureBagFraction; + private final Integer numTopFeatureImportanceValues; private final String predictionFieldName; private final Double trainingPercent; private final Integer numTopClasses; private final Long randomizeSeed; private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta, - @Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName, + @Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, + @Nullable Integer numTopFeatureImportanceValues, @Nullable String predictionFieldName, @Nullable Double trainingPercent, @Nullable Integer numTopClasses, @Nullable Long randomizeSeed) { this.dependentVariable = Objects.requireNonNull(dependentVariable); this.lambda = lambda; @@ -100,6 +105,7 @@ private Classification(String dependentVariable, @Nullable Double lambda, @Nulla this.eta = eta; this.maximumNumberTrees = maximumNumberTrees; this.featureBagFraction = featureBagFraction; + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; this.predictionFieldName = predictionFieldName; this.trainingPercent = trainingPercent; this.numTopClasses = numTopClasses; @@ -135,6 +141,10 @@ public Double getFeatureBagFraction() { return featureBagFraction; } + public Integer getNumTopFeatureImportanceValues() { + return numTopFeatureImportanceValues; + } + public String getPredictionFieldName() { return predictionFieldName; } @@ -170,6 +180,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (featureBagFraction != null) { builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction); } + if (numTopFeatureImportanceValues != null) { + builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); + } if (predictionFieldName != null) { builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName); } @@ -188,8 +201,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws @Override public int hashCode() { - return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName, - trainingPercent, randomizeSeed, numTopClasses); + return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, numTopFeatureImportanceValues, + predictionFieldName, trainingPercent, randomizeSeed, numTopClasses); } @Override @@ -203,6 +216,7 @@ public boolean equals(Object o) { && Objects.equals(eta, that.eta) && Objects.equals(maximumNumberTrees, that.maximumNumberTrees) && Objects.equals(featureBagFraction, that.featureBagFraction) + && Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues) && Objects.equals(predictionFieldName, that.predictionFieldName) && Objects.equals(trainingPercent, that.trainingPercent) && Objects.equals(randomizeSeed, that.randomizeSeed) @@ -221,6 +235,7 @@ public static class Builder { private Double eta; private Integer maximumNumberTrees; private Double featureBagFraction; + private Integer numTopFeatureImportanceValues; private String predictionFieldName; private Double trainingPercent; private Integer numTopClasses; @@ -255,6 +270,11 @@ public Builder setFeatureBagFraction(Double featureBagFraction) { return this; } + public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceValues) { + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; + return this; + } + public Builder setPredictionFieldName(String predictionFieldName) { this.predictionFieldName = predictionFieldName; return this; @@ -276,8 +296,8 @@ public Builder setNumTopClasses(Integer numTopClasses) { } public Classification build() { - return new Classification(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName, - trainingPercent, numTopClasses, randomizeSeed); + return new Classification(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, + numTopFeatureImportanceValues, predictionFieldName, trainingPercent, numTopClasses, randomizeSeed); } } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java index fa55ee40b27fb..d7e374a2563a1 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java @@ -46,6 +46,7 @@ public static Builder builder(String dependentVariable) { static final ParseField ETA = new ParseField("eta"); static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees"); static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction"); + static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values"); static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name"); static final ParseField TRAINING_PERCENT = new ParseField("training_percent"); static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed"); @@ -61,9 +62,10 @@ public static Builder builder(String dependentVariable) { (Double) a[3], (Integer) a[4], (Double) a[5], - (String) a[6], - (Double) a[7], - (Long) a[8])); + (Integer) a[6], + (String) a[7], + (Double) a[8], + (Long) a[9])); static { PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE); @@ -72,6 +74,7 @@ public static Builder builder(String dependentVariable) { PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA); PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES); PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION); + PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES); PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME); PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT); PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED); @@ -83,12 +86,14 @@ public static Builder builder(String dependentVariable) { private final Double eta; private final Integer maximumNumberTrees; private final Double featureBagFraction; + private final Integer numTopFeatureImportanceValues; private final String predictionFieldName; private final Double trainingPercent; private final Long randomizeSeed; - private Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta, - @Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName, + private Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta, + @Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, + @Nullable Integer numTopFeatureImportanceValues, @Nullable String predictionFieldName, @Nullable Double trainingPercent, @Nullable Long randomizeSeed) { this.dependentVariable = Objects.requireNonNull(dependentVariable); this.lambda = lambda; @@ -96,6 +101,7 @@ private Regression(String dependentVariable, @Nullable Double lambda, @Nullable this.eta = eta; this.maximumNumberTrees = maximumNumberTrees; this.featureBagFraction = featureBagFraction; + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; this.predictionFieldName = predictionFieldName; this.trainingPercent = trainingPercent; this.randomizeSeed = randomizeSeed; @@ -130,6 +136,10 @@ public Double getFeatureBagFraction() { return featureBagFraction; } + public Integer getNumTopFeatureImportanceValues() { + return numTopFeatureImportanceValues; + } + public String getPredictionFieldName() { return predictionFieldName; } @@ -161,6 +171,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (featureBagFraction != null) { builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction); } + if (numTopFeatureImportanceValues != null) { + builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); + } if (predictionFieldName != null) { builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName); } @@ -176,8 +189,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws @Override public int hashCode() { - return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName, - trainingPercent, randomizeSeed); + return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, numTopFeatureImportanceValues, + predictionFieldName, trainingPercent, randomizeSeed); } @Override @@ -191,6 +204,7 @@ public boolean equals(Object o) { && Objects.equals(eta, that.eta) && Objects.equals(maximumNumberTrees, that.maximumNumberTrees) && Objects.equals(featureBagFraction, that.featureBagFraction) + && Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues) && Objects.equals(predictionFieldName, that.predictionFieldName) && Objects.equals(trainingPercent, that.trainingPercent) && Objects.equals(randomizeSeed, that.randomizeSeed); @@ -208,6 +222,7 @@ public static class Builder { private Double eta; private Integer maximumNumberTrees; private Double featureBagFraction; + private Integer numTopFeatureImportanceValues; private String predictionFieldName; private Double trainingPercent; private Long randomizeSeed; @@ -241,6 +256,11 @@ public Builder setFeatureBagFraction(Double featureBagFraction) { return this; } + public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceValues) { + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; + return this; + } + public Builder setPredictionFieldName(String predictionFieldName) { this.predictionFieldName = predictionFieldName; return this; @@ -257,8 +277,8 @@ public Builder setRandomizeSeed(Long randomizeSeed) { } public Regression build() { - return new Regression(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName, - trainingPercent, randomizeSeed); + return new Regression(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, + numTopFeatureImportanceValues, predictionFieldName, trainingPercent, randomizeSeed); } } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 247b726e00874..f9ed6f4e259f5 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -1324,6 +1324,12 @@ public void testPutDataFrameAnalyticsConfig_GivenRegression() throws Exception { .setPredictionFieldName("my_dependent_variable_prediction") .setTrainingPercent(80.0) .setRandomizeSeed(42L) + .setLambda(1.0) + .setGamma(1.0) + .setEta(1.0) + .setMaximumNumberTrees(10) + .setFeatureBagFraction(0.5) + .setNumTopFeatureImportanceValues(3) .build()) .setDescription("this is a regression") .build(); @@ -1361,6 +1367,12 @@ public void testPutDataFrameAnalyticsConfig_GivenClassification() throws Excepti .setTrainingPercent(80.0) .setRandomizeSeed(42L) .setNumTopClasses(1) + .setLambda(1.0) + .setGamma(1.0) + .setEta(1.0) + .setMaximumNumberTrees(10) + .setFeatureBagFraction(0.5) + .setNumTopFeatureImportanceValues(3) .build()) .setDescription("this is a classification") .build(); diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index 860fe533fd359..142f1f1f66081 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -2975,10 +2975,11 @@ public void testPutDataFrameAnalytics() throws Exception { .setEta(5.5) // <4> .setMaximumNumberTrees(50) // <5> .setFeatureBagFraction(0.4) // <6> - .setPredictionFieldName("my_prediction_field_name") // <7> - .setTrainingPercent(50.0) // <8> - .setRandomizeSeed(1234L) // <9> - .setNumTopClasses(1) // <10> + .setNumTopFeatureImportanceValues(3) // <7> + .setPredictionFieldName("my_prediction_field_name") // <8> + .setTrainingPercent(50.0) // <9> + .setRandomizeSeed(1234L) // <10> + .setNumTopClasses(1) // <11> .build(); // end::put-data-frame-analytics-classification @@ -2989,9 +2990,10 @@ public void testPutDataFrameAnalytics() throws Exception { .setEta(5.5) // <4> .setMaximumNumberTrees(50) // <5> .setFeatureBagFraction(0.4) // <6> - .setPredictionFieldName("my_prediction_field_name") // <7> - .setTrainingPercent(50.0) // <8> - .setRandomizeSeed(1234L) // <9> + .setNumTopFeatureImportanceValues(3) // <7> + .setPredictionFieldName("my_prediction_field_name") // <8> + .setTrainingPercent(50.0) // <9> + .setRandomizeSeed(1234L) // <10> .build(); // end::put-data-frame-analytics-regression @@ -3670,7 +3672,7 @@ public void testPutTrainedModel() throws Exception { } { PutTrainedModelRequest request = new PutTrainedModelRequest(trainedModelConfig); - + // tag::put-trained-model-execute-listener ActionListener listener = new ActionListener() { @Override diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java index 5ef8fdaef5a27..79d78c888880f 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java @@ -32,6 +32,7 @@ public static Classification randomClassification() { .setEta(randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true)) .setMaximumNumberTrees(randomBoolean() ? null : randomIntBetween(1, 2000)) .setFeatureBagFraction(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false)) + .setNumTopFeatureImportanceValues(randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE)) .setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10)) .setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true)) .setRandomizeSeed(randomBoolean() ? null : randomLong()) diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/RegressionTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/RegressionTests.java index 02e41ecdff333..eedffb4740d78 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/RegressionTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/RegressionTests.java @@ -32,6 +32,7 @@ public static Regression randomRegression() { .setEta(randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true)) .setMaximumNumberTrees(randomBoolean() ? null : randomIntBetween(1, 2000)) .setFeatureBagFraction(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false)) + .setNumTopFeatureImportanceValues(randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE)) .setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10)) .setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true)) .build(); diff --git a/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc b/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc index 2152eff5c0850..4be2011340210 100644 --- a/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc +++ b/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc @@ -117,10 +117,11 @@ include-tagged::{doc-tests-file}[{api}-classification] <4> The applied shrinkage. A double in [0.001, 1]. <5> The maximum number of trees the forest is allowed to contain. An integer in [1, 2000]. <6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1]. -<7> The name of the prediction field in the results object. -<8> The percentage of training-eligible rows to be used in training. Defaults to 100%. -<9> The seed to be used by the random generator that picks which rows are used in training. -<10> The number of top classes to be reported in the results. Defaults to 2. +<7> If set, feature importance for the top most important features will be computed. +<8> The name of the prediction field in the results object. +<9> The percentage of training-eligible rows to be used in training. Defaults to 100%. +<10> The seed to be used by the random generator that picks which rows are used in training. +<11> The number of top classes to be reported in the results. Defaults to 2. ===== Regression @@ -137,9 +138,10 @@ include-tagged::{doc-tests-file}[{api}-regression] <4> The applied shrinkage. A double in [0.001, 1]. <5> The maximum number of trees the forest is allowed to contain. An integer in [1, 2000]. <6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1]. -<7> The name of the prediction field in the results object. -<8> The percentage of training-eligible rows to be used in training. Defaults to 100%. -<9> The seed to be used by the random generator that picks which rows are used in training. +<7> If set, feature importance for the top most important features will be computed. +<8> The name of the prediction field in the results object. +<9> The percentage of training-eligible rows to be used in training. Defaults to 100%. +<10> The seed to be used by the random generator that picks which rows are used in training. ==== Analyzed fields diff --git a/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc b/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc index 8ecc11e115f40..b38b42f3af8ca 100644 --- a/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc +++ b/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc @@ -150,6 +150,10 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=prediction-field-name] (Optional, long) include::{docdir}/ml/ml-shared.asciidoc[tag=randomize-seed] +`analysis`.`classification`.`num_top_feature_importance_values`:::: +(Optional, integer) +include::{docdir}/ml/ml-shared.asciidoc[tag=num-top-feature-importance-values] + `analysis`.`classification`.`training_percent`:::: (Optional, integer) include::{docdir}/ml/ml-shared.asciidoc[tag=training-percent] @@ -229,6 +233,10 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=lambda] (Optional, string) include::{docdir}/ml/ml-shared.asciidoc[tag=prediction-field-name] +`analysis`.`regression`.`num_top_feature_importance_values`:::: +(Optional, integer) +include::{docdir}/ml/ml-shared.asciidoc[tag=num-top-feature-importance-values] + `analysis`.`regression`.`training_percent`:::: (Optional, integer) include::{docdir}/ml/ml-shared.asciidoc[tag=training-percent] diff --git a/docs/reference/ml/ml-shared.asciidoc b/docs/reference/ml/ml-shared.asciidoc index 07e7f38d42fa8..8d6022232e842 100644 --- a/docs/reference/ml/ml-shared.asciidoc +++ b/docs/reference/ml/ml-shared.asciidoc @@ -637,6 +637,14 @@ end::include-model-definition[] tag::indices[] An array of index names. Wildcards are supported. For example: `["it_ops_metrics", "server*"]`. + +tag::num-top-feature-importance-values[] +Advanced configuration option. If set, feature importance for the top +most important features will be computed. Importance is calculated +using the SHAP (SHapley Additive exPlanations) method as described in +https://papers.nips.cc/paper/7062-a-unified-approach-to-interpreting-model-predictions.pdf[Lundberg, S. M., & Lee, S.-I. A Unified Approach to Interpreting Model Predictions. In NeurIPS 2017.]. +end::num-top-feature-importance-values[] + + -- NOTE: If any indices are in remote clusters then `cluster.remote.connect` must diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParams.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParams.java index 0f06b08444f53..e0890c21377ca 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParams.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParams.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.analyses; +import org.elasticsearch.Version; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; @@ -34,6 +35,7 @@ public class BoostedTreeParams implements ToXContentFragment, Writeable { public static final ParseField ETA = new ParseField("eta"); public static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees"); public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction"); + public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values"); static void declareFields(AbstractObjectParser parser) { parser.declareDouble(optionalConstructorArg(), LAMBDA); @@ -41,6 +43,7 @@ static void declareFields(AbstractObjectParser parser) { parser.declareDouble(optionalConstructorArg(), ETA); parser.declareInt(optionalConstructorArg(), MAXIMUM_NUMBER_TREES); parser.declareDouble(optionalConstructorArg(), FEATURE_BAG_FRACTION); + parser.declareInt(optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES); } private final Double lambda; @@ -48,12 +51,14 @@ static void declareFields(AbstractObjectParser parser) { private final Double eta; private final Integer maximumNumberTrees; private final Double featureBagFraction; + private final Integer numTopFeatureImportanceValues; public BoostedTreeParams(@Nullable Double lambda, - @Nullable Double gamma, - @Nullable Double eta, - @Nullable Integer maximumNumberTrees, - @Nullable Double featureBagFraction) { + @Nullable Double gamma, + @Nullable Double eta, + @Nullable Integer maximumNumberTrees, + @Nullable Double featureBagFraction, + @Nullable Integer numTopFeatureImportanceValues) { if (lambda != null && lambda < 0) { throw ExceptionsHelper.badRequestException("[{}] must be a non-negative double", LAMBDA.getPreferredName()); } @@ -69,15 +74,16 @@ public BoostedTreeParams(@Nullable Double lambda, if (featureBagFraction != null && (featureBagFraction <= 0 || featureBagFraction > 1.0)) { throw ExceptionsHelper.badRequestException("[{}] must be a double in (0, 1]", FEATURE_BAG_FRACTION.getPreferredName()); } + if (numTopFeatureImportanceValues != null && numTopFeatureImportanceValues < 0) { + throw ExceptionsHelper.badRequestException("[{}] must be a non-negative integer", + NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName()); + } this.lambda = lambda; this.gamma = gamma; this.eta = eta; this.maximumNumberTrees = maximumNumberTrees; this.featureBagFraction = featureBagFraction; - } - - public BoostedTreeParams() { - this(null, null, null, null, null); + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; } BoostedTreeParams(StreamInput in) throws IOException { @@ -86,6 +92,11 @@ public BoostedTreeParams() { eta = in.readOptionalDouble(); maximumNumberTrees = in.readOptionalVInt(); featureBagFraction = in.readOptionalDouble(); + if (in.getVersion().onOrAfter(Version.V_7_6_0)) { + numTopFeatureImportanceValues = in.readOptionalInt(); + } else { + numTopFeatureImportanceValues = null; + } } @Override @@ -95,6 +106,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalDouble(eta); out.writeOptionalVInt(maximumNumberTrees); out.writeOptionalDouble(featureBagFraction); + if (out.getVersion().onOrAfter(Version.V_7_6_0)) { + out.writeOptionalInt(numTopFeatureImportanceValues); + } } @Override @@ -114,6 +128,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (featureBagFraction != null) { builder.field(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction); } + if (numTopFeatureImportanceValues != null) { + builder.field(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); + } return builder; } @@ -134,6 +151,9 @@ Map getParams() { if (featureBagFraction != null) { params.put(FEATURE_BAG_FRACTION.getPreferredName(), featureBagFraction); } + if (numTopFeatureImportanceValues != null) { + params.put(NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), numTopFeatureImportanceValues); + } return params; } @@ -146,11 +166,62 @@ public boolean equals(Object o) { && Objects.equals(gamma, that.gamma) && Objects.equals(eta, that.eta) && Objects.equals(maximumNumberTrees, that.maximumNumberTrees) - && Objects.equals(featureBagFraction, that.featureBagFraction); + && Objects.equals(featureBagFraction, that.featureBagFraction) + && Objects.equals(numTopFeatureImportanceValues, that.numTopFeatureImportanceValues); } @Override public int hashCode() { - return Objects.hash(lambda, gamma, eta, maximumNumberTrees, featureBagFraction); + return Objects.hash(lambda, gamma, eta, maximumNumberTrees, featureBagFraction, numTopFeatureImportanceValues); + } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private Double lambda; + private Double gamma; + private Double eta; + private Integer maximumNumberTrees; + private Double featureBagFraction; + private Integer numTopFeatureImportanceValues; + + private Builder() {} + + public Builder setLambda(Double lambda) { + this.lambda = lambda; + return this; + } + + public Builder setGamma(Double gamma) { + this.gamma = gamma; + return this; + } + + public Builder setEta(Double eta) { + this.eta = eta; + return this; + } + + public Builder setMaximumNumberTrees(Integer maximumNumberTrees) { + this.maximumNumberTrees = maximumNumberTrees; + return this; + } + + public Builder setFeatureBagFraction(Double featureBagFraction) { + this.featureBagFraction = featureBagFraction; + return this; + } + + public Builder setNumTopFeatureImportanceValues(Integer numTopFeatureImportanceValues) { + this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; + return this; + } + + public BoostedTreeParams build() { + return new BoostedTreeParams(lambda, gamma, eta, maximumNumberTrees, featureBagFraction, numTopFeatureImportanceValues); + } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java index 5a6cc664edf46..24b814d19ed05 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java @@ -50,11 +50,11 @@ private static ConstructingObjectParser createParser(boole lenient, a -> new Classification( (String) a[0], - new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5]), - (String) a[6], - (Integer) a[7], - (Double) a[8], - (Long) a[9])); + new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (Integer) a[6]), + (String) a[7], + (Integer) a[8], + (Double) a[9], + (Long) a[10])); parser.declareString(constructorArg(), DEPENDENT_VARIABLE); BoostedTreeParams.declareFields(parser); parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME); @@ -114,7 +114,7 @@ public Classification(String dependentVariable, } public Classification(String dependentVariable) { - this(dependentVariable, new BoostedTreeParams(), null, null, null, null); + this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null); } public Classification(StreamInput in) throws IOException { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java index fe2927591312a..83174a9aebfe3 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java @@ -47,10 +47,10 @@ private static ConstructingObjectParser createParser(boolean l lenient, a -> new Regression( (String) a[0], - new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5]), - (String) a[6], - (Double) a[7], - (Long) a[8])); + new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (Integer) a[6]), + (String) a[7], + (Double) a[8], + (Long) a[9])); parser.declareString(constructorArg(), DEPENDENT_VARIABLE); BoostedTreeParams.declareFields(parser); parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME); @@ -85,7 +85,7 @@ public Regression(String dependentVariable, } public Regression(String dependentVariable) { - this(dependentVariable, new BoostedTreeParams(), null, null, null); + this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null); } public Regression(StreamInput in) throws IOException { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java index 462943861094d..a90f0d919707f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java @@ -471,6 +471,9 @@ public static void addDataFrameAnalyticsFields(XContentBuilder builder) throws I .startObject(BoostedTreeParams.FEATURE_BAG_FRACTION.getPreferredName()) .field(TYPE, DOUBLE) .endObject() + .startObject(BoostedTreeParams.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName()) + .field(TYPE, INTEGER) + .endObject() .startObject(Regression.PREDICTION_FIELD_NAME.getPreferredName()) .field(TYPE, KEYWORD) .endObject() @@ -499,6 +502,9 @@ public static void addDataFrameAnalyticsFields(XContentBuilder builder) throws I .startObject(BoostedTreeParams.FEATURE_BAG_FRACTION.getPreferredName()) .field(TYPE, DOUBLE) .endObject() + .startObject(BoostedTreeParams.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName()) + .field(TYPE, INTEGER) + .endObject() .startObject(Classification.PREDICTION_FIELD_NAME.getPreferredName()) .field(TYPE, KEYWORD) .endObject() diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java index d96f57d06811e..23075b2b9df23 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java @@ -323,6 +323,7 @@ public final class ReservedFieldNames { BoostedTreeParams.ETA.getPreferredName(), BoostedTreeParams.MAXIMUM_NUMBER_TREES.getPreferredName(), BoostedTreeParams.FEATURE_BAG_FRACTION.getPreferredName(), + BoostedTreeParams.NUM_TOP_FEATURE_IMPORTANCE_VALUES.getPreferredName(), ElasticsearchMappings.CONFIG_TYPE, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParamsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParamsTests.java index 145533df407cd..6f3aff88846d9 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParamsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/BoostedTreeParamsTests.java @@ -23,7 +23,7 @@ protected BoostedTreeParams doParseInstance(XContentParser parser) throws IOExce new ConstructingObjectParser<>( BoostedTreeParams.NAME, true, - a -> new BoostedTreeParams((Double) a[0], (Double) a[1], (Double) a[2], (Integer) a[3], (Double) a[4])); + a -> new BoostedTreeParams((Double) a[0], (Double) a[1], (Double) a[2], (Integer) a[3], (Double) a[4], (Integer) a[5])); BoostedTreeParams.declareFields(objParser); return objParser.apply(parser, null); } @@ -34,12 +34,14 @@ protected BoostedTreeParams createTestInstance() { } public static BoostedTreeParams createRandom() { - Double lambda = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true); - Double gamma = randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true); - Double eta = randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true); - Integer maximumNumberTrees = randomBoolean() ? null : randomIntBetween(1, 2000); - Double featureBagFraction = randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false); - return new BoostedTreeParams(lambda, gamma, eta, maximumNumberTrees, featureBagFraction); + return BoostedTreeParams.builder() + .setLambda(randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true)) + .setGamma(randomBoolean() ? null : randomDoubleBetween(0.0, Double.MAX_VALUE, true)) + .setEta(randomBoolean() ? null : randomDoubleBetween(0.001, 1.0, true)) + .setMaximumNumberTrees(randomBoolean() ? null : randomIntBetween(1, 2000)) + .setFeatureBagFraction(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false)) + .setNumTopFeatureImportanceValues(randomBoolean() ? null : randomIntBetween(0, Integer.MAX_VALUE)) + .build(); } @Override @@ -49,57 +51,64 @@ protected Writeable.Reader instanceReader() { public void testConstructor_GivenNegativeLambda() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new BoostedTreeParams(-0.00001, 0.0, 0.5, 500, 0.3)); + () -> BoostedTreeParams.builder().setLambda(-0.00001).build()); assertThat(e.getMessage(), equalTo("[lambda] must be a non-negative double")); } public void testConstructor_GivenNegativeGamma() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new BoostedTreeParams(0.0, -0.00001, 0.5, 500, 0.3)); + () -> BoostedTreeParams.builder().setGamma(-0.00001).build()); assertThat(e.getMessage(), equalTo("[gamma] must be a non-negative double")); } public void testConstructor_GivenEtaIsZero() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new BoostedTreeParams(0.0, 0.0, 0.0, 500, 0.3)); + () -> BoostedTreeParams.builder().setEta(0.0).build()); assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]")); } public void testConstructor_GivenEtaIsGreaterThanOne() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new BoostedTreeParams(0.0, 0.0, 1.00001, 500, 0.3)); + () -> BoostedTreeParams.builder().setEta(1.00001).build()); assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]")); } public void testConstructor_GivenMaximumNumberTreesIsZero() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new BoostedTreeParams(0.0, 0.0, 0.5, 0, 0.3)); + () -> BoostedTreeParams.builder().setMaximumNumberTrees(0).build()); assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]")); } public void testConstructor_GivenMaximumNumberTreesIsGreaterThan2k() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new BoostedTreeParams(0.0, 0.0, 0.5, 2001, 0.3)); + () -> BoostedTreeParams.builder().setMaximumNumberTrees(2001).build()); assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]")); } public void testConstructor_GivenFeatureBagFractionIsLessThanZero() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new BoostedTreeParams(0.0, 0.0, 0.5, 500, -0.00001)); + () -> BoostedTreeParams.builder().setFeatureBagFraction(-0.00001).build()); assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]")); } public void testConstructor_GivenFeatureBagFractionIsGreaterThanOne() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.00001)); + () -> BoostedTreeParams.builder().setFeatureBagFraction(1.00001).build()); assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]")); } + + public void testConstructor_GivenTopFeatureImportanceValuesIsNegative() { + ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, + () -> BoostedTreeParams.builder().setNumTopFeatureImportanceValues(-1).build()); + + assertThat(e.getMessage(), equalTo("[num_top_feature_importance_values] must be a non-negative integer")); + } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java index 7a0af05071b7f..55afb76ef5c1f 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java @@ -37,7 +37,7 @@ public class ClassificationTests extends AbstractSerializingTestCase { - private static final BoostedTreeParams BOOSTED_TREE_PARAMS = new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0); + private static final BoostedTreeParams BOOSTED_TREE_PARAMS = BoostedTreeParams.builder().build(); @Override protected Classification doParseInstance(XContentParser parser) throws IOException { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java index c123a0553d190..ab9e12650e88a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java @@ -31,7 +31,7 @@ public class RegressionTests extends AbstractSerializingTestCase { - private static final BoostedTreeParams BOOSTED_TREE_PARAMS = new BoostedTreeParams(0.0, 0.0, 0.5, 500, 1.0); + private static final BoostedTreeParams BOOSTED_TREE_PARAMS = BoostedTreeParams.builder().build(); @Override protected Regression doParseInstance(XContentParser parser) throws IOException { diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index a0e0c0b4ccae2..078c174445467 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -6,7 +6,6 @@ package org.elasticsearch.xpack.ml.integration; import com.google.common.collect.Ordering; - import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.admin.indices.get.GetIndexAction; import org.elasticsearch.action.admin.indices.get.GetIndexRequest; @@ -28,7 +27,6 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams; -import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.Accuracy; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification.MulticlassConfusionMatrix; @@ -83,7 +81,14 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws String predictedClassField = KEYWORD_FIELD + "_prediction"; indexData(sourceIndex, 300, 50, KEYWORD_FIELD); - DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD)); + DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, + new Classification( + KEYWORD_FIELD, + BoostedTreeParams.builder().setNumTopFeatureImportanceValues(1).build(), + null, + null, + null, + null)); registerAnalytics(config); putAnalytics(config); @@ -101,6 +106,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws assertThat(getFieldValue(resultsObject, predictedClassField), is(in(KEYWORD_FIELD_VALUES))); assertThat(getFieldValue(resultsObject, "is_training"), is(destDoc.containsKey(KEYWORD_FIELD))); assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES); + assertThat(resultsObject.keySet().stream().filter(k -> k.startsWith("feature_importance.")).findAny().isPresent(), is(true)); } assertProgress(jobId, 100, 100, 100, 100); @@ -175,7 +181,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(String jobId, sourceIndex, destIndex, null, - new Classification(dependentVariable, BoostedTreeParamsTests.createRandom(), null, numTopClasses, 50.0, null)); + new Classification(dependentVariable, BoostedTreeParams.builder().build(), null, numTopClasses, 50.0, null)); registerAnalytics(config); putAnalytics(config); @@ -354,7 +360,13 @@ public void testTwoJobsWithSameRandomizeSeedUseSameTrainingSet() throws Exceptio String firstJobId = "classification_two_jobs_with_same_randomize_seed_1"; String firstJobDestIndex = firstJobId + "_dest"; - BoostedTreeParams boostedTreeParams = new BoostedTreeParams(1.0, 1.0, 1.0, 1, 1.0); + BoostedTreeParams boostedTreeParams = BoostedTreeParams.builder() + .setLambda(1.0) + .setGamma(1.0) + .setEta(1.0) + .setFeatureBagFraction(1.0) + .setMaximumNumberTrees(1) + .build(); DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null, new Classification(dependentVariable, boostedTreeParams, null, 1, 50.0, null)); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java index 5ecab6f69d429..8b7350d9e13e7 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java @@ -18,7 +18,6 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState; import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams; -import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.junit.After; @@ -53,7 +52,14 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws initialize("regression_single_numeric_feature_and_mixed_data_set"); indexData(sourceIndex, 300, 50); - DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD)); + DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, + new Regression( + DEPENDENT_VARIABLE_FIELD, + BoostedTreeParams.builder().setNumTopFeatureImportanceValues(1).build(), + null, + null, + null) + ); registerAnalytics(config); putAnalytics(config); @@ -78,6 +84,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws assertThat(resultsObject.containsKey("variable_prediction"), is(true)); assertThat(resultsObject.containsKey("is_training"), is(true)); assertThat(resultsObject.get("is_training"), is(destDoc.containsKey(DEPENDENT_VARIABLE_FIELD))); + assertThat(resultsObject.containsKey("feature_importance." + NUMERICAL_FEATURE_FIELD), is(true)); } assertProgress(jobId, 100, 100, 100, 100); @@ -141,7 +148,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception sourceIndex, destIndex, null, - new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParamsTests.createRandom(), null, 50.0, null)); + new Regression(DEPENDENT_VARIABLE_FIELD, BoostedTreeParams.builder().build(), null, 50.0, null)); registerAnalytics(config); putAnalytics(config); @@ -244,7 +251,13 @@ public void testTwoJobsWithSameRandomizeSeedUseSameTrainingSet() throws Exceptio String firstJobId = "regression_two_jobs_with_same_randomize_seed_1"; String firstJobDestIndex = firstJobId + "_dest"; - BoostedTreeParams boostedTreeParams = new BoostedTreeParams(1.0, 1.0, 1.0, 1, 1.0); + BoostedTreeParams boostedTreeParams = BoostedTreeParams.builder() + .setLambda(1.0) + .setGamma(1.0) + .setEta(1.0) + .setFeatureBagFraction(1.0) + .setMaximumNumberTrees(1) + .build(); DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null, new Regression(DEPENDENT_VARIABLE_FIELD, boostedTreeParams, null, 50.0, null));