From 0eee38b0c6cb9dec84a63ddf875cd89d073853c6 Mon Sep 17 00:00:00 2001 From: Tom Veasey Date: Tue, 25 Feb 2020 12:29:51 +0000 Subject: [PATCH 01/11] Add a new classification parameter: class assignment objective --- .../client/ml/dataframe/Classification.java | 54 ++++++++++++++++-- .../client/MachineLearningIT.java | 2 + .../MlClientDocumentationIT.java | 3 +- .../ml/dataframe/ClassificationTests.java | 1 + .../ml/put-data-frame-analytics.asciidoc | 3 +- .../apis/put-dfanalytics.asciidoc | 4 ++ docs/reference/ml/ml-shared.asciidoc | 8 +++ .../ml/dataframe/analyses/Classification.java | 55 +++++++++++++++++-- .../ml/job/results/ReservedFieldNames.java | 1 + .../xpack/core/ml/config_index_mappings.json | 3 + .../DataFrameAnalyticsConfigTests.java | 2 + .../analyses/ClassificationTests.java | 51 +++++++++++------ .../ml/integration/ClassificationIT.java | 1 + .../test/ml/data_frame_analytics_crud.yml | 2 + 14 files changed, 161 insertions(+), 29 deletions(-) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java index 02861adc73845..538cd43edf273 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java @@ -22,10 +22,12 @@ import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import java.io.IOException; +import java.util.Locale; import java.util.Objects; public class Classification implements DataFrameAnalysis { @@ -49,6 +51,7 @@ public static Builder builder(String dependentVariable) { static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values"); static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name"); static final ParseField TRAINING_PERCENT = new ParseField("training_percent"); + static final ParseField CLASS_ASSIGNMENT_OBJECTIVE = new ParseField("class_assignment_objective"); static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed"); @@ -66,8 +69,9 @@ public static Builder builder(String dependentVariable) { (Integer) a[6], (String) a[7], (Double) a[8], - (Integer) a[9], - (Long) a[10])); + (ClassAssignmentObjective) a[9], + (Integer) a[10], + (Long) a[11])); static { PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE); @@ -79,6 +83,12 @@ public static Builder builder(String dependentVariable) { PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES); PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME); PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT); + PARSER.declareField(ConstructingObjectParser.optionalConstructorArg(), p -> { + if (p.currentToken() == XContentParser.Token.VALUE_STRING) { + return ClassAssignmentObjective.fromString(p.text()); + } + throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]"); + }, CLASS_ASSIGNMENT_OBJECTIVE, ObjectParser.ValueType.STRING); PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES); PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED); } @@ -92,13 +102,15 @@ public static Builder builder(String dependentVariable) { private final Integer numTopFeatureImportanceValues; private final String predictionFieldName; private final Double trainingPercent; + private final ClassAssignmentObjective classAssignmentObjective; private final Integer numTopClasses; private final Long randomizeSeed; private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta, @Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable Integer numTopFeatureImportanceValues, @Nullable String predictionFieldName, - @Nullable Double trainingPercent, @Nullable Integer numTopClasses, @Nullable Long randomizeSeed) { + @Nullable Double trainingPercent, @Nullable ClassAssignmentObjective classAssignmentObjective, + @Nullable Integer numTopClasses, @Nullable Long randomizeSeed) { this.dependentVariable = Objects.requireNonNull(dependentVariable); this.lambda = lambda; this.gamma = gamma; @@ -108,6 +120,7 @@ private Classification(String dependentVariable, @Nullable Double lambda, @Nulla this.numTopFeatureImportanceValues = numTopFeatureImportanceValues; this.predictionFieldName = predictionFieldName; this.trainingPercent = trainingPercent; + this.classAssignmentObjective = classAssignmentObjective; this.numTopClasses = numTopClasses; this.randomizeSeed = randomizeSeed; } @@ -157,6 +170,10 @@ public Long getRandomizeSeed() { return randomizeSeed; } + public ClassAssignmentObjective getClassAssignmentObjective() { + return classAssignmentObjective; + } + public Integer getNumTopClasses() { return numTopClasses; } @@ -192,6 +209,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (randomizeSeed != null) { builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed); } + if (classAssignmentObjective != null) { + builder.field(CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), classAssignmentObjective); + } if (numTopClasses != null) { builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses); } @@ -201,8 +221,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws @Override public int hashCode() { - return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, numTopFeatureImportanceValues, - predictionFieldName, trainingPercent, randomizeSeed, numTopClasses); + return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, + numTopFeatureImportanceValues, predictionFieldName, trainingPercent, randomizeSeed, classAssignmentObjective, + numTopClasses); } @Override @@ -220,6 +241,7 @@ public boolean equals(Object o) { && Objects.equals(predictionFieldName, that.predictionFieldName) && Objects.equals(trainingPercent, that.trainingPercent) && Objects.equals(randomizeSeed, that.randomizeSeed) + && Objects.equals(classAssignmentObjective, that.classAssignmentObjective) && Objects.equals(numTopClasses, that.numTopClasses); } @@ -228,6 +250,19 @@ public String toString() { return Strings.toString(this); } + public enum ClassAssignmentObjective { + MAXIMIZE_ACCURACY, MAXIMIZE_MINIMUM_RECALL; + + public static ClassAssignmentObjective fromString(String value) { + return ClassAssignmentObjective.valueOf(value.toUpperCase(Locale.ROOT)); + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } + } + public static class Builder { private String dependentVariable; private Double lambda; @@ -238,6 +273,7 @@ public static class Builder { private Integer numTopFeatureImportanceValues; private String predictionFieldName; private Double trainingPercent; + private ClassAssignmentObjective classAssignmentObjective; private Integer numTopClasses; private Long randomizeSeed; @@ -290,6 +326,11 @@ public Builder setRandomizeSeed(Long randomizeSeed) { return this; } + public Builder setClassAssignmentObjective(ClassAssignmentObjective classAssignmentObjective) { + this.classAssignmentObjective = classAssignmentObjective; + return this; + } + public Builder setNumTopClasses(Integer numTopClasses) { this.numTopClasses = numTopClasses; return this; @@ -297,7 +338,8 @@ public Builder setNumTopClasses(Integer numTopClasses) { public Classification build() { return new Classification(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, - numTopFeatureImportanceValues, predictionFieldName, trainingPercent, numTopClasses, randomizeSeed); + numTopFeatureImportanceValues, predictionFieldName, trainingPercent, classAssignmentObjective, + numTopClasses, randomizeSeed); } } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 6fe08f8a507de..01badb24ca7dc 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -1336,6 +1336,7 @@ public void testPutDataFrameAnalyticsConfig_GivenClassification() throws Excepti .setPredictionFieldName("my_dependent_variable_prediction") .setTrainingPercent(80.0) .setRandomizeSeed(42L) + .classAssignmentObjective("maximize_accuracy") .setNumTopClasses(1) .setLambda(1.0) .setGamma(1.0) @@ -1362,6 +1363,7 @@ public void testPutDataFrameAnalyticsConfig_GivenClassification() throws Excepti assertThat(createdConfig.getAnalyzedFields(), equalTo(config.getAnalyzedFields())); assertThat(createdConfig.getModelMemoryLimit(), equalTo(ByteSizeValue.parseBytesSizeValue("1gb", ""))); // default value assertThat(createdConfig.getDescription(), equalTo("this is a classification")); + assertThat(createdConfig.getClassAssignmentObjective(), equalTo(config.getClassificationObjective())); } public void testGetDataFrameAnalyticsConfig_SingleConfig() throws Exception { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index d1c3a5e657e0f..9c733c54a0d29 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -2979,7 +2979,8 @@ public void testPutDataFrameAnalytics() throws Exception { .setPredictionFieldName("my_prediction_field_name") // <8> .setTrainingPercent(50.0) // <9> .setRandomizeSeed(1234L) // <10> - .setNumTopClasses(1) // <11> + .setClassAssignmentObjective("maximize_accuracy") // <11> + .setNumTopClasses(1) // <12> .build(); // end::put-data-frame-analytics-classification diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java index 79d78c888880f..b70e8e86ca7d0 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java @@ -36,6 +36,7 @@ public static Classification randomClassification() { .setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10)) .setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true)) .setRandomizeSeed(randomBoolean() ? null : randomLong()) + .setClassAssignmentObjective(randomBoolean() ? null : randomFrom(Classification.ClassAssignmentObjective.values())) .setNumTopClasses(randomBoolean() ? null : randomIntBetween(0, 10)) .build(); } diff --git a/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc b/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc index 4be2011340210..cf88d65ae9314 100644 --- a/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc +++ b/docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc @@ -121,7 +121,8 @@ include-tagged::{doc-tests-file}[{api}-classification] <8> The name of the prediction field in the results object. <9> The percentage of training-eligible rows to be used in training. Defaults to 100%. <10> The seed to be used by the random generator that picks which rows are used in training. -<11> The number of top classes to be reported in the results. Defaults to 2. +<11> The optimization objective to target when assigning class labels. Defaults to maximize_minimum_recall. +<12> The number of top classes to be reported in the results. Defaults to 2. ===== Regression diff --git a/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc b/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc index 9c1b41c9b57ad..bc6b7ee32332b 100644 --- a/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc +++ b/docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc @@ -136,6 +136,10 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=gamma] (Optional, double) include::{docdir}/ml/ml-shared.asciidoc[tag=lambda] +`analysis`.`classification`.`class_assignment_objective`:::: +(Optional, string) +include::{docdir}/ml/ml-shared.asciidoc[tag=class-assignment-objective] + `analysis`.`classification`.`num_top_classes`:::: (Optional, integer) include::{docdir}/ml/ml-shared.asciidoc[tag=num-top-classes] diff --git a/docs/reference/ml/ml-shared.asciidoc b/docs/reference/ml/ml-shared.asciidoc index f4f65fcd56fba..7638d1f6ad1f8 100644 --- a/docs/reference/ml/ml-shared.asciidoc +++ b/docs/reference/ml/ml-shared.asciidoc @@ -899,6 +899,14 @@ improve diversity in the ensemble. Therefore, only override this if you are confident that the value you choose is appropriate for the data set. end::n-neighbors[] +tag::class-assignment-objective[] +Defines the objective to optimize when assigning class labels. Available +objectives are maximize_accuracy and maximize_minimum_recall. When maximizing +accuracy class labels are choosen to maximize the number of correct predictions. +When maximizing minimum recall labels are choosen to maximize the minimum recall +for any class. Defaults to maximize_minimum_recall. +end::class-assignment-objective[] + tag::num-top-classes[] Defines the number of categories for which the predicted probabilities are reported. It must be non-negative. If it is greater than the diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java index d9289a5bcf1ac..013489f2350bd 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.index.mapper.FieldAliasMapper; @@ -21,6 +22,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Objects; import java.util.Set; @@ -37,6 +39,7 @@ public class Classification implements DataFrameAnalysis { public static final ParseField DEPENDENT_VARIABLE = new ParseField("dependent_variable"); public static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name"); + public static final ParseField CLASS_ASSIGNMENT_OBJECTIVE = new ParseField("class_assignment_objective"); public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); public static final ParseField TRAINING_PERCENT = new ParseField("training_percent"); public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed"); @@ -54,12 +57,19 @@ private static ConstructingObjectParser createParser(boole (String) a[0], new BoostedTreeParams((Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (Integer) a[6]), (String) a[7], - (Integer) a[8], - (Double) a[9], - (Long) a[10])); + (ClassAssignmentObjective) a[8], + (Integer) a[9], + (Double) a[10], + (Long) a[11])); parser.declareString(constructorArg(), DEPENDENT_VARIABLE); BoostedTreeParams.declareFields(parser); parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME); + parser.declareField(optionalConstructorArg(), p -> { + if (p.currentToken() == XContentParser.Token.VALUE_STRING) { + return ClassAssignmentObjective.fromString(p.text()); + } + throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]"); + }, CLASS_ASSIGNMENT_OBJECTIVE, ObjectParser.ValueType.STRING); parser.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES); parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT); parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED); @@ -89,6 +99,7 @@ public static Classification fromXContent(XContentParser parser, boolean ignoreU private final String dependentVariable; private final BoostedTreeParams boostedTreeParams; private final String predictionFieldName; + private final ClassAssignmentObjective classAssignmentObjective; private final int numTopClasses; private final double trainingPercent; private final long randomizeSeed; @@ -96,6 +107,7 @@ public static Classification fromXContent(XContentParser parser, boolean ignoreU public Classification(String dependentVariable, BoostedTreeParams boostedTreeParams, @Nullable String predictionFieldName, + @Nullable ClassAssignmentObjective classAssignmentObjective, @Nullable Integer numTopClasses, @Nullable Double trainingPercent, @Nullable Long randomizeSeed) { @@ -108,19 +120,26 @@ public Classification(String dependentVariable, this.dependentVariable = ExceptionsHelper.requireNonNull(dependentVariable, DEPENDENT_VARIABLE); this.boostedTreeParams = ExceptionsHelper.requireNonNull(boostedTreeParams, BoostedTreeParams.NAME); this.predictionFieldName = predictionFieldName == null ? dependentVariable + "_prediction" : predictionFieldName; + this.classAssignmentObjective = classAssignmentObjective == null ? + ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL : classAssignmentObjective; this.numTopClasses = numTopClasses == null ? DEFAULT_NUM_TOP_CLASSES : numTopClasses; this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent; this.randomizeSeed = randomizeSeed == null ? Randomness.get().nextLong() : randomizeSeed; } public Classification(String dependentVariable) { - this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null); + this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null); } public Classification(StreamInput in) throws IOException { dependentVariable = in.readString(); boostedTreeParams = new BoostedTreeParams(in); predictionFieldName = in.readOptionalString(); + if (in.getVersion().onOrAfter(Version.V_7_7_0)) { + classAssignmentObjective = in.readEnum(ClassAssignmentObjective.class); + } else { + classAssignmentObjective = ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL; + } numTopClasses = in.readOptionalVInt(); trainingPercent = in.readDouble(); if (in.getVersion().onOrAfter(Version.V_7_6_0)) { @@ -142,6 +161,10 @@ public String getPredictionFieldName() { return predictionFieldName; } + public ClassAssignmentObjective getClassAssignmentObjective() { + return classAssignmentObjective; + } + public int getNumTopClasses() { return numTopClasses; } @@ -164,6 +187,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(dependentVariable); boostedTreeParams.writeTo(out); out.writeOptionalString(predictionFieldName); + if (out.getVersion().onOrAfter(Version.V_7_7_0)) { + out.writeEnum(classAssignmentObjective); + } out.writeOptionalVInt(numTopClasses); out.writeDouble(trainingPercent); if (out.getVersion().onOrAfter(Version.V_7_6_0)) { @@ -178,6 +204,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable); boostedTreeParams.toXContent(builder, params); + if (version.onOrAfter(Version.V_7_7_0)) { + builder.field(CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), classAssignmentObjective); + } builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses); if (predictionFieldName != null) { builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName); @@ -195,6 +224,7 @@ public Map getParams(Map> extractedFields) { Map params = new HashMap<>(); params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable); params.putAll(boostedTreeParams.getParams()); + params.put(CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), classAssignmentObjective); params.put(NUM_TOP_CLASSES.getPreferredName(), numTopClasses); if (predictionFieldName != null) { params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName); @@ -303,6 +333,7 @@ public boolean equals(Object o) { return Objects.equals(dependentVariable, that.dependentVariable) && Objects.equals(boostedTreeParams, that.boostedTreeParams) && Objects.equals(predictionFieldName, that.predictionFieldName) + && Objects.equals(classAssignmentObjective, that.classAssignmentObjective) && Objects.equals(numTopClasses, that.numTopClasses) && trainingPercent == that.trainingPercent && randomizeSeed == that.randomizeSeed; @@ -310,6 +341,20 @@ public boolean equals(Object o) { @Override public int hashCode() { - return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, numTopClasses, trainingPercent, randomizeSeed); + return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, classAssignmentObjective, + numTopClasses, trainingPercent, randomizeSeed); + } + + public enum ClassAssignmentObjective { + MAXIMIZE_ACCURACY, MAXIMIZE_MINIMUM_RECALL; + + public static ClassAssignmentObjective fromString(String value) { + return ClassAssignmentObjective.valueOf(value.toUpperCase(Locale.ROOT)); + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java index e1be33e5ff8a2..9ba410c8d14e0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java @@ -316,6 +316,7 @@ public final class ReservedFieldNames { Classification.NAME.getPreferredName(), Classification.DEPENDENT_VARIABLE.getPreferredName(), Classification.PREDICTION_FIELD_NAME.getPreferredName(), + Classification.CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), Classification.NUM_TOP_CLASSES.getPreferredName(), Classification.TRAINING_PERCENT.getPreferredName(), BoostedTreeParams.LAMBDA.getPreferredName(), diff --git a/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/config_index_mappings.json b/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/config_index_mappings.json index 7c10d342bea15..dbf4b68063362 100644 --- a/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/config_index_mappings.json +++ b/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/config_index_mappings.json @@ -43,6 +43,9 @@ "maximum_number_trees" : { "type" : "integer" }, + "class_assignment_objective" : { + "type" : "keyword" + }, "num_top_classes" : { "type" : "integer" }, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java index 7dd0f6f9b0ca9..1c733e04504c8 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfigTests.java @@ -155,12 +155,14 @@ protected void assertOnBWCObject(DataFrameAnalyticsConfig bwcSerializedObject, D bwcAnalysis = new Classification(bwcClassification.getDependentVariable(), bwcClassification.getBoostedTreeParams(), bwcClassification.getPredictionFieldName(), + bwcClassification.getClassAssignmentObjective(), bwcClassification.getNumTopClasses(), bwcClassification.getTrainingPercent(), 42L); testAnalysis = new Classification(testClassification.getDependentVariable(), testClassification.getBoostedTreeParams(), testClassification.getPredictionFieldName(), + testClassification.getClassAssignmentObjective(), testClassification.getNumTopClasses(), testClassification.getTrainingPercent(), 42L); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java index 0898a3fb476e8..87801f5c4458b 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java @@ -54,17 +54,20 @@ public static Classification createRandom() { String dependentVariableName = randomAlphaOfLength(10); BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom(); String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10); + Classification.ClassAssignmentObjective classAssignmentObjective = randomBoolean() ? + null : randomFrom(Classification.ClassAssignmentObjective.values()); Integer numTopClasses = randomBoolean() ? null : randomIntBetween(0, 1000); Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true); Long randomizeSeed = randomBoolean() ? null : randomLong(); - return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, numTopClasses, trainingPercent, - randomizeSeed); + return new Classification(dependentVariableName, boostedTreeParams, predictionFieldName, classAssignmentObjective, + numTopClasses, trainingPercent, randomizeSeed); } public static Classification mutateForVersion(Classification instance, Version version) { return new Classification(instance.getDependentVariable(), BoostedTreeParamsTests.mutateForVersion(instance.getBoostedTreeParams(), version), instance.getPredictionFieldName(), + version.onOrAfter(Version.V_7_7_0) ? instance.getClassAssignmentObjective() : null, instance.getNumTopClasses(), instance.getTrainingPercent(), instance.getRandomizeSeed()); @@ -80,12 +83,14 @@ protected void assertOnBWCObject(Classification bwcSerializedObject, Classificat Classification newBwc = new Classification(bwcSerializedObject.getDependentVariable(), bwcSerializedObject.getBoostedTreeParams(), bwcSerializedObject.getPredictionFieldName(), + bwcSerializedObject.getClassAssignmentObjective(), bwcSerializedObject.getNumTopClasses(), bwcSerializedObject.getTrainingPercent(), 42L); Classification newInstance = new Classification(testInstance.getDependentVariable(), testInstance.getBoostedTreeParams(), testInstance.getPredictionFieldName(), + testInstance.getClassAssignmentObjective(), testInstance.getNumTopClasses(), testInstance.getTrainingPercent(), 42L); @@ -99,71 +104,85 @@ protected Writeable.Reader instanceReader() { public void testConstructor_GivenTrainingPercentIsLessThanOne() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 0.999, randomLong())); + () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 0.999, randomLong())); assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); } public void testConstructor_GivenTrainingPercentIsGreaterThan100() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0001, randomLong())); + () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0001, randomLong())); assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]")); } public void testConstructor_GivenNumTopClassesIsLessThanZero() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", -1, 1.0, randomLong())); + () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, -1, 1.0, randomLong())); assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]")); } public void testConstructor_GivenNumTopClassesIsGreaterThan1000() { ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, - () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1001, 1.0, randomLong())); + () -> new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1001, 1.0, randomLong())); assertThat(e.getMessage(), equalTo("[num_top_classes] must be an integer in [0, 1000]")); } public void testGetPredictionFieldName() { - Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0, randomLong()); + Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong()); assertThat(classification.getPredictionFieldName(), equalTo("result")); - classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, 3, 50.0, randomLong()); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, null, null, 3, 50.0, randomLong()); assertThat(classification.getPredictionFieldName(), equalTo("foo_prediction")); } + public void testClassAssignmentObjective() { + Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", + Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY, 7, 1.0, randomLong()); + assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY)); + + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", + Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL, 7, 1.0, randomLong()); + assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL)); + + // class_assignment_objective == null, default applied + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong()); + assertThat(classification.getClassAssignmentObjective(), equalTo(Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL)); + } + public void testGetNumTopClasses() { - Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 7, 1.0, randomLong()); + Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 7, 1.0, randomLong()); assertThat(classification.getNumTopClasses(), equalTo(7)); // Boundary condition: num_top_classes == 0 - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 0, 1.0, randomLong()); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 0, 1.0, randomLong()); assertThat(classification.getNumTopClasses(), equalTo(0)); // Boundary condition: num_top_classes == 1000 - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 1000, 1.0, randomLong()); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1000, 1.0, randomLong()); assertThat(classification.getNumTopClasses(), equalTo(1000)); // num_top_classes == null, default applied - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 1.0, randomLong()); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, null, 1.0, randomLong()); assertThat(classification.getNumTopClasses(), equalTo(2)); } public void testGetTrainingPercent() { - Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 50.0, randomLong()); + Classification classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 50.0, randomLong()); assertThat(classification.getTrainingPercent(), equalTo(50.0)); // Boundary condition: training_percent == 1.0 - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 1.0, randomLong()); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 1.0, randomLong()); assertThat(classification.getTrainingPercent(), equalTo(1.0)); // Boundary condition: training_percent == 100.0 - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, 100.0, randomLong()); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, 100.0, randomLong()); assertThat(classification.getTrainingPercent(), equalTo(100.0)); // training_percent == null, default applied - classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", 3, null, randomLong()); + classification = new Classification("foo", BOOSTED_TREE_PARAMS, "result", null, 3, null, randomLong()); assertThat(classification.getTrainingPercent(), equalTo(100.0)); } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index c3242e535a6d7..f73b2c7ab457a 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -87,6 +87,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws null, null, null, + null, null)); registerAnalytics(config); putAnalytics(config); diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml index afd6701af145c..88244e0381906 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml @@ -1834,6 +1834,7 @@ setup: "eta": 0.5, "maximum_number_trees": 400, "feature_bag_fraction": 0.3, + "class_assignment_objective": "maximize_accuracy", "training_percent": 60.3, "randomize_seed": 24 } @@ -1853,6 +1854,7 @@ setup: "prediction_field_name": "foo_prediction", "training_percent": 60.3, "randomize_seed": 24, + "class_assignment_objective": "maximize_accuracy", "num_top_classes": 2 } }} From edfebc27dfa95dca28110351d7e25bcaa0add610 Mon Sep 17 00:00:00 2001 From: Tom Veasey Date: Tue, 25 Feb 2020 13:49:11 +0000 Subject: [PATCH 02/11] Assorted fixes --- .../java/org/elasticsearch/client/MachineLearningIT.java | 4 ++-- .../client/documentation/MlClientDocumentationIT.java | 5 +++-- .../core/ml/dataframe/analyses/ClassificationTests.java | 3 +++ .../xpack/ml/integration/ClassificationIT.java | 6 +++--- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java index 01badb24ca7dc..b3fddc1039d46 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java @@ -1336,7 +1336,8 @@ public void testPutDataFrameAnalyticsConfig_GivenClassification() throws Excepti .setPredictionFieldName("my_dependent_variable_prediction") .setTrainingPercent(80.0) .setRandomizeSeed(42L) - .classAssignmentObjective("maximize_accuracy") + .setClassAssignmentObjective( + org.elasticsearch.client.ml.dataframe.Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY) .setNumTopClasses(1) .setLambda(1.0) .setGamma(1.0) @@ -1363,7 +1364,6 @@ public void testPutDataFrameAnalyticsConfig_GivenClassification() throws Excepti assertThat(createdConfig.getAnalyzedFields(), equalTo(config.getAnalyzedFields())); assertThat(createdConfig.getModelMemoryLimit(), equalTo(ByteSizeValue.parseBytesSizeValue("1gb", ""))); // default value assertThat(createdConfig.getDescription(), equalTo("this is a classification")); - assertThat(createdConfig.getClassAssignmentObjective(), equalTo(config.getClassificationObjective())); } public void testGetDataFrameAnalyticsConfig_SingleConfig() throws Exception { diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java index 9c733c54a0d29..89467b741d8d1 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java @@ -139,6 +139,7 @@ import org.elasticsearch.client.ml.datafeed.DatafeedStats; import org.elasticsearch.client.ml.datafeed.DatafeedUpdate; import org.elasticsearch.client.ml.datafeed.DelayedDataCheckConfig; +import org.elasticsearch.client.ml.dataframe.Classification; import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis; import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig; import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsDest; @@ -2969,7 +2970,7 @@ public void testPutDataFrameAnalytics() throws Exception { // end::put-data-frame-analytics-outlier-detection-customized // tag::put-data-frame-analytics-classification - DataFrameAnalysis classification = org.elasticsearch.client.ml.dataframe.Classification.builder("my_dependent_variable") // <1> + DataFrameAnalysis classification = Classification.builder("my_dependent_variable") // <1> .setLambda(1.0) // <2> .setGamma(5.5) // <3> .setEta(5.5) // <4> @@ -2979,7 +2980,7 @@ public void testPutDataFrameAnalytics() throws Exception { .setPredictionFieldName("my_prediction_field_name") // <8> .setTrainingPercent(50.0) // <9> .setRandomizeSeed(1234L) // <10> - .setClassAssignmentObjective("maximize_accuracy") // <11> + .setClassAssignmentObjective(Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY) // <11> .setNumTopClasses(1) // <12> .build(); // end::put-data-frame-analytics-classification diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java index 87801f5c4458b..67dd3a3e03eee 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java @@ -197,6 +197,7 @@ public void testGetParams() { equalTo( Map.of( "dependent_variable", "foo", + "class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL, "num_top_classes", 2, "prediction_field_name", "foo_prediction", "prediction_field_type", "bool"))); @@ -205,6 +206,7 @@ public void testGetParams() { equalTo( Map.of( "dependent_variable", "bar", + "class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL, "num_top_classes", 2, "prediction_field_name", "bar_prediction", "prediction_field_type", "int"))); @@ -213,6 +215,7 @@ public void testGetParams() { equalTo( Map.of( "dependent_variable", "baz", + "class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL, "num_top_classes", 2, "prediction_field_name", "baz_prediction", "prediction_field_type", "string"))); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index f73b2c7ab457a..6d8ab841c182b 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -181,7 +181,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(String jobId, sourceIndex, destIndex, null, - new Classification(dependentVariable, BoostedTreeParams.builder().build(), null, numTopClasses, 50.0, null)); + new Classification(dependentVariable, BoostedTreeParams.builder().build(), null, null, numTopClasses, 50.0, null)); registerAnalytics(config); putAnalytics(config); @@ -426,7 +426,7 @@ public void testTwoJobsWithSameRandomizeSeedUseSameTrainingSet() throws Exceptio .build(); DataFrameAnalyticsConfig firstJob = buildAnalytics(firstJobId, sourceIndex, firstJobDestIndex, null, - new Classification(dependentVariable, boostedTreeParams, null, 1, 50.0, null)); + new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, null)); registerAnalytics(firstJob); putAnalytics(firstJob); @@ -435,7 +435,7 @@ public void testTwoJobsWithSameRandomizeSeedUseSameTrainingSet() throws Exceptio long randomizeSeed = ((Classification) firstJob.getAnalysis()).getRandomizeSeed(); DataFrameAnalyticsConfig secondJob = buildAnalytics(secondJobId, sourceIndex, secondJobDestIndex, null, - new Classification(dependentVariable, boostedTreeParams, null, 1, 50.0, randomizeSeed)); + new Classification(dependentVariable, boostedTreeParams, null, null, 1, 50.0, randomizeSeed)); registerAnalytics(secondJob); putAnalytics(secondJob); From beaf0604114c81529211a72de5db637510661438 Mon Sep 17 00:00:00 2001 From: Tom Veasey Date: Wed, 4 Mar 2020 14:21:50 +0000 Subject: [PATCH 03/11] Typo --- docs/reference/ml/ml-shared.asciidoc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/reference/ml/ml-shared.asciidoc b/docs/reference/ml/ml-shared.asciidoc index ecdbd922768ca..d002fe2c60be8 100644 --- a/docs/reference/ml/ml-shared.asciidoc +++ b/docs/reference/ml/ml-shared.asciidoc @@ -286,8 +286,8 @@ end::chunking-config[] tag::class-assignment-objective[] Defines the objective to optimize when assigning class labels. Available objectives are maximize_accuracy and maximize_minimum_recall. When maximizing -accuracy class labels are choosen to maximize the number of correct predictions. -When maximizing minimum recall labels are choosen to maximize the minimum recall +accuracy class labels are chosen to maximize the number of correct predictions. +When maximizing minimum recall labels are chosen to maximize the minimum recall for any class. Defaults to maximize_minimum_recall. end::class-assignment-objective[] From d635bde1762e1cbff76d1a95dd168b26807f10b0 Mon Sep 17 00:00:00 2001 From: Tom Veasey Date: Wed, 4 Mar 2020 14:22:06 +0000 Subject: [PATCH 04/11] Backwards compatibility tests --- .../90_ml_data_frame_analytics_crud.yml | 25 ++++++++++++++++++ .../90_ml_data_frame_analytics_crud.yml | 23 ++++++++++++++++ .../90_ml_data_frame_analytics_crud.yml | 26 +++++++++++++++++++ 3 files changed, 74 insertions(+) diff --git a/x-pack/qa/rolling-upgrade/src/test/resources/rest-api-spec/test/mixed_cluster/90_ml_data_frame_analytics_crud.yml b/x-pack/qa/rolling-upgrade/src/test/resources/rest-api-spec/test/mixed_cluster/90_ml_data_frame_analytics_crud.yml index af304afc57db6..5049f3360a42e 100644 --- a/x-pack/qa/rolling-upgrade/src/test/resources/rest-api-spec/test/mixed_cluster/90_ml_data_frame_analytics_crud.yml +++ b/x-pack/qa/rolling-upgrade/src/test/resources/rest-api-spec/test/mixed_cluster/90_ml_data_frame_analytics_crud.yml @@ -92,6 +92,31 @@ - match: { data_frame_analytics.0.id: "old_cluster_regression_job" } - match: { data_frame_analytics.0.state: "stopped" } +--- +"Get old classification job": + + - do: + ml.get_data_frame_analytics: + id: "old_cluster_classification_job" + - match: { count: 1 } + - match: { data_frame_analytics.0.id: "old_cluster_classification_job" } + - match: { data_frame_analytics.0.source.index: ["bwc_ml_classification_job_source"] } + - match: { data_frame_analytics.0.source.query: {"term": { "user.keyword": "Kimchy" }} } + - match: { data_frame_analytics.0.dest.index: "old_cluster_classification_job_results" } + - match: { data_frame_analytics.0.analysis.classification.dependent_variable: "foo" } + - match: { data_frame_analytics.0.analysis.classification.training_percent: 100.0 } + - is_true: data_frame_analytics.0.analysis.classification.randomize_seed + +--- +"Get old classification job stats": + + - do: + ml.get_data_frame_analytics_stats: + id: "old_cluster_classification_job" + - match: { count: 1 } + - match: { data_frame_analytics.0.id: "old_cluster_classification_job" } + - match: { data_frame_analytics.0.state: "stopped" } + --- "Put an outlier_detection job on the mixed cluster": diff --git a/x-pack/qa/rolling-upgrade/src/test/resources/rest-api-spec/test/old_cluster/90_ml_data_frame_analytics_crud.yml b/x-pack/qa/rolling-upgrade/src/test/resources/rest-api-spec/test/old_cluster/90_ml_data_frame_analytics_crud.yml index fe160bba15f23..13b1382fe3e07 100644 --- a/x-pack/qa/rolling-upgrade/src/test/resources/rest-api-spec/test/old_cluster/90_ml_data_frame_analytics_crud.yml +++ b/x-pack/qa/rolling-upgrade/src/test/resources/rest-api-spec/test/old_cluster/90_ml_data_frame_analytics_crud.yml @@ -64,3 +64,26 @@ setup: } } - match: { id: "old_cluster_regression_job" } + +--- +"Put classification job on the old cluster": + + - do: + ml.put_data_frame_analytics: + id: "old_cluster_classification_job" + body: > + { + "source": { + "index": "bwc_ml_classification_job_source", + "query": {"term" : { "user.keyword" : "Kimchy" }} + }, + "dest": { + "index": "old_cluster_classification_job_results" + }, + "analysis": { + "classification":{ + "dependent_variable": "foo" + } + } + } + - match: { id: "old_cluster_classification_job" } diff --git a/x-pack/qa/rolling-upgrade/src/test/resources/rest-api-spec/test/upgraded_cluster/90_ml_data_frame_analytics_crud.yml b/x-pack/qa/rolling-upgrade/src/test/resources/rest-api-spec/test/upgraded_cluster/90_ml_data_frame_analytics_crud.yml index 14438883f0da1..b197d0c5cf433 100644 --- a/x-pack/qa/rolling-upgrade/src/test/resources/rest-api-spec/test/upgraded_cluster/90_ml_data_frame_analytics_crud.yml +++ b/x-pack/qa/rolling-upgrade/src/test/resources/rest-api-spec/test/upgraded_cluster/90_ml_data_frame_analytics_crud.yml @@ -52,6 +52,32 @@ - match: { data_frame_analytics.0.id: "old_cluster_regression_job" } - match: { data_frame_analytics.0.state: "stopped" } +--- +"Get old classification job": + + - do: + ml.get_data_frame_analytics: + id: "old_cluster_classification_job" + - match: { count: 1 } + - match: { data_frame_analytics.0.id: "old_cluster_classification_job" } + - match: { data_frame_analytics.0.source.index: ["bwc_ml_classification_job_source"] } + - match: { data_frame_analytics.0.source.query: {"term": { "user.keyword": "Kimchy" }} } + - match: { data_frame_analytics.0.dest.index: "old_cluster_classification_job_results" } + - match: { data_frame_analytics.0.analysis.classification.dependent_variable: "foo" } + - match: { data_frame_analytics.0.analysis.classification.training_percent: 100.0 } + - match: { data_frame_analytics.0.analysis.classification.class_assignment_objective: "maximize_minimum_recall" } + - is_true: data_frame_analytics.0.analysis.classification.randomize_seed + +--- +"Get old classification job stats": + + - do: + ml.get_data_frame_analytics_stats: + id: "old_cluster_classification_job" + - match: { count: 1 } + - match: { data_frame_analytics.0.id: "old_cluster_classification_job" } + - match: { data_frame_analytics.0.state: "stopped" } + --- "Get mixed cluster outlier_detection job": - skip: From 854fdf734bb5841364b098c79fb9a44e1b26e2eb Mon Sep 17 00:00:00 2001 From: Tom Veasey Date: Wed, 4 Mar 2020 14:25:21 +0000 Subject: [PATCH 05/11] Quotes --- docs/reference/ml/ml-shared.asciidoc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/reference/ml/ml-shared.asciidoc b/docs/reference/ml/ml-shared.asciidoc index 38b3ddf53fffc..f05add6a569ea 100644 --- a/docs/reference/ml/ml-shared.asciidoc +++ b/docs/reference/ml/ml-shared.asciidoc @@ -340,7 +340,7 @@ end::chunking-config[] tag::class-assignment-objective[] Defines the objective to optimize when assigning class labels. Available -objectives are maximize_accuracy and maximize_minimum_recall. When maximizing +objectives are `maximize_accuracy` and `maximize_minimum_recall`. When maximizing accuracy class labels are chosen to maximize the number of correct predictions. When maximizing minimum recall labels are chosen to maximize the minimum recall for any class. Defaults to maximize_minimum_recall. From 4e0ded77b33c9699553bc1a6e5df1216433fc3a0 Mon Sep 17 00:00:00 2001 From: Tom Veasey Date: Wed, 4 Mar 2020 15:47:04 +0000 Subject: [PATCH 06/11] Move new parameter to end of constructor --- .../client/ml/dataframe/Classification.java | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java index 538cd43edf273..4e549746f9025 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java @@ -69,9 +69,9 @@ public static Builder builder(String dependentVariable) { (Integer) a[6], (String) a[7], (Double) a[8], - (ClassAssignmentObjective) a[9], - (Integer) a[10], - (Long) a[11])); + (Integer) a[9], + (Long) a[10], + (ClassAssignmentObjective) a[11])); static { PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE); @@ -109,8 +109,8 @@ public static Builder builder(String dependentVariable) { private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta, @Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable Integer numTopFeatureImportanceValues, @Nullable String predictionFieldName, - @Nullable Double trainingPercent, @Nullable ClassAssignmentObjective classAssignmentObjective, - @Nullable Integer numTopClasses, @Nullable Long randomizeSeed) { + @Nullable Double trainingPercent, @Nullable Integer numTopClasses, @Nullable Long randomizeSeed, + @Nullable ClassAssignmentObjective classAssignmentObjective) { this.dependentVariable = Objects.requireNonNull(dependentVariable); this.lambda = lambda; this.gamma = gamma; @@ -338,8 +338,8 @@ public Builder setNumTopClasses(Integer numTopClasses) { public Classification build() { return new Classification(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, - numTopFeatureImportanceValues, predictionFieldName, trainingPercent, classAssignmentObjective, - numTopClasses, randomizeSeed); + numTopFeatureImportanceValues, predictionFieldName, trainingPercent, numTopClasses, randomizeSeed, + classAssignmentObjective); } } } From a1b5e2b90b30d20cebbf9febe8a78be8f73e9404 Mon Sep 17 00:00:00 2001 From: Tom Veasey Date: Wed, 4 Mar 2020 15:52:06 +0000 Subject: [PATCH 07/11] Only in version 8.0 at present --- .../xpack/core/ml/dataframe/analyses/Classification.java | 6 +++--- .../core/ml/dataframe/analyses/ClassificationTests.java | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java index 013489f2350bd..e0a4fbae860b1 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java @@ -135,7 +135,7 @@ public Classification(StreamInput in) throws IOException { dependentVariable = in.readString(); boostedTreeParams = new BoostedTreeParams(in); predictionFieldName = in.readOptionalString(); - if (in.getVersion().onOrAfter(Version.V_7_7_0)) { + if (in.getVersion().onOrAfter(Version.V_8_0_0)) { classAssignmentObjective = in.readEnum(ClassAssignmentObjective.class); } else { classAssignmentObjective = ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL; @@ -187,7 +187,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(dependentVariable); boostedTreeParams.writeTo(out); out.writeOptionalString(predictionFieldName); - if (out.getVersion().onOrAfter(Version.V_7_7_0)) { + if (out.getVersion().onOrAfter(Version.V_8_0_0)) { out.writeEnum(classAssignmentObjective); } out.writeOptionalVInt(numTopClasses); @@ -204,7 +204,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable); boostedTreeParams.toXContent(builder, params); - if (version.onOrAfter(Version.V_7_7_0)) { + if (version.onOrAfter(Version.V_8_0_0)) { builder.field(CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), classAssignmentObjective); } builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java index 07049019a08b2..e7b2dbbc09f95 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java @@ -67,7 +67,7 @@ public static Classification mutateForVersion(Classification instance, Version v return new Classification(instance.getDependentVariable(), BoostedTreeParamsTests.mutateForVersion(instance.getBoostedTreeParams(), version), instance.getPredictionFieldName(), - version.onOrAfter(Version.V_7_7_0) ? instance.getClassAssignmentObjective() : null, + version.onOrAfter(Version.V_8_0_0) ? instance.getClassAssignmentObjective() : null, instance.getNumTopClasses(), instance.getTrainingPercent(), instance.getRandomizeSeed()); From ed56c5058993cffd893f0adc3ed512649eac62e1 Mon Sep 17 00:00:00 2001 From: Tom Veasey Date: Thu, 5 Mar 2020 11:00:03 +0000 Subject: [PATCH 08/11] Correct order in which fields are added to the parser --- .../org/elasticsearch/client/ml/dataframe/Classification.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java index 4e549746f9025..c5ca64fd73ba3 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java @@ -83,14 +83,14 @@ public static Builder builder(String dependentVariable) { PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_FEATURE_IMPORTANCE_VALUES); PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME); PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT); + PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES); + PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED); PARSER.declareField(ConstructingObjectParser.optionalConstructorArg(), p -> { if (p.currentToken() == XContentParser.Token.VALUE_STRING) { return ClassAssignmentObjective.fromString(p.text()); } throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]"); }, CLASS_ASSIGNMENT_OBJECTIVE, ObjectParser.ValueType.STRING); - PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES); - PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED); } private final String dependentVariable; From 08b2614ab1de436766f7d7717a4b43d68d799e62 Mon Sep 17 00:00:00 2001 From: Tom Veasey Date: Thu, 5 Mar 2020 11:09:48 +0000 Subject: [PATCH 09/11] Missing setup in bwc tests --- .../old_cluster/90_ml_data_frame_analytics_crud.yml | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/x-pack/qa/rolling-upgrade/src/test/resources/rest-api-spec/test/old_cluster/90_ml_data_frame_analytics_crud.yml b/x-pack/qa/rolling-upgrade/src/test/resources/rest-api-spec/test/old_cluster/90_ml_data_frame_analytics_crud.yml index 13b1382fe3e07..923a56395e8a4 100644 --- a/x-pack/qa/rolling-upgrade/src/test/resources/rest-api-spec/test/old_cluster/90_ml_data_frame_analytics_crud.yml +++ b/x-pack/qa/rolling-upgrade/src/test/resources/rest-api-spec/test/old_cluster/90_ml_data_frame_analytics_crud.yml @@ -19,6 +19,16 @@ setup: "user": "Kimchy" } + - do: + index: + index: bwc_ml_classification_job_source + body: > + { + "numeric_field_1": 1.0, + "foo": "a", + "user": "Kimchy" + } + - do: indices.refresh: index: bwc_ml_* From 60189eeb9a3474d9ce90f9957ffb4a96863fff39 Mon Sep 17 00:00:00 2001 From: Tom Veasey Date: Thu, 5 Mar 2020 12:44:08 +0000 Subject: [PATCH 10/11] Always write to JSON --- .../xpack/core/ml/dataframe/analyses/Classification.java | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java index e0a4fbae860b1..0f365ad671b76 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java @@ -204,9 +204,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); builder.field(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable); boostedTreeParams.toXContent(builder, params); - if (version.onOrAfter(Version.V_8_0_0)) { - builder.field(CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), classAssignmentObjective); - } + builder.field(CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), classAssignmentObjective); builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses); if (predictionFieldName != null) { builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName); From 47c3f98c814881ecbe82f3377aee278dd6807fa7 Mon Sep 17 00:00:00 2001 From: Tom Veasey Date: Thu, 12 Mar 2020 17:18:58 +0000 Subject: [PATCH 11/11] Test fix --- .../rest-api-spec/test/ml/data_frame_analytics_crud.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml index 9d25a05534bb0..a5a99b30391ea 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/data_frame_analytics_crud.yml @@ -1898,6 +1898,7 @@ setup: "prediction_field_name": "foo_prediction", "training_percent": 100.0, "randomize_seed": 24, + "class_assignment_objective": "maximize_minimum_recall", "num_top_classes": 2 } }}