diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index 1d3f6bf4f2e75..f5db9ae690a96 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -15,6 +15,8 @@ import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.SearchHit; import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; @@ -228,7 +230,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty_DependentVariableI assertEvaluation(BOOLEAN_FIELD, BOOLEAN_FIELD_VALUES, "ml.boolean-field_prediction"); } - public void testDependentVariableCardinalityTooHighError() { + public void testDependentVariableCardinalityTooHighError() throws Exception { initialize("cardinality_too_high"); indexData(sourceIndex, 6, 5, KEYWORD_FIELD); // Index one more document with a class different than the two already used. @@ -246,6 +248,27 @@ public void testDependentVariableCardinalityTooHighError() { assertThat(e.getMessage(), equalTo("Field [keyword-field] must have at most [2] distinct values but there were at least [3]")); } + public void testDependentVariableCardinalityTooHighButWithQueryMakesItWithinRange() throws Exception { + initialize("cardinality_too_high_with_query"); + indexData(sourceIndex, 6, 5, KEYWORD_FIELD); + // Index one more document with a class different than the two already used. + client().execute(IndexAction.INSTANCE, new IndexRequest(sourceIndex) + .source(KEYWORD_FIELD, "fox") + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)) + .actionGet(); + QueryBuilder query = QueryBuilders.boolQuery().filter(QueryBuilders.termsQuery(KEYWORD_FIELD, KEYWORD_FIELD_VALUES)); + + DataFrameAnalyticsConfig config = buildAnalytics(jobId, sourceIndex, destIndex, null, new Classification(KEYWORD_FIELD), query); + registerAnalytics(config); + putAnalytics(config); + + // Should not throw + startAnalytics(jobId); + waitUntilAnalyticsIsStopped(jobId); + + assertProgress(jobId, 100, 100, 100, 100); + } + private void initialize(String jobId) { this.jobId = jobId; this.sourceIndex = jobId + "_source_index"; diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java index 0b9e2c19961d8..29ef54d3f7524 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Strings; import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.xpack.core.ml.action.DeleteDataFrameAnalyticsAction; @@ -37,6 +38,7 @@ import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.core.ml.notifications.AuditorField; import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; +import org.elasticsearch.xpack.core.ml.utils.QueryProvider; import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask; import org.hamcrest.Matcher; import org.hamcrest.Matchers; @@ -161,10 +163,16 @@ protected EvaluateDataFrameAction.Response evaluateDataFrame(String index, Evalu } protected static DataFrameAnalyticsConfig buildAnalytics(String id, String sourceIndex, String destIndex, - @Nullable String resultsField, DataFrameAnalysis analysis) { + @Nullable String resultsField, DataFrameAnalysis analysis) throws Exception { + return buildAnalytics(id, sourceIndex, destIndex, resultsField, analysis, QueryBuilders.matchAllQuery()); + } + + protected static DataFrameAnalyticsConfig buildAnalytics(String id, String sourceIndex, String destIndex, + @Nullable String resultsField, DataFrameAnalysis analysis, + QueryBuilder queryBuilder) throws Exception { return new DataFrameAnalyticsConfig.Builder() .setId(id) - .setSource(new DataFrameAnalyticsSource(new String[] { sourceIndex }, null, null)) + .setSource(new DataFrameAnalyticsSource(new String[] { sourceIndex }, QueryProvider.fromParsedQuery(queryBuilder), null)) .setDest(new DataFrameAnalyticsDest(destIndex, resultsField)) .setAnalysis(analysis) .build(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorFactory.java index ea37bdf393aeb..8e6ad7a614b09 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/ExtractedFieldsDetectorFactory.java @@ -109,7 +109,7 @@ private void getCardinalitiesForFieldsWithLimit(String[] index, DataFrameAnalyti listener::onFailure ); - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().size(0).query(config.getSource().getParsedQuery()); for (Map.Entry entry : fieldCardinalityLimits.entrySet()) { String fieldName = entry.getKey(); Long limit = entry.getValue();