From 367cfd654f5c32105c818b9a782a5ee8c3ad6619 Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Thu, 12 Mar 2020 15:31:14 +0200 Subject: [PATCH 1/3] [ML] Extend classification to support multiple classes Prepares classification analysis to support more than just two classes. It introduces a new parameter to the process config which dictates the `num_classes` to the process. It also changes the max classes limit to `30` provisionally. --- .../ml/dataframe/analyses/Classification.java | 9 +- .../dataframe/analyses/DataFrameAnalysis.java | 28 ++- .../dataframe/analyses/OutlierDetection.java | 2 +- .../ml/dataframe/analyses/Regression.java | 2 +- .../analyses/ClassificationTests.java | 47 ++++- .../ml/integration/ClassificationIT.java | 2 + .../scroll/TimeBasedExtractedFields.java | 3 +- .../extractor/ExtractedFieldsDetector.java | 17 +- .../process/AnalyticsProcessConfig.java | 30 +++- .../xpack/ml/extractor/ExtractedFields.java | 14 +- .../DataFrameDataExtractorTests.java | 6 +- .../ExtractedFieldsDetectorTests.java | 4 +- .../process/AnalyticsProcessConfigTests.java | 170 ++++++++++++++++++ .../ml/extractor/ExtractedFieldsTests.java | 8 +- 14 files changed, 299 insertions(+), 43 deletions(-) create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfigTests.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java index 0f365ad671b76..381f574982db2 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java @@ -46,6 +46,8 @@ public class Classification implements DataFrameAnalysis { private static final String STATE_DOC_ID_SUFFIX = "_classification_state#1"; + private static final String NUM_CLASSES = "num_classes"; + private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); private static final ConstructingObjectParser STRICT_PARSER = createParser(false); @@ -218,7 +220,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } @Override - public Map getParams(Map> extractedFields) { + public Map getParams(FieldInfo fieldInfo) { Map params = new HashMap<>(); params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable); params.putAll(boostedTreeParams.getParams()); @@ -227,10 +229,11 @@ public Map getParams(Map> extractedFields) { if (predictionFieldName != null) { params.put(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName); } - String predictionFieldType = getPredictionFieldType(extractedFields.get(dependentVariable)); + String predictionFieldType = getPredictionFieldType(fieldInfo.getTypes(dependentVariable)); if (predictionFieldType != null) { params.put(PREDICTION_FIELD_TYPE, predictionFieldType); } + params.put(NUM_CLASSES, fieldInfo.getCardinality(dependentVariable)); return params; } @@ -272,7 +275,7 @@ public List getRequiredFields() { @Override public List getFieldCardinalityConstraints() { // This restriction is due to the fact that currently the C++ backend only supports binomial classification. - return Collections.singletonList(FieldCardinalityConstraint.between(dependentVariable, 2, 2)); + return Collections.singletonList(FieldCardinalityConstraint.between(dependentVariable, 2, 30)); } @SuppressWarnings("unchecked") diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java index 664b38e4fc05d..941224dc30a67 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.core.ml.dataframe.analyses; +import org.elasticsearch.common.Nullable; import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.common.xcontent.ToXContentObject; @@ -16,9 +17,9 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable { /** * @return The analysis parameters as a map - * @param extractedFields map of (name, types) for all the extracted fields + * @param fieldInfo Information about the fields like types and cardinalities */ - Map getParams(Map> extractedFields); + Map getParams(FieldInfo fieldInfo); /** * @return {@code true} if this analysis supports fields with categorical values (i.e. text, keyword, ip) @@ -64,4 +65,27 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable { * Returns the document id for the analysis state */ String getStateDocId(String jobId); + + /** + * Summarizes information about the fields that is necessary for analysis to generate + * the parameters needed for the process configuration. + */ + interface FieldInfo { + + /** + * Returns the types for the given field or {@code null} if the field is unknown + * @param field the field whose types to return + * @return the types for the given field or {@code null} if the field is unknown + */ + @Nullable + Set getTypes(String field); + + /** + * Returns the cardinality of the given field or {@code null} if there is no cardinality for that field + * @param field the field whose cardinality to get + * @return the cardinality of the given field or {@code null} if there is no cardinality for that field + */ + @Nullable + Long getCardinality(String field); + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java index 654d5ba4d1a29..2c83afa87808d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java @@ -192,7 +192,7 @@ public int hashCode() { } @Override - public Map getParams(Map> extractedFields) { + public Map getParams(FieldInfo fieldInfo) { Map params = new HashMap<>(); if (nNeighbors != null) { params.put(N_NEIGHBORS.getPreferredName(), nNeighbors); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java index 86f8039090c74..d8c490ddcb804 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java @@ -155,7 +155,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } @Override - public Map getParams(Map> extractedFields) { + public Map getParams(FieldInfo fieldInfo) { Map params = new HashMap<>(); params.put(DEPENDENT_VARIABLE.getPreferredName(), dependentVariable); params.putAll(boostedTreeParams.getParams()); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java index e7b2dbbc09f95..7429fe08d180d 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java @@ -187,38 +187,46 @@ public void testGetTrainingPercent() { } public void testGetParams() { - Map> extractedFields = + DataFrameAnalysis.FieldInfo fieldInfo = new TestFieldInfo( Map.of( "foo", Set.of(BooleanFieldMapper.CONTENT_TYPE), "bar", Set.of(NumberFieldMapper.NumberType.LONG.typeName()), - "baz", Set.of(KeywordFieldMapper.CONTENT_TYPE)); + "baz", Set.of(KeywordFieldMapper.CONTENT_TYPE)), + Map.of( + "foo", 10L, + "bar", 20L, + "baz", 30L) + ); assertThat( - new Classification("foo").getParams(extractedFields), + new Classification("foo").getParams(fieldInfo), equalTo( Map.of( "dependent_variable", "foo", "class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL, "num_top_classes", 2, "prediction_field_name", "foo_prediction", - "prediction_field_type", "bool"))); + "prediction_field_type", "bool", + "num_classes", 10L))); assertThat( - new Classification("bar").getParams(extractedFields), + new Classification("bar").getParams(fieldInfo), equalTo( Map.of( "dependent_variable", "bar", "class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL, "num_top_classes", 2, "prediction_field_name", "bar_prediction", - "prediction_field_type", "int"))); + "prediction_field_type", "int", + "num_classes", 20L))); assertThat( - new Classification("baz").getParams(extractedFields), + new Classification("baz").getParams(fieldInfo), equalTo( Map.of( "dependent_variable", "baz", "class_assignment_objective", Classification.ClassAssignmentObjective.MAXIMIZE_MINIMUM_RECALL, "num_top_classes", 2, "prediction_field_name", "baz_prediction", - "prediction_field_type", "string"))); + "prediction_field_type", "string", + "num_classes", 30L))); } public void testRequiredFieldsIsNonEmpty() { @@ -232,7 +240,7 @@ public void testFieldCardinalityLimitsIsNonEmpty() { assertThat(constraints.size(), equalTo(1)); assertThat(constraints.get(0).getField(), equalTo(classification.getDependentVariable())); assertThat(constraints.get(0).getLowerBound(), equalTo(2L)); - assertThat(constraints.get(0).getUpperBound(), equalTo(2L)); + assertThat(constraints.get(0).getUpperBound(), equalTo(30L)); } public void testGetExplicitlyMappedFields() { @@ -331,4 +339,25 @@ public void testExtractJobIdFromStateDoc() { protected Classification mutateInstanceForVersion(Classification instance, Version version) { return mutateForVersion(instance, version); } + + private static class TestFieldInfo implements DataFrameAnalysis.FieldInfo { + + private final Map> fieldTypes; + private final Map fieldCardinalities; + + private TestFieldInfo(Map> fieldTypes, Map fieldCardinalities) { + this.fieldTypes = fieldTypes; + this.fieldCardinalities = fieldCardinalities; + } + + @Override + public Set getTypes(String field) { + return fieldTypes.get(field); + } + + @Override + public Long getCardinality(String field) { + return fieldCardinalities.get(field); + } + } } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index d2f90eed635ea..a58418bbd35ef 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -322,9 +322,11 @@ public void testStopAndRestart() throws Exception { assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField); } + @AwaitsFix(bugUrl = "Muted until ml-cpp supports multiple classes") public void testDependentVariableCardinalityTooHighError() throws Exception { initialize("cardinality_too_high"); indexData(sourceIndex, 6, 5, KEYWORD_FIELD); + // Index one more document with a class different than the two already used. client().execute(IndexAction.INSTANCE, new IndexRequest(sourceIndex) .source(KEYWORD_FIELD, "fox") diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/TimeBasedExtractedFields.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/TimeBasedExtractedFields.java index 8202c0ef3d28e..eb13f2395ef8d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/TimeBasedExtractedFields.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/datafeed/extractor/scroll/TimeBasedExtractedFields.java @@ -14,6 +14,7 @@ import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.Set; @@ -27,7 +28,7 @@ public class TimeBasedExtractedFields extends ExtractedFields { private final ExtractedField timeField; public TimeBasedExtractedFields(ExtractedField timeField, List allFields) { - super(allFields); + super(allFields, Collections.emptyMap()); if (!allFields.contains(timeField)) { throw new IllegalArgumentException("timeField should also be contained in allFields"); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java index 0c885806d730c..2254a3de0a60c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetector.java @@ -58,15 +58,15 @@ public class ExtractedFieldsDetector { private final DataFrameAnalyticsConfig config; private final int docValueFieldsLimit; private final FieldCapabilitiesResponse fieldCapabilitiesResponse; - private final Map fieldCardinalities; + private final Map cardinalitiesForFieldsWithConstraints; ExtractedFieldsDetector(String[] index, DataFrameAnalyticsConfig config, int docValueFieldsLimit, - FieldCapabilitiesResponse fieldCapabilitiesResponse, Map fieldCardinalities) { + FieldCapabilitiesResponse fieldCapabilitiesResponse, Map cardinalitiesForFieldsWithConstraints) { this.index = Objects.requireNonNull(index); this.config = Objects.requireNonNull(config); this.docValueFieldsLimit = docValueFieldsLimit; this.fieldCapabilitiesResponse = Objects.requireNonNull(fieldCapabilitiesResponse); - this.fieldCardinalities = Objects.requireNonNull(fieldCardinalities); + this.cardinalitiesForFieldsWithConstraints = Objects.requireNonNull(cardinalitiesForFieldsWithConstraints); } public Tuple> detect() { @@ -286,12 +286,13 @@ private void checkRequiredFields(Set fields) { private void checkFieldsWithCardinalityLimit() { for (FieldCardinalityConstraint constraint : config.getAnalysis().getFieldCardinalityConstraints()) { - constraint.check(fieldCardinalities.get(constraint.getField())); + constraint.check(cardinalitiesForFieldsWithConstraints.get(constraint.getField())); } } private ExtractedFields detectExtractedFields(Set fields, Set fieldSelection) { - ExtractedFields extractedFields = ExtractedFields.build(fields, Collections.emptySet(), fieldCapabilitiesResponse); + ExtractedFields extractedFields = ExtractedFields.build(fields, Collections.emptySet(), fieldCapabilitiesResponse, + cardinalitiesForFieldsWithConstraints); boolean preferSource = extractedFields.getDocValueFields().size() > docValueFieldsLimit; extractedFields = deduplicateMultiFields(extractedFields, preferSource, fieldSelection); if (preferSource) { @@ -321,7 +322,7 @@ private ExtractedFields deduplicateMultiFields(ExtractedFields extractedFields, chooseMultiFieldOrParent(preferSource, requiredFields, parent, multiField, fieldSelection)); } } - return new ExtractedFields(new ArrayList<>(nameOrParentToField.values())); + return new ExtractedFields(new ArrayList<>(nameOrParentToField.values()), cardinalitiesForFieldsWithConstraints); } private ExtractedField chooseMultiFieldOrParent(boolean preferSource, Set requiredFields, ExtractedField parent, @@ -372,7 +373,7 @@ private ExtractedFields fetchFromSourceIfSupported(ExtractedFields extractedFiel for (ExtractedField field : extractedFields.getAllFields()) { adjusted.add(field.supportsFromSource() ? field.newFromSource() : field); } - return new ExtractedFields(adjusted); + return new ExtractedFields(adjusted, cardinalitiesForFieldsWithConstraints); } private ExtractedFields fetchBooleanFieldsAsIntegers(ExtractedFields extractedFields) { @@ -389,7 +390,7 @@ private ExtractedFields fetchBooleanFieldsAsIntegers(ExtractedFields extractedFi adjusted.add(field); } } - return new ExtractedFields(adjusted); + return new ExtractedFields(adjusted, cardinalitiesForFieldsWithConstraints); } private void addIncludedFields(ExtractedFields extractedFields, Set fieldSelection) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java index 714f63091801f..0daec5365ff48 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfig.java @@ -14,10 +14,9 @@ import java.io.IOException; import java.util.Objects; +import java.util.Optional; import java.util.Set; -import static java.util.stream.Collectors.toMap; - public class AnalyticsProcessConfig implements ToXContentObject { private static final String JOB_ID = "job_id"; @@ -93,12 +92,31 @@ private DataFrameAnalysisWrapper(DataFrameAnalysis analysis, ExtractedFields ext public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field("name", analysis.getWriteableName()); - builder.field( - "parameters", - analysis.getParams( - extractedFields.getAllFields().stream().collect(toMap(ExtractedField::getName, ExtractedField::getTypes)))); + builder.field("parameters", analysis.getParams(new AnalysisFieldInfo(extractedFields))); builder.endObject(); return builder; } } + + private static class AnalysisFieldInfo implements DataFrameAnalysis.FieldInfo { + + private final ExtractedFields extractedFields; + + AnalysisFieldInfo(ExtractedFields extractedFields) { + this.extractedFields = Objects.requireNonNull(extractedFields); + } + + @Override + public Set getTypes(String field) { + Optional extractedField = extractedFields.getAllFields().stream() + .filter(f -> f.getName().equals(field)) + .findAny(); + return extractedField.isPresent() ? extractedField.get().getTypes() : null; + } + + @Override + public Long getCardinality(String field) { + return extractedFields.getCardinalitiesForFieldsWithConstraints().get(field); + } + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/extractor/ExtractedFields.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/extractor/ExtractedFields.java index 3a36bb7ff76d0..ab314a5d21851 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/extractor/ExtractedFields.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/extractor/ExtractedFields.java @@ -28,12 +28,14 @@ public class ExtractedFields { private final List allFields; private final List docValueFields; private final String[] sourceFields; + private final Map cardinalitiesForFieldsWithConstraints; - public ExtractedFields(List allFields) { + public ExtractedFields(List allFields, Map cardinalitiesForFieldsWithConstraints) { this.allFields = Collections.unmodifiableList(allFields); this.docValueFields = filterFields(ExtractedField.Method.DOC_VALUE, allFields); this.sourceFields = filterFields(ExtractedField.Method.SOURCE, allFields).stream().map(ExtractedField::getSearchField) .toArray(String[]::new); + this.cardinalitiesForFieldsWithConstraints = Collections.unmodifiableMap(cardinalitiesForFieldsWithConstraints); } public List getAllFields() { @@ -48,14 +50,20 @@ public List getDocValueFields() { return docValueFields; } + public Map getCardinalitiesForFieldsWithConstraints() { + return cardinalitiesForFieldsWithConstraints; + } + private static List filterFields(ExtractedField.Method method, List fields) { return fields.stream().filter(field -> field.getMethod() == method).collect(Collectors.toList()); } public static ExtractedFields build(Collection allFields, Set scriptFields, - FieldCapabilitiesResponse fieldsCapabilities) { + FieldCapabilitiesResponse fieldsCapabilities, + Map cardinalitiesForFieldsWithConstraints) { ExtractionMethodDetector extractionMethodDetector = new ExtractionMethodDetector(scriptFields, fieldsCapabilities); - return new ExtractedFields(allFields.stream().map(field -> extractionMethodDetector.detect(field)).collect(Collectors.toList())); + return new ExtractedFields(allFields.stream().map(field -> extractionMethodDetector.detect(field)).collect(Collectors.toList()), + cardinalitiesForFieldsWithConstraints); } public static TimeField newTimeField(String name, ExtractedField.Method method) { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java index d54997f1dbdc3..b75392e03c2bb 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorTests.java @@ -81,7 +81,7 @@ public void setUpTests() { query = QueryBuilders.matchAllQuery(); extractedFields = new ExtractedFields(Arrays.asList( new DocValueField("field_1", Collections.singleton("keyword")), - new DocValueField("field_2", Collections.singleton("keyword")))); + new DocValueField("field_2", Collections.singleton("keyword"))), Collections.emptyMap()); scrollSize = 1000; headers = Collections.emptyMap(); @@ -299,7 +299,7 @@ public void testIncludeSourceIsFalseAndAtLeastOneSourceField() throws IOExceptio // Explicit cast of ExtractedField args necessary for Eclipse due to https://bugs.eclipse.org/bugs/show_bug.cgi?id=530915 extractedFields = new ExtractedFields(Arrays.asList( (ExtractedField) new DocValueField("field_1", Collections.singleton("keyword")), - (ExtractedField) new SourceField("field_2", Collections.singleton("text")))); + (ExtractedField) new SourceField("field_2", Collections.singleton("text"))), Collections.emptyMap()); TestExtractor dataExtractor = createExtractor(false, false); @@ -404,7 +404,7 @@ public void testGetCategoricalFields() { (ExtractedField) new DocValueField("field_integer", Collections.singleton("integer")), (ExtractedField) new DocValueField("field_long", Collections.singleton("long")), (ExtractedField) new DocValueField("field_keyword", Collections.singleton("keyword")), - (ExtractedField) new SourceField("field_text", Collections.singleton("text")))); + (ExtractedField) new SourceField("field_text", Collections.singleton("text"))), Collections.emptyMap()); TestExtractor dataExtractor = createExtractor(true, true); assertThat(dataExtractor.getCategoricalFields(OutlierDetectionTests.createRandom()), empty()); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java index 49a302a498b82..a7a5784c452f5 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorTests.java @@ -294,10 +294,10 @@ public void testDetect_GivenClassificationAndDependentVariableHasInvalidCardinal .build(); ExtractedFieldsDetector extractedFieldsDetector = new ExtractedFieldsDetector(SOURCE_INDEX, - buildClassificationConfig("some_keyword"), 100, fieldCapabilities, Collections.singletonMap("some_keyword", 3L)); + buildClassificationConfig("some_keyword"), 100, fieldCapabilities, Collections.singletonMap("some_keyword", 31L)); ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, extractedFieldsDetector::detect); - assertThat(e.getMessage(), equalTo("Field [some_keyword] must have at most [2] distinct values but there were at least [3]")); + assertThat(e.getMessage(), equalTo("Field [some_keyword] must have at most [30] distinct values but there were at least [31]")); } public void testDetect_GivenIgnoredField() { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfigTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfigTests.java new file mode 100644 index 0000000000000..a4db8de032af5 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessConfigTests.java @@ -0,0 +1,170 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.dataframe.process; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.common.xcontent.json.JsonXContent; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.OutlierDetection; +import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; +import org.elasticsearch.xpack.ml.extractor.DocValueField; +import org.elasticsearch.xpack.ml.extractor.ExtractedFields; +import org.junit.Before; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.hasEntry; +import static org.hamcrest.Matchers.hasKey; + +public class AnalyticsProcessConfigTests extends ESTestCase { + + private String jobId; + private long rows; + private int cols; + private ByteSizeValue memoryLimit; + private int threads; + private String resultsField; + private Set categoricalFields; + + @Before + public void setUpConfigParams() { + jobId = randomAlphaOfLength(10); + rows = randomNonNegativeLong(); + cols = randomIntBetween(1, 42000); + memoryLimit = new ByteSizeValue(randomNonNegativeLong(), ByteSizeUnit.BYTES); + threads = randomIntBetween(1, 8); + resultsField = randomAlphaOfLength(10); + + int categoricalFieldsSize = randomIntBetween(0, 5); + categoricalFields = new HashSet<>(); + for (int i = 0; i < categoricalFieldsSize; i++) { + categoricalFields.add(randomAlphaOfLength(10)); + } + } + + @SuppressWarnings("unchecked") + public void testToXContent_GivenOutlierDetection() throws IOException { + ExtractedFields extractedFields = new ExtractedFields(Arrays.asList( + new DocValueField("field_1", Collections.singleton("double")), + new DocValueField("field_2", Collections.singleton("float"))), Collections.emptyMap()); + DataFrameAnalysis analysis = new OutlierDetection.Builder().build(); + + AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields); + Map asMap = toMap(processConfig); + + assertRandomizedFields(asMap); + + assertThat(asMap, hasKey("analysis")); + Map analysisAsMap = (Map) asMap.get("analysis"); + assertThat(analysisAsMap, hasEntry("name", "outlier_detection")); + assertThat(analysisAsMap, hasKey("parameters")); + } + + @SuppressWarnings("unchecked") + public void testToXContent_GivenRegression() throws IOException { + ExtractedFields extractedFields = new ExtractedFields(Arrays.asList( + new DocValueField("field_1", Collections.singleton("double")), + new DocValueField("field_2", Collections.singleton("float")), + new DocValueField("test_dep_var", Collections.singleton("keyword"))), Collections.emptyMap()); + DataFrameAnalysis analysis = new Regression("test_dep_var"); + + AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields); + Map asMap = toMap(processConfig); + + assertRandomizedFields(asMap); + + assertThat(asMap, hasKey("analysis")); + Map analysisAsMap = (Map) asMap.get("analysis"); + assertThat(analysisAsMap, hasEntry("name", "regression")); + assertThat(analysisAsMap, hasKey("parameters")); + Map paramsAsMap = (Map) analysisAsMap.get("parameters"); + assertThat(paramsAsMap, hasEntry("dependent_variable", "test_dep_var")); + } + + @SuppressWarnings("unchecked") + public void testToXContent_GivenClassificationAndDepVarIsKeyword() throws IOException { + ExtractedFields extractedFields = new ExtractedFields(Arrays.asList( + new DocValueField("field_1", Collections.singleton("double")), + new DocValueField("field_2", Collections.singleton("float")), + new DocValueField("test_dep_var", Collections.singleton("keyword"))), Collections.singletonMap("test_dep_var", 5L)); + DataFrameAnalysis analysis = new Classification("test_dep_var"); + + AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields); + Map asMap = toMap(processConfig); + + assertRandomizedFields(asMap); + + assertThat(asMap, hasKey("analysis")); + Map analysisAsMap = (Map) asMap.get("analysis"); + assertThat(analysisAsMap, hasEntry("name", "classification")); + assertThat(analysisAsMap, hasKey("parameters")); + Map paramsAsMap = (Map) analysisAsMap.get("parameters"); + assertThat(paramsAsMap, hasEntry("dependent_variable", "test_dep_var")); + assertThat(paramsAsMap, hasEntry("prediction_field_type", "string")); + assertThat(paramsAsMap, hasEntry("num_classes", 5)); + } + + @SuppressWarnings("unchecked") + public void testToXContent_GivenClassificationAndDepVarIsInteger() throws IOException { + ExtractedFields extractedFields = new ExtractedFields(Arrays.asList( + new DocValueField("field_1", Collections.singleton("double")), + new DocValueField("field_2", Collections.singleton("float")), + new DocValueField("test_dep_var", Collections.singleton("integer"))), Collections.singletonMap("test_dep_var", 8L)); + DataFrameAnalysis analysis = new Classification("test_dep_var"); + + AnalyticsProcessConfig processConfig = createProcessConfig(analysis, extractedFields); + Map asMap = toMap(processConfig); + + assertRandomizedFields(asMap); + + assertThat(asMap, hasKey("analysis")); + Map analysisAsMap = (Map) asMap.get("analysis"); + assertThat(analysisAsMap, hasEntry("name", "classification")); + assertThat(analysisAsMap, hasKey("parameters")); + Map paramsAsMap = (Map) analysisAsMap.get("parameters"); + assertThat(paramsAsMap, hasEntry("dependent_variable", "test_dep_var")); + assertThat(paramsAsMap, hasEntry("prediction_field_type", "int")); + assertThat(paramsAsMap, hasEntry("num_classes", 8)); + } + + private AnalyticsProcessConfig createProcessConfig(DataFrameAnalysis analysis, ExtractedFields extractedFields) { + return new AnalyticsProcessConfig(jobId, rows, cols, memoryLimit, threads, resultsField, categoricalFields, analysis, + extractedFields); + } + + private static Map toMap(AnalyticsProcessConfig config) throws IOException { + try (XContentBuilder builder = JsonXContent.contentBuilder()) { + config.toXContent(builder, ToXContent.EMPTY_PARAMS); + return XContentHelper.convertToMap(JsonXContent.jsonXContent, Strings.toString(builder), false); + } + } + + @SuppressWarnings("unchecked") + private void assertRandomizedFields(Map configAsMap) { + assertThat(configAsMap, hasEntry("job_id", jobId)); + assertThat(configAsMap, hasEntry("rows", rows)); + assertThat(configAsMap, hasEntry("cols", cols)); + assertThat(configAsMap, hasEntry("memory_limit", memoryLimit.getBytes())); + assertThat(configAsMap, hasEntry("threads", threads)); + assertThat(configAsMap, hasEntry("results_field", resultsField)); + assertThat(configAsMap, hasKey("categorical_fields")); + assertThat((List) configAsMap.get("categorical_fields"), containsInAnyOrder(categoricalFields.toArray())); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/extractor/ExtractedFieldsTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/extractor/ExtractedFieldsTests.java index 5ac983e7d505b..a51eafd1d8b3d 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/extractor/ExtractedFieldsTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/extractor/ExtractedFieldsTests.java @@ -32,7 +32,7 @@ public void testAllTypesOfFields() { ExtractedField sourceField1 = new SourceField("src1", Collections.singleton("text")); ExtractedField sourceField2 = new SourceField("src2", Collections.singleton("text")); ExtractedFields extractedFields = new ExtractedFields(Arrays.asList( - docValue1, docValue2, scriptField1, scriptField2, sourceField1, sourceField2)); + docValue1, docValue2, scriptField1, scriptField2, sourceField1, sourceField2), Collections.emptyMap()); assertThat(extractedFields.getAllFields().size(), equalTo(6)); assertThat(extractedFields.getDocValueFields().stream().map(ExtractedField::getName).toArray(String[]::new), @@ -54,7 +54,7 @@ public void testBuildGivenMixtureOfTypes() { when(fieldCapabilitiesResponse.getField("airline")).thenReturn(airlineCaps); ExtractedFields extractedFields = ExtractedFields.build(Arrays.asList("time", "value", "airline", "airport"), - new HashSet<>(Collections.singletonList("airport")), fieldCapabilitiesResponse); + new HashSet<>(Collections.singletonList("airport")), fieldCapabilitiesResponse, Collections.emptyMap()); assertThat(extractedFields.getDocValueFields().size(), equalTo(2)); assertThat(extractedFields.getDocValueFields().get(0).getName(), equalTo("time")); @@ -77,7 +77,7 @@ public void testBuildGivenMultiFields() { when(fieldCapabilitiesResponse.getField("airport.keyword")).thenReturn(keyword); ExtractedFields extractedFields = ExtractedFields.build(Arrays.asList("airline.text", "airport.keyword"), - Collections.emptySet(), fieldCapabilitiesResponse); + Collections.emptySet(), fieldCapabilitiesResponse, Collections.emptyMap()); assertThat(extractedFields.getDocValueFields().size(), equalTo(1)); assertThat(extractedFields.getDocValueFields().get(0).getName(), equalTo("airport.keyword")); @@ -119,7 +119,7 @@ public void testBuildGivenFieldWithoutMappings() { FieldCapabilitiesResponse fieldCapabilitiesResponse = mock(FieldCapabilitiesResponse.class); IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> ExtractedFields.build( - Collections.singletonList("value"), Collections.emptySet(), fieldCapabilitiesResponse)); + Collections.singletonList("value"), Collections.emptySet(), fieldCapabilitiesResponse, Collections.emptyMap())); assertThat(e.getMessage(), equalTo("cannot retrieve field [value] because it has no mappings")); } From b4d0c5cc7ccb2c017cdaad6565250415d1e0f9ef Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Fri, 13 Mar 2020 17:11:59 +0200 Subject: [PATCH 2/3] We can't test cardinality is too high in the YML tests anymore --- .../test/ml/start_data_frame_analytics.yml | 21 +------------------ 1 file changed, 1 insertion(+), 20 deletions(-) diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/start_data_frame_analytics.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/start_data_frame_analytics.yml index 81077b1e69a8c..ab6cfac9515e5 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/start_data_frame_analytics.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/start_data_frame_analytics.yml @@ -142,7 +142,7 @@ id: "start_given_empty_dest_index" --- -"Test start classification analysis when the dependent variable cardinality is too low or too high": +"Test start classification analysis when the dependent variable cardinality is too low": - do: indices.create: index: index-with-dep-var-with-too-high-card @@ -179,22 +179,3 @@ catch: /Field \[keyword_field\] must have at least \[2\] distinct values but there were \[1\]/ ml.start_data_frame_analytics: id: "classification-cardinality-limits" - - - do: - index: - index: index-with-dep-var-with-too-high-card - body: { numeric_field: 2.0, keyword_field: "class_b" } - - - do: - index: - index: index-with-dep-var-with-too-high-card - body: { numeric_field: 3.0, keyword_field: "class_c" } - - - do: - indices.refresh: - index: index-with-dep-var-with-too-high-card - - - do: - catch: /Field \[keyword_field\] must have at most \[2\] distinct values but there were at least \[3\]/ - ml.start_data_frame_analytics: - id: "classification-cardinality-limits" From 150dd0e6eda3cd7f04e60c9dc1e11199be41cfaf Mon Sep 17 00:00:00 2001 From: Dimitris Athanasiou Date: Mon, 16 Mar 2020 11:20:43 +0200 Subject: [PATCH 3/3] Extract max number of classes in a constant --- .../xpack/core/ml/dataframe/analyses/Classification.java | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java index 381f574982db2..c0965f8d7ebb5 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java @@ -51,6 +51,11 @@ public class Classification implements DataFrameAnalysis { private static final ConstructingObjectParser LENIENT_PARSER = createParser(true); private static final ConstructingObjectParser STRICT_PARSER = createParser(false); + /** + * The max number of classes classification supports + */ + private static final int MAX_DEPENDENT_VARIABLE_CARDINALITY = 30; + private static ConstructingObjectParser createParser(boolean lenient) { ConstructingObjectParser parser = new ConstructingObjectParser<>( NAME.getPreferredName(), @@ -275,7 +280,7 @@ public List getRequiredFields() { @Override public List getFieldCardinalityConstraints() { // This restriction is due to the fact that currently the C++ backend only supports binomial classification. - return Collections.singletonList(FieldCardinalityConstraint.between(dependentVariable, 2, 30)); + return Collections.singletonList(FieldCardinalityConstraint.between(dependentVariable, 2, MAX_DEPENDENT_VARIABLE_CARDINALITY)); } @SuppressWarnings("unchecked")