Skip to content

Commit 83ffe96

Browse files
authored
[7.6] Do not copy mapping from dependent variable to prediction field in regression analysis (#51227) (#51289)
1 parent d9cf8fc commit 83ffe96

File tree

12 files changed

+179
-101
lines changed

12 files changed

+179
-101
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
1515
import org.elasticsearch.common.xcontent.XContentBuilder;
1616
import org.elasticsearch.common.xcontent.XContentParser;
17+
import org.elasticsearch.index.mapper.FieldAliasMapper;
1718
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
1819

1920
import java.io.IOException;
@@ -28,6 +29,7 @@
2829

2930
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
3031
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
32+
import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue;
3133

3234
public class Classification implements DataFrameAnalysis {
3335

@@ -248,12 +250,32 @@ public Map<String, Long> getFieldCardinalityLimits() {
248250
return Collections.singletonMap(dependentVariable, 2L);
249251
}
250252

253+
@SuppressWarnings("unchecked")
251254
@Override
252-
public Map<String, String> getExplicitlyMappedFields(String resultsFieldName) {
253-
return new HashMap<String, String>() {{
254-
put(resultsFieldName + "." + predictionFieldName, dependentVariable);
255-
put(resultsFieldName + ".top_classes.class_name", dependentVariable);
256-
}};
255+
public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
256+
Object dependentVariableMapping = extractMapping(dependentVariable, mappingsProperties);
257+
if ((dependentVariableMapping instanceof Map) == false) {
258+
return Collections.emptyMap();
259+
}
260+
Map<String, Object> dependentVariableMappingAsMap = (Map) dependentVariableMapping;
261+
// If the source field is an alias, fetch the concrete field that the alias points to.
262+
if (FieldAliasMapper.CONTENT_TYPE.equals(dependentVariableMappingAsMap.get("type"))) {
263+
String path = (String) dependentVariableMappingAsMap.get(FieldAliasMapper.Names.PATH);
264+
dependentVariableMapping = extractMapping(path, mappingsProperties);
265+
}
266+
// We may have updated the value of {@code dependentVariableMapping} in the "if" block above.
267+
// Hence, we need to check the "instanceof" condition again.
268+
if ((dependentVariableMapping instanceof Map) == false) {
269+
return Collections.emptyMap();
270+
}
271+
Map<String, Object> additionalProperties = new HashMap<>();
272+
additionalProperties.put(resultsFieldName + "." + predictionFieldName, dependentVariableMapping);
273+
additionalProperties.put(resultsFieldName + ".top_classes.class_name", dependentVariableMapping);
274+
return additionalProperties;
275+
}
276+
277+
private static Object extractMapping(String path, Map<String, Object> mappingsProperties) {
278+
return extractValue(String.join(".properties.", path.split("\\.")), mappingsProperties);
257279
}
258280

259281
@Override

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,13 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
4242
Map<String, Long> getFieldCardinalityLimits();
4343

4444
/**
45-
* Returns fields for which the mappings should be copied from source index to destination index.
46-
* Each entry of the returned {@link Map} is of the form:
47-
* key - field path in the destination index
48-
* value - field path in the source index from which the mapping should be taken
45+
* Returns fields for which the mappings should be either predefined or copied from source index to destination index.
4946
*
47+
* @param mappingsProperties mappings.properties portion of the index mappings
5048
* @param resultsFieldName name of the results field under which all the results are stored
51-
* @return {@link Map} containing fields for which the mappings should be copied from source index to destination index
49+
* @return {@link Map} containing fields for which the mappings should be handled explicitly
5250
*/
53-
Map<String, String> getExplicitlyMappedFields(String resultsFieldName);
51+
Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName);
5452

5553
/**
5654
* @return {@code true} if this analysis supports data frame rows with missing values

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ public Map<String, Long> getFieldCardinalityLimits() {
230230
}
231231

232232
@Override
233-
public Map<String, String> getExplicitlyMappedFields(String resultsFieldName) {
233+
public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
234234
return Collections.emptyMap();
235235
}
236236

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,10 @@ public Map<String, Long> getFieldCardinalityLimits() {
187187
}
188188

189189
@Override
190-
public Map<String, String> getExplicitlyMappedFields(String resultsFieldName) {
191-
return Collections.singletonMap(resultsFieldName + "." + predictionFieldName, dependentVariable);
190+
public Map<String, Object> getExplicitlyMappedFields(Map<String, Object> mappingsProperties, String resultsFieldName) {
191+
// Prediction field should be always mapped as "double" rather than "float" in order to increase precision in case of
192+
// high (over 10M) values of dependent variable.
193+
return Collections.singletonMap(resultsFieldName + "." + predictionFieldName, Collections.singletonMap("type", "double"));
192194
}
193195

194196
@Override

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import java.util.Map;
2626
import java.util.Set;
2727

28+
import static org.hamcrest.Matchers.allOf;
2829
import static org.hamcrest.Matchers.anEmptyMap;
2930
import static org.hamcrest.Matchers.containsString;
3031
import static org.hamcrest.Matchers.empty;
@@ -171,8 +172,40 @@ public void testFieldCardinalityLimitsIsNonEmpty() {
171172
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(anEmptyMap())));
172173
}
173174

174-
public void testFieldMappingsToCopyIsNonEmpty() {
175-
assertThat(createTestInstance().getExplicitlyMappedFields(""), is(not(anEmptyMap())));
175+
public void testGetExplicitlyMappedFields() {
176+
assertThat(new Classification("foo").getExplicitlyMappedFields(null, "results"), is(anEmptyMap()));
177+
assertThat(new Classification("foo").getExplicitlyMappedFields(Collections.emptyMap(), "results"), is(anEmptyMap()));
178+
assertThat(
179+
new Classification("foo").getExplicitlyMappedFields(Collections.singletonMap("foo", "not_a_map"), "results"),
180+
is(anEmptyMap()));
181+
assertThat(
182+
new Classification("foo").getExplicitlyMappedFields(
183+
Collections.singletonMap("foo", Collections.singletonMap("bar", "baz")),
184+
"results"),
185+
allOf(
186+
hasEntry("results.foo_prediction", Collections.singletonMap("bar", "baz")),
187+
hasEntry("results.top_classes.class_name", Collections.singletonMap("bar", "baz"))));
188+
assertThat(
189+
new Classification("foo").getExplicitlyMappedFields(
190+
new HashMap<String, Object>() {{
191+
put("foo", new HashMap<String, String>() {{
192+
put("type", "alias");
193+
put("path", "bar");
194+
}});
195+
put("bar", Collections.singletonMap("type", "long"));
196+
}},
197+
"results"),
198+
allOf(
199+
hasEntry("results.foo_prediction", Collections.singletonMap("type", "long")),
200+
hasEntry("results.top_classes.class_name", Collections.singletonMap("type", "long"))));
201+
assertThat(
202+
new Classification("foo").getExplicitlyMappedFields(
203+
Collections.singletonMap("foo", new HashMap<String, String>() {{
204+
put("type", "alias");
205+
put("path", "missing");
206+
}}),
207+
"results"),
208+
is(anEmptyMap()));
176209
}
177210

178211
public void testToXContent_GivenVersionBeforeRandomizeSeedWasIntroduced() throws IOException {

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ public void testFieldCardinalityLimitsIsEmpty() {
9292
assertThat(createTestInstance().getFieldCardinalityLimits(), is(anEmptyMap()));
9393
}
9494

95-
public void testFieldMappingsToCopyIsEmpty() {
96-
assertThat(createTestInstance().getExplicitlyMappedFields(""), is(anEmptyMap()));
95+
public void testGetExplicitlyMappedFields() {
96+
assertThat(createTestInstance().getExplicitlyMappedFields(null, null), is(anEmptyMap()));
9797
}
9898

9999
public void testGetStateDocId() {

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ protected Regression createTestInstance() {
4343
return createRandom();
4444
}
4545

46-
public static Regression createRandom() {
46+
private static Regression createRandom() {
4747
String dependentVariableName = randomAlphaOfLength(10);
4848
BoostedTreeParams boostedTreeParams = BoostedTreeParamsTests.createRandom();
4949
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
@@ -110,8 +110,10 @@ public void testFieldCardinalityLimitsIsEmpty() {
110110
assertThat(createTestInstance().getFieldCardinalityLimits(), is(anEmptyMap()));
111111
}
112112

113-
public void testFieldMappingsToCopyIsNonEmpty() {
114-
assertThat(createTestInstance().getExplicitlyMappedFields(""), is(not(anEmptyMap())));
113+
public void testGetExplicitlyMappedFields() {
114+
assertThat(
115+
new Regression("foo").getExplicitlyMappedFields(null, "results"),
116+
hasEntry("results.foo_prediction", Collections.singletonMap("type", "double")));
115117
}
116118

117119
public void testGetStateDocId() {

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java

Lines changed: 7 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
// Pending fix
99
//import com.google.common.collect.Ordering;
1010
import org.elasticsearch.ElasticsearchStatusException;
11-
import org.elasticsearch.action.admin.indices.get.GetIndexAction;
12-
import org.elasticsearch.action.admin.indices.get.GetIndexRequest;
1311
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
1412
import org.elasticsearch.action.bulk.BulkRequestBuilder;
1513
import org.elasticsearch.action.bulk.BulkResponse;
@@ -43,7 +41,6 @@
4341
import java.util.Set;
4442

4543
import static java.util.stream.Collectors.toList;
46-
import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue;
4744
import static org.hamcrest.Matchers.allOf;
4845
import static org.hamcrest.Matchers.anyOf;
4946
import static org.hamcrest.Matchers.equalTo;
@@ -117,7 +114,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws
117114
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
118115
assertModelStatePersisted(stateDocId());
119116
assertInferenceModelPersisted(jobId);
120-
assertMlResultsFieldMappings(predictedClassField, "keyword");
117+
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
121118
assertThatAuditMessagesMatch(jobId,
122119
"Created analytics with analysis type [classification]",
123120
"Estimated memory usage for this analytics to be",
@@ -158,7 +155,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti
158155
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
159156
assertModelStatePersisted(stateDocId());
160157
assertInferenceModelPersisted(jobId);
161-
assertMlResultsFieldMappings(predictedClassField, "keyword");
158+
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
162159
assertThatAuditMessagesMatch(jobId,
163160
"Created analytics with analysis type [classification]",
164161
"Estimated memory usage for this analytics to be",
@@ -221,7 +218,7 @@ public <T> void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(String jobId,
221218
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
222219
assertModelStatePersisted(stateDocId());
223220
assertInferenceModelPersisted(jobId);
224-
assertMlResultsFieldMappings(predictedClassField, expectedMappingTypeForPredictedField);
221+
assertMlResultsFieldMappings(destIndex, predictedClassField, expectedMappingTypeForPredictedField);
225222
assertThatAuditMessagesMatch(jobId,
226223
"Created analytics with analysis type [classification]",
227224
"Estimated memory usage for this analytics to be",
@@ -309,7 +306,7 @@ public void testStopAndRestart() throws Exception {
309306
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
310307
assertModelStatePersisted(stateDocId());
311308
assertInferenceModelPersisted(jobId);
312-
assertMlResultsFieldMappings(predictedClassField, "keyword");
309+
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
313310
assertEvaluation(KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
314311
}
315312

@@ -366,7 +363,7 @@ public void testDependentVariableIsNested() throws Exception {
366363
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
367364
assertModelStatePersisted(stateDocId());
368365
assertInferenceModelPersisted(jobId);
369-
assertMlResultsFieldMappings(predictedClassField, "keyword");
366+
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
370367
assertEvaluation(NESTED_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
371368
}
372369

@@ -385,7 +382,7 @@ public void testDependentVariableIsAliasToKeyword() throws Exception {
385382
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
386383
assertModelStatePersisted(stateDocId());
387384
assertInferenceModelPersisted(jobId);
388-
assertMlResultsFieldMappings(predictedClassField, "keyword");
385+
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
389386
assertEvaluation(ALIAS_TO_KEYWORD_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
390387
}
391388

@@ -404,7 +401,7 @@ public void testDependentVariableIsAliasToNested() throws Exception {
404401
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
405402
assertModelStatePersisted(stateDocId());
406403
assertInferenceModelPersisted(jobId);
407-
assertMlResultsFieldMappings(predictedClassField, "keyword");
404+
assertMlResultsFieldMappings(destIndex, predictedClassField, "keyword");
408405
assertEvaluation(ALIAS_TO_NESTED_FIELD, KEYWORD_FIELD_VALUES, "ml." + predictedClassField);
409406
}
410407

@@ -565,15 +562,6 @@ private static Map<String, Object> getDestDoc(DataFrameAnalyticsConfig config, S
565562
return destDoc;
566563
}
567564

568-
/**
569-
* Wrapper around extractValue that:
570-
* - allows dots (".") in the path elements provided as arguments
571-
* - supports implicit casting to the appropriate type
572-
*/
573-
private static <T> T getFieldValue(Map<String, Object> doc, String... path) {
574-
return (T)extractValue(String.join(".", path), doc);
575-
}
576-
577565
private static <T> void assertTopClasses(Map<String, Object> resultsObject,
578566
int numTopClasses,
579567
String dependentVariable,
@@ -657,27 +645,6 @@ private <T> void assertEvaluation(String dependentVariable, List<T> dependentVar
657645
}
658646
}
659647

660-
private void assertMlResultsFieldMappings(String predictedClassField, String expectedType) {
661-
Map<String, Object> mappings =
662-
client()
663-
.execute(GetIndexAction.INSTANCE, new GetIndexRequest().indices(destIndex))
664-
.actionGet()
665-
.mappings()
666-
.get(destIndex)
667-
.get("_doc")
668-
.sourceAsMap();
669-
assertThat(
670-
mappings.toString(),
671-
getFieldValue(
672-
mappings,
673-
"properties", "ml", "properties", String.join(".properties.", predictedClassField.split("\\.")), "type"),
674-
equalTo(expectedType));
675-
assertThat(
676-
mappings.toString(),
677-
getFieldValue(mappings, "properties", "ml", "properties", "top_classes", "properties", "class_name", "type"),
678-
equalTo(expectedType));
679-
}
680-
681648
private String stateDocId() {
682649
return jobId + "_classification_state#1";
683650
}

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
*/
66
package org.elasticsearch.xpack.ml.integration;
77

8+
import org.elasticsearch.action.admin.indices.get.GetIndexAction;
9+
import org.elasticsearch.action.admin.indices.get.GetIndexRequest;
810
import org.elasticsearch.action.admin.indices.refresh.RefreshAction;
911
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
1012
import org.elasticsearch.action.admin.indices.refresh.RefreshResponse;
@@ -53,6 +55,7 @@
5355
import java.util.concurrent.TimeUnit;
5456
import java.util.stream.Collectors;
5557

58+
import static org.elasticsearch.common.xcontent.support.XContentMapValues.extractValue;
5659
import static org.hamcrest.Matchers.anyOf;
5760
import static org.hamcrest.Matchers.arrayWithSize;
5861
import static org.hamcrest.Matchers.equalTo;
@@ -281,4 +284,36 @@ protected static void assertModelStatePersisted(String stateDocId) {
281284
.get();
282285
assertThat(searchResponse.getHits().getHits().length, equalTo(1));
283286
}
287+
288+
protected static void assertMlResultsFieldMappings(String index, String predictedClassField, String expectedType) {
289+
Map<String, Object> mappings =
290+
client()
291+
.execute(GetIndexAction.INSTANCE, new GetIndexRequest().indices(index))
292+
.actionGet()
293+
.mappings()
294+
.get(index)
295+
.get("_doc")
296+
.sourceAsMap();
297+
assertThat(
298+
mappings.toString(),
299+
getFieldValue(
300+
mappings,
301+
"properties", "ml", "properties", String.join(".properties.", predictedClassField.split("\\.")), "type"),
302+
equalTo(expectedType));
303+
if (getFieldValue(mappings, "properties", "ml", "properties", "top_classes") != null) {
304+
assertThat(
305+
mappings.toString(),
306+
getFieldValue(mappings, "properties", "ml", "properties", "top_classes", "properties", "class_name", "type"),
307+
equalTo(expectedType));
308+
}
309+
}
310+
311+
/**
312+
* Wrapper around extractValue that:
313+
* - allows dots (".") in the path elements provided as arguments
314+
* - supports implicit casting to the appropriate type
315+
*/
316+
protected static <T> T getFieldValue(Map<String, Object> doc, String... path) {
317+
return (T)extractValue(String.join(".", path), doc);
318+
}
284319
}

0 commit comments

Comments
 (0)