Skip to content

Commit 59687a9

Browse files
[7.x][ML] Validate classification dependent_variable cardinality is at lea… (#51232) (#51309)
Data frame analytics classification currently only supports 2 classes for the dependent variable. We were checking that the field's cardinality is not higher than 2 but we should also check it is not less than that as otherwise the process fails. Backport of #51232
1 parent 2a73e84 commit 59687a9

File tree

16 files changed

+197
-70
lines changed

16 files changed

+197
-70
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,9 +245,9 @@ public List<RequiredField> getRequiredFields() {
245245
}
246246

247247
@Override
248-
public Map<String, Long> getFieldCardinalityLimits() {
248+
public List<FieldCardinalityConstraint> getFieldCardinalityConstraints() {
249249
// This restriction is due to the fact that currently the C++ backend only supports binomial classification.
250-
return Collections.singletonMap(dependentVariable, 2L);
250+
return Collections.singletonList(FieldCardinalityConstraint.between(dependentVariable, 2, 2));
251251
}
252252

253253
@SuppressWarnings("unchecked")

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
3737
List<RequiredField> getRequiredFields();
3838

3939
/**
40-
* @return {@link Map} containing cardinality limits for the selected (analysis-specific) fields
40+
* @return {@link List} containing cardinality constraints for the selected (analysis-specific) fields
4141
*/
42-
Map<String, Long> getFieldCardinalityLimits();
42+
List<FieldCardinalityConstraint> getFieldCardinalityConstraints();
4343

4444
/**
4545
* Returns fields for which the mappings should be either predefined or copied from source index to destination index.
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
7+
8+
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
9+
10+
import java.util.Objects;
11+
12+
/**
13+
* Allows checking a field's cardinality against given lower and upper bounds
14+
*/
15+
public class FieldCardinalityConstraint {
16+
17+
private final String field;
18+
private final long lowerBound;
19+
private final long upperBound;
20+
21+
public static FieldCardinalityConstraint between(String field, long lowerBound, long upperBound) {
22+
return new FieldCardinalityConstraint(field, lowerBound, upperBound);
23+
}
24+
25+
private FieldCardinalityConstraint(String field, long lowerBound, long upperBound) {
26+
this.field = Objects.requireNonNull(field);
27+
this.lowerBound = lowerBound;
28+
this.upperBound = upperBound;
29+
}
30+
31+
public String getField() {
32+
return field;
33+
}
34+
35+
public long getLowerBound() {
36+
return lowerBound;
37+
}
38+
39+
public long getUpperBound() {
40+
return upperBound;
41+
}
42+
43+
public void check(long fieldCardinality) {
44+
if (fieldCardinality < lowerBound) {
45+
throw ExceptionsHelper.badRequestException(
46+
"Field [{}] must have at least [{}] distinct values but there were [{}]",
47+
field, lowerBound, fieldCardinality);
48+
}
49+
if (fieldCardinality > upperBound) {
50+
throw ExceptionsHelper.badRequestException(
51+
"Field [{}] must have at most [{}] distinct values but there were at least [{}]",
52+
field, upperBound, fieldCardinality);
53+
}
54+
}
55+
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,8 @@ public List<RequiredField> getRequiredFields() {
225225
}
226226

227227
@Override
228-
public Map<String, Long> getFieldCardinalityLimits() {
229-
return Collections.emptyMap();
228+
public List<FieldCardinalityConstraint> getFieldCardinalityConstraints() {
229+
return Collections.emptyList();
230230
}
231231

232232
@Override

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ public List<RequiredField> getRequiredFields() {
182182
}
183183

184184
@Override
185-
public Map<String, Long> getFieldCardinalityLimits() {
186-
return Collections.emptyMap();
185+
public List<FieldCardinalityConstraint> getFieldCardinalityConstraints() {
186+
return Collections.emptyList();
187187
}
188188

189189
@Override

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.io.IOException;
2323
import java.util.Collections;
2424
import java.util.HashMap;
25+
import java.util.List;
2526
import java.util.Map;
2627
import java.util.Set;
2728

@@ -169,7 +170,13 @@ public void testRequiredFieldsIsNonEmpty() {
169170
}
170171

171172
public void testFieldCardinalityLimitsIsNonEmpty() {
172-
assertThat(createTestInstance().getFieldCardinalityLimits(), is(not(anEmptyMap())));
173+
Classification classification = createTestInstance();
174+
List<FieldCardinalityConstraint> constraints = classification.getFieldCardinalityConstraints();
175+
176+
assertThat(constraints.size(), equalTo(1));
177+
assertThat(constraints.get(0).getField(), equalTo(classification.getDependentVariable()));
178+
assertThat(constraints.get(0).getLowerBound(), equalTo(2L));
179+
assertThat(constraints.get(0).getUpperBound(), equalTo(2L));
173180
}
174181

175182
public void testGetExplicitlyMappedFields() {
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
package org.elasticsearch.xpack.core.ml.dataframe.analyses;
7+
8+
import org.elasticsearch.ElasticsearchStatusException;
9+
import org.elasticsearch.rest.RestStatus;
10+
import org.elasticsearch.test.ESTestCase;
11+
12+
import static org.hamcrest.Matchers.equalTo;
13+
14+
public class FieldCardinalityConstraintTests extends ESTestCase {
15+
16+
public void testBetween_GivenWithinLimits() {
17+
FieldCardinalityConstraint constraint = FieldCardinalityConstraint.between("foo", 3, 6);
18+
19+
constraint.check(3);
20+
constraint.check(4);
21+
constraint.check(5);
22+
constraint.check(6);
23+
}
24+
25+
public void testBetween_GivenLessThanLowerBound() {
26+
FieldCardinalityConstraint constraint = FieldCardinalityConstraint.between("foo", 3, 6);
27+
28+
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> constraint.check(2L));
29+
assertThat(e.getMessage(), equalTo("Field [foo] must have at least [3] distinct values but there were [2]"));
30+
assertThat(e.status(), equalTo(RestStatus.BAD_REQUEST));
31+
}
32+
33+
public void testBetween_GivenGreaterThanUpperBound() {
34+
FieldCardinalityConstraint constraint = FieldCardinalityConstraint.between("foo", 3, 6);
35+
36+
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> constraint.check(7L));
37+
assertThat(e.getMessage(), equalTo("Field [foo] must have at most [6] distinct values but there were at least [7]"));
38+
assertThat(e.status(), equalTo(RestStatus.BAD_REQUEST));
39+
}
40+
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ public void testRequiredFieldsIsEmpty() {
8989
}
9090

9191
public void testFieldCardinalityLimitsIsEmpty() {
92-
assertThat(createTestInstance().getFieldCardinalityLimits(), is(anEmptyMap()));
92+
assertThat(createTestInstance().getFieldCardinalityConstraints(), is(empty()));
9393
}
9494

9595
public void testGetExplicitlyMappedFields() {

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import java.util.Collections;
2020

2121
import static org.hamcrest.Matchers.allOf;
22-
import static org.hamcrest.Matchers.anEmptyMap;
2322
import static org.hamcrest.Matchers.containsString;
2423
import static org.hamcrest.Matchers.empty;
2524
import static org.hamcrest.Matchers.equalTo;
@@ -107,7 +106,7 @@ public void testRequiredFieldsIsNonEmpty() {
107106
}
108107

109108
public void testFieldCardinalityLimitsIsEmpty() {
110-
assertThat(createTestInstance().getFieldCardinalityLimits(), is(anEmptyMap()));
109+
assertThat(createTestInstance().getFieldCardinalityConstraints(), is(empty()));
111110
}
112111

113112
public void testGetExplicitlyMappedFields() {

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ExplainDataFrameAnalyticsIT.java

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,23 @@ public void testSourceQueryIsApplied() throws IOException {
4343
String sourceIndex = "test-source-query-is-applied";
4444

4545
client().admin().indices().prepareCreate(sourceIndex)
46-
.addMapping("_doc", "numeric_1", "type=double", "numeric_2", "type=float", "categorical", "type=keyword")
46+
.addMapping("_doc",
47+
"numeric_1", "type=double",
48+
"numeric_2", "type=float",
49+
"categorical", "type=keyword",
50+
"filtered_field", "type=keyword")
4751
.get();
4852

4953
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();
5054
bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
5155

5256
for (int i = 0; i < 30; i++) {
5357
IndexRequest indexRequest = new IndexRequest(sourceIndex);
54-
55-
// We insert one odd value out of 5 for one feature
56-
indexRequest.source("numeric_1", 1.0, "numeric_2", 2.0, "categorical", i == 0 ? "only-one" : "normal");
58+
indexRequest.source(
59+
"numeric_1", 1.0,
60+
"numeric_2", 2.0,
61+
"categorical", i % 2 == 0 ? "class_1" : "class_2",
62+
"filtered_field", i < 2 ? "bingo" : "rest"); // We tag bingo on the first two docs to ensure we have 2 classes
5763
bulkRequestBuilder.add(indexRequest);
5864
}
5965
BulkResponse bulkResponse = bulkRequestBuilder.get();
@@ -66,7 +72,7 @@ public void testSourceQueryIsApplied() throws IOException {
6672
DataFrameAnalyticsConfig config = new DataFrameAnalyticsConfig.Builder()
6773
.setId(id)
6874
.setSource(new DataFrameAnalyticsSource(new String[] { sourceIndex },
69-
QueryProvider.fromParsedQuery(QueryBuilders.termQuery("categorical", "only-one")),
75+
QueryProvider.fromParsedQuery(QueryBuilders.termQuery("filtered_field", "bingo")),
7076
null))
7177
.setAnalysis(new Classification("categorical"))
7278
.buildForExplain();

0 commit comments

Comments
 (0)