Skip to content

Commit c141c1d

Browse files
[7.x][ML] Stratified cross validation split for classification (#54087) (#54104)
As classification now works for multiple classes, randomly picking training/test data frame rows is not good enough. This commit introduces a stratified cross validation splitter that maintains the proportion of the each class in the dataset in the sample that is used for training the model. Backport of #54087
1 parent e006d1f commit c141c1d

File tree

6 files changed

+403
-24
lines changed

6 files changed

+403
-24
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ public class Classification implements DataFrameAnalysis {
5454
/**
5555
* The max number of classes classification supports
5656
*/
57-
private static final int MAX_DEPENDENT_VARIABLE_CARDINALITY = 30;
57+
public static final int MAX_DEPENDENT_VARIABLE_CARDINALITY = 30;
5858

5959
private static ConstructingObjectParser<Classification, Void> createParser(boolean lenient) {
6060
ConstructingObjectParser<Classification, Void> parser = new ConstructingObjectParser<>(

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import org.elasticsearch.xpack.core.ClientHelper;
2626
import org.elasticsearch.xpack.core.ml.MlStatsIndex;
2727
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
28-
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
2928
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts;
3029
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
3130
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
@@ -162,7 +161,7 @@ private void processData(DataFrameAnalyticsTask task, ProcessContext processCont
162161
AnalyticsResultProcessor resultProcessor = processContext.resultProcessor.get();
163162
try {
164163
writeHeaderRecord(dataExtractor, process);
165-
writeDataRows(dataExtractor, process, config.getAnalysis(), task.getStatsHolder().getProgressTracker(),
164+
writeDataRows(dataExtractor, process, config, task.getStatsHolder().getProgressTracker(),
166165
task.getStatsHolder().getDataCountsTracker());
167166
processContext.statsPersister.persistWithRetry(task.getStatsHolder().getDataCountsTracker().report(config.getId()),
168167
DataCounts::documentId);
@@ -214,11 +213,12 @@ private void processData(DataFrameAnalyticsTask task, ProcessContext processCont
214213
}
215214
}
216215

217-
private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess<AnalyticsResult> process, DataFrameAnalysis analysis,
218-
ProgressTracker progressTracker, DataCountsTracker dataCountsTracker) throws IOException {
216+
private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess<AnalyticsResult> process,
217+
DataFrameAnalyticsConfig config, ProgressTracker progressTracker, DataCountsTracker dataCountsTracker)
218+
throws IOException {
219219

220-
CrossValidationSplitter crossValidationSplitter = new CrossValidationSplitterFactory(dataExtractor.getFieldNames())
221-
.create(analysis);
220+
CrossValidationSplitter crossValidationSplitter = new CrossValidationSplitterFactory(client, config, dataExtractor.getFieldNames())
221+
.create();
222222

223223
// The extra fields are for the doc hash and the control field (should be an empty string)
224224
String[] record = new String[dataExtractor.getFieldNames().size() + 2];
@@ -324,7 +324,8 @@ private void refreshIndices(String jobId) {
324324
);
325325
refreshRequest.indicesOptions(IndicesOptions.lenientExpandOpen());
326326

327-
LOGGER.debug("[{}] Refreshing indices {}", jobId, Arrays.toString(refreshRequest.indices()));
327+
LOGGER.debug(() -> new ParameterizedMessage("[{}] Refreshing indices {}",
328+
jobId, Arrays.toString(refreshRequest.indices())));
328329

329330
try (ThreadContext.StoredContext ignore = client.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) {
330331
client.admin().indices().refresh(refreshRequest).actionGet();

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitterFactory.java

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,32 +5,81 @@
55
*/
66
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
77

8+
import org.apache.logging.log4j.LogManager;
9+
import org.apache.logging.log4j.Logger;
10+
import org.apache.logging.log4j.message.ParameterizedMessage;
11+
import org.elasticsearch.ElasticsearchException;
12+
import org.elasticsearch.action.search.SearchRequestBuilder;
13+
import org.elasticsearch.action.search.SearchResponse;
14+
import org.elasticsearch.client.Client;
15+
import org.elasticsearch.search.aggregations.AggregationBuilders;
16+
import org.elasticsearch.search.aggregations.Aggregations;
17+
import org.elasticsearch.search.aggregations.bucket.terms.Terms;
18+
import org.elasticsearch.xpack.core.ClientHelper;
19+
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
820
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
9-
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
1021
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
1122

23+
import java.util.HashMap;
1224
import java.util.List;
25+
import java.util.Map;
1326
import java.util.Objects;
1427

1528
public class CrossValidationSplitterFactory {
1629

30+
private static final Logger LOGGER = LogManager.getLogger(CrossValidationSplitterFactory.class);
31+
32+
private final Client client;
33+
private final DataFrameAnalyticsConfig config;
1734
private final List<String> fieldNames;
1835

19-
public CrossValidationSplitterFactory(List<String> fieldNames) {
36+
public CrossValidationSplitterFactory(Client client, DataFrameAnalyticsConfig config, List<String> fieldNames) {
37+
this.client = Objects.requireNonNull(client);
38+
this.config = Objects.requireNonNull(config);
2039
this.fieldNames = Objects.requireNonNull(fieldNames);
2140
}
2241

23-
public CrossValidationSplitter create(DataFrameAnalysis analysis) {
24-
if (analysis instanceof Regression) {
25-
Regression regression = (Regression) analysis;
26-
return new RandomCrossValidationSplitter(
27-
fieldNames, regression.getDependentVariable(), regression.getTrainingPercent(), regression.getRandomizeSeed());
42+
public CrossValidationSplitter create() {
43+
if (config.getAnalysis() instanceof Regression) {
44+
return createRandomSplitter();
2845
}
29-
if (analysis instanceof Classification) {
30-
Classification classification = (Classification) analysis;
31-
return new RandomCrossValidationSplitter(
32-
fieldNames, classification.getDependentVariable(), classification.getTrainingPercent(), classification.getRandomizeSeed());
46+
if (config.getAnalysis() instanceof Classification) {
47+
return createStratifiedSplitter((Classification) config.getAnalysis());
3348
}
3449
return (row, incrementTrainingDocs, incrementTestDocs) -> incrementTrainingDocs.run();
3550
}
51+
52+
private CrossValidationSplitter createRandomSplitter() {
53+
Regression regression = (Regression) config.getAnalysis();
54+
return new RandomCrossValidationSplitter(
55+
fieldNames, regression.getDependentVariable(), regression.getTrainingPercent(), regression.getRandomizeSeed());
56+
}
57+
58+
private CrossValidationSplitter createStratifiedSplitter(Classification classification) {
59+
String aggName = "dependent_variable_terms";
60+
SearchRequestBuilder searchRequestBuilder = client.prepareSearch(config.getDest().getIndex())
61+
.setSize(0)
62+
.setAllowPartialSearchResults(false)
63+
.addAggregation(AggregationBuilders.terms(aggName)
64+
.field(classification.getDependentVariable())
65+
.size(Classification.MAX_DEPENDENT_VARIABLE_CARDINALITY));
66+
67+
try {
68+
SearchResponse searchResponse = ClientHelper.executeWithHeaders(config.getHeaders(), ClientHelper.ML_ORIGIN, client,
69+
searchRequestBuilder::get);
70+
Aggregations aggs = searchResponse.getAggregations();
71+
Terms terms = aggs.get(aggName);
72+
Map<String, Long> classCardinalities = new HashMap<>();
73+
for (Terms.Bucket bucket : terms.getBuckets()) {
74+
classCardinalities.put(String.valueOf(bucket.getKey()), bucket.getDocCount());
75+
}
76+
77+
return new StratifiedCrossValidationSplitter(fieldNames, classification.getDependentVariable(), classCardinalities,
78+
classification.getTrainingPercent(), classification.getRandomizeSeed());
79+
} catch (Exception e) {
80+
ParameterizedMessage msg = new ParameterizedMessage("[{}] Dependent variable terms search failed", config.getId());
81+
LOGGER.error(msg, e);
82+
throw new ElasticsearchException(msg.getFormattedMessage(), e);
83+
}
84+
}
3685
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitter.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,18 @@ class RandomCrossValidationSplitter implements CrossValidationSplitter {
2525
private boolean isFirstRow = true;
2626

2727
RandomCrossValidationSplitter(List<String> fieldNames, String dependentVariable, double trainingPercent, long randomizeSeed) {
28+
assert trainingPercent >= 1.0 && trainingPercent <= 100.0;
2829
this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable);
2930
this.trainingPercent = trainingPercent;
3031
this.random = new Random(randomizeSeed);
3132
}
3233

3334
private static int findDependentVariableIndex(List<String> fieldNames, String dependentVariable) {
34-
for (int i = 0; i < fieldNames.size(); i++) {
35-
if (fieldNames.get(i).equals(dependentVariable)) {
36-
return i;
37-
}
35+
int dependentVariableIndex = fieldNames.indexOf(dependentVariable);
36+
if (dependentVariableIndex < 0) {
37+
throw ExceptionsHelper.serverError("Could not find dependent variable [" + dependentVariable + "] in fields " + fieldNames);
3838
}
39-
throw ExceptionsHelper.serverError("Could not find dependent variable [" + dependentVariable + "] in fields " + fieldNames);
39+
return dependentVariableIndex;
4040
}
4141

4242
@Override
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
7+
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
8+
9+
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
10+
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
11+
12+
import java.util.HashMap;
13+
import java.util.List;
14+
import java.util.Map;
15+
import java.util.Random;
16+
17+
/**
18+
* Given a dependent variable, randomly splits the dataset trying
19+
* to preserve the proportion of each class in the training sample.
20+
*/
21+
public class StratifiedCrossValidationSplitter implements CrossValidationSplitter {
22+
23+
private final int dependentVariableIndex;
24+
private final double samplingRatio;
25+
private final Random random;
26+
private final Map<String, ClassSample> classSamples;
27+
28+
public StratifiedCrossValidationSplitter(List<String> fieldNames, String dependentVariable, Map<String, Long> classCardinalities,
29+
double trainingPercent, long randomizeSeed) {
30+
assert trainingPercent >= 1.0 && trainingPercent <= 100.0;
31+
this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable);
32+
this.samplingRatio = trainingPercent / 100.0;
33+
this.random = new Random(randomizeSeed);
34+
this.classSamples = new HashMap<>();
35+
classCardinalities.entrySet().forEach(entry -> classSamples.put(entry.getKey(), new ClassSample(entry.getValue())));
36+
}
37+
38+
private static int findDependentVariableIndex(List<String> fieldNames, String dependentVariable) {
39+
int dependentVariableIndex = fieldNames.indexOf(dependentVariable);
40+
if (dependentVariableIndex < 0) {
41+
throw ExceptionsHelper.serverError("Could not find dependent variable [" + dependentVariable + "] in fields " + fieldNames);
42+
}
43+
return dependentVariableIndex;
44+
}
45+
46+
@Override
47+
public void process(String[] row, Runnable incrementTrainingDocs, Runnable incrementTestDocs) {
48+
49+
if (canBeUsedForTraining(row) == false) {
50+
incrementTestDocs.run();
51+
return;
52+
}
53+
54+
String classValue = row[dependentVariableIndex];
55+
ClassSample sample = classSamples.get(classValue);
56+
if (sample == null) {
57+
throw new IllegalStateException("Unknown class [" + classValue + "]; expected one of " + classSamples.keySet());
58+
}
59+
60+
// The idea here is that the probability increases as the chances we have to get the target proportion
61+
// for a class decreases.
62+
double p = (samplingRatio * sample.cardinality - sample.training) / (sample.cardinality - sample.observed);
63+
64+
boolean isTraining = random.nextDouble() <= p;
65+
66+
sample.observed++;
67+
if (isTraining) {
68+
sample.training++;
69+
incrementTrainingDocs.run();
70+
} else {
71+
row[dependentVariableIndex] = DataFrameDataExtractor.NULL_VALUE;
72+
incrementTestDocs.run();
73+
}
74+
}
75+
76+
private boolean canBeUsedForTraining(String[] row) {
77+
return row[dependentVariableIndex] != DataFrameDataExtractor.NULL_VALUE;
78+
}
79+
80+
private static class ClassSample {
81+
82+
private final long cardinality;
83+
private long training;
84+
private long observed;
85+
86+
private ClassSample(long cardinality) {
87+
this.cardinality = cardinality;
88+
}
89+
}
90+
}

0 commit comments

Comments
 (0)