Skip to content

Commit c149c64

Browse files
[7.x][ML] Apply source query on data frame analytics memory estimation (#49517) (#49532)
Closes #49454 Backport of #49517
1 parent a5fa86e commit c149c64

File tree

3 files changed

+91
-13
lines changed

3 files changed

+91
-13
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
package org.elasticsearch.xpack.ml.integration;
7+
8+
import org.elasticsearch.action.bulk.BulkRequestBuilder;
9+
import org.elasticsearch.action.bulk.BulkResponse;
10+
import org.elasticsearch.action.index.IndexRequest;
11+
import org.elasticsearch.action.support.WriteRequest;
12+
import org.elasticsearch.index.query.QueryBuilders;
13+
import org.elasticsearch.xpack.core.ml.action.ExplainDataFrameAnalyticsAction;
14+
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
15+
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
16+
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
17+
import org.elasticsearch.xpack.core.ml.utils.QueryProvider;
18+
19+
import java.io.IOException;
20+
21+
import static org.hamcrest.Matchers.lessThanOrEqualTo;
22+
23+
public class ExplainDataFrameAnalyticsIT extends MlNativeDataFrameAnalyticsIntegTestCase {
24+
25+
public void testSourceQueryIsApplied() throws IOException {
26+
// To test the source query is applied when we extract data,
27+
// we set up a job where we have a query which excludes all but one document.
28+
// We then assert the memory estimation is low enough.
29+
30+
String sourceIndex = "test-source-query-is-applied";
31+
32+
client().admin().indices().prepareCreate(sourceIndex)
33+
.addMapping("_doc", "numeric_1", "type=double", "numeric_2", "type=float", "categorical", "type=keyword")
34+
.get();
35+
36+
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();
37+
bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
38+
39+
for (int i = 0; i < 30; i++) {
40+
IndexRequest indexRequest = new IndexRequest(sourceIndex);
41+
42+
// We insert one odd value out of 5 for one feature
43+
indexRequest.source("numeric_1", 1.0, "numeric_2", 2.0, "categorical", i == 0 ? "only-one" : "normal");
44+
bulkRequestBuilder.add(indexRequest);
45+
}
46+
BulkResponse bulkResponse = bulkRequestBuilder.get();
47+
if (bulkResponse.hasFailures()) {
48+
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
49+
}
50+
51+
String id = "test_source_query_is_applied";
52+
53+
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder()
54+
.setId(id)
55+
.setSource(new DataFrameAnalyticsSource(new String[] { sourceIndex },
56+
QueryProvider.fromParsedQuery(QueryBuilders.termQuery("categorical", "only-one"))))
57+
.setAnalysis(new Classification("categorical"))
58+
.buildForExplain();
59+
60+
ExplainDataFrameAnalyticsAction.Response explainResponse = explainDataFrame(config);
61+
62+
assertThat(explainResponse.getMemoryEstimation().getExpectedMemoryWithoutDisk().getKb(), lessThanOrEqualTo(500L));
63+
}
64+
}

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import org.elasticsearch.search.sort.SortOrder;
2020
import org.elasticsearch.xpack.core.ml.action.DeleteDataFrameAnalyticsAction;
2121
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
22+
import org.elasticsearch.xpack.core.ml.action.ExplainDataFrameAnalyticsAction;
2223
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsAction;
2324
import org.elasticsearch.xpack.core.ml.action.GetDataFrameAnalyticsStatsAction;
2425
import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction;
@@ -145,6 +146,11 @@ protected GetDataFrameAnalyticsStatsAction.Response.Stats getAnalyticsStats(Stri
145146
return stats.get(0);
146147
}
147148

149+
protected ExplainDataFrameAnalyticsAction.Response explainDataFrame(DataFrameAnalyticsConfig config) {
150+
PutDataFrameAnalyticsAction.Request request = new PutDataFrameAnalyticsAction.Request(config);
151+
return client().execute(ExplainDataFrameAnalyticsAction.INSTANCE, request).actionGet();
152+
}
153+
148154
protected EvaluateDataFrameAction.Response evaluateDataFrame(String index, Evaluation evaluation) {
149155
EvaluateDataFrameAction.Request request =
150156
new EvaluateDataFrameAction.Request()
@@ -155,12 +161,12 @@ protected EvaluateDataFrameAction.Response evaluateDataFrame(String index, Evalu
155161

156162
protected static DataFrameAnalyticsConfig buildAnalytics(String id, String sourceIndex, String destIndex,
157163
@Nullable String resultsField, DataFrameAnalysis analysis) {
158-
DataFrameAnalyticsConfig.Builder configBuilder = new DataFrameAnalyticsConfig.Builder();
159-
configBuilder.setId(id);
160-
configBuilder.setSource(new DataFrameAnalyticsSource(new String[] { sourceIndex }, null));
161-
configBuilder.setDest(new DataFrameAnalyticsDest(destIndex, resultsField));
162-
configBuilder.setAnalysis(analysis);
163-
return configBuilder.build();
164+
return new DataFrameAnalyticsConfig.Builder()
165+
.setId(id)
166+
.setSource(new DataFrameAnalyticsSource(new String[] { sourceIndex }, null))
167+
.setDest(new DataFrameAnalyticsDest(destIndex, resultsField))
168+
.setAnalysis(analysis)
169+
.build();
164170
}
165171

166172
protected void assertIsStopped(String id) {

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractorFactory.java

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,18 @@ public class DataFrameDataExtractorFactory {
2525
private final Client client;
2626
private final String analyticsId;
2727
private final List<String> indices;
28+
private final QueryBuilder sourceQuery;
2829
private final ExtractedFields extractedFields;
2930
private final Map<String, String> headers;
3031
private final boolean includeRowsWithMissingValues;
3132

32-
public DataFrameDataExtractorFactory(Client client, String analyticsId, List<String> indices, ExtractedFields extractedFields,
33-
Map<String, String> headers, boolean includeRowsWithMissingValues) {
33+
private DataFrameDataExtractorFactory(Client client, String analyticsId, List<String> indices, QueryBuilder sourceQuery,
34+
ExtractedFields extractedFields, Map<String, String> headers,
35+
boolean includeRowsWithMissingValues) {
3436
this.client = Objects.requireNonNull(client);
3537
this.analyticsId = Objects.requireNonNull(analyticsId);
3638
this.indices = Objects.requireNonNull(indices);
39+
this.sourceQuery = Objects.requireNonNull(sourceQuery);
3740
this.extractedFields = Objects.requireNonNull(extractedFields);
3841
this.headers = headers;
3942
this.includeRowsWithMissingValues = includeRowsWithMissingValues;
@@ -54,7 +57,12 @@ public DataFrameDataExtractor newExtractor(boolean includeSource) {
5457
}
5558

5659
private QueryBuilder createQuery() {
57-
return includeRowsWithMissingValues ? QueryBuilders.matchAllQuery() : allExtractedFieldsExistQuery();
60+
BoolQueryBuilder query = QueryBuilders.boolQuery();
61+
query.filter(sourceQuery);
62+
if (includeRowsWithMissingValues == false) {
63+
query.filter(allExtractedFieldsExistQuery());
64+
}
65+
return query;
5866
}
5967

6068
private QueryBuilder allExtractedFieldsExistQuery() {
@@ -77,8 +85,8 @@ private QueryBuilder allExtractedFieldsExistQuery() {
7785
*/
7886
public static DataFrameDataExtractorFactory createForSourceIndices(Client client, String taskId, DataFrameAnalyticsConfig config,
7987
ExtractedFields extractedFields) {
80-
return new DataFrameDataExtractorFactory(client, taskId, Arrays.asList(config.getSource().getIndex()), extractedFields,
81-
config.getHeaders(), config.getAnalysis().supportsMissingValues());
88+
return new DataFrameDataExtractorFactory(client, taskId, Arrays.asList(config.getSource().getIndex()),
89+
config.getSource().getParsedQuery(), extractedFields, config.getHeaders(), config.getAnalysis().supportsMissingValues());
8290
}
8391

8492
/**
@@ -100,8 +108,8 @@ public static void createForDestinationIndex(Client client,
100108
extractedFieldsDetector -> {
101109
ExtractedFields extractedFields = extractedFieldsDetector.detect().v1();
102110
DataFrameDataExtractorFactory extractorFactory = new DataFrameDataExtractorFactory(client, config.getId(),
103-
Collections.singletonList(config.getDest().getIndex()), extractedFields, config.getHeaders(),
104-
config.getAnalysis().supportsMissingValues());
111+
Collections.singletonList(config.getDest().getIndex()), config.getSource().getParsedQuery(), extractedFields,
112+
config.getHeaders(), config.getAnalysis().supportsMissingValues());
105113
listener.onResponse(extractorFactory);
106114
},
107115
listener::onFailure

0 commit comments

Comments
 (0)