Skip to content

Commit 3873510

Browse files
[7.x][ML] Refactor DFA custom processor to cross validation splitter (#53915) (#53956)
While `CustomProcessor` is generic and allows for flexibility, there are new requirements that make cross validation a concept it's hard to abstract behind custom processor. In particular, we would like to add data_counts to the DFA jobs stats. Counting training VS. test docs would be a useful statistic. We would also want to add a different cross validation strategy for multiclass classification. This commit renames custom processors to cross validation splitters which allows for those enhancements without cryptically doing things as a side effect of the abstract custom processing. Backport of #53915
1 parent 0c010e1 commit 3873510

File tree

5 files changed

+32
-28
lines changed

5 files changed

+32
-28
lines changed

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@
3131
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask;
3232
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
3333
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
34-
import org.elasticsearch.xpack.ml.dataframe.process.customprocessing.CustomProcessor;
35-
import org.elasticsearch.xpack.ml.dataframe.process.customprocessing.CustomProcessorFactory;
34+
import org.elasticsearch.xpack.ml.dataframe.process.crossvalidation.CrossValidationSplitter;
35+
import org.elasticsearch.xpack.ml.dataframe.process.crossvalidation.CrossValidationSplitterFactory;
3636
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
3737
import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker;
3838
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
@@ -208,7 +208,8 @@ private void processData(DataFrameAnalyticsTask task, ProcessContext processCont
208208
private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess<AnalyticsResult> process,
209209
DataFrameAnalysis analysis, ProgressTracker progressTracker) throws IOException {
210210

211-
CustomProcessor customProcessor = new CustomProcessorFactory(dataExtractor.getFieldNames()).create(analysis);
211+
CrossValidationSplitter crossValidationSplitter = new CrossValidationSplitterFactory(dataExtractor.getFieldNames())
212+
.create(analysis);
212213

213214
// The extra fields are for the doc hash and the control field (should be an empty string)
214215
String[] record = new String[dataExtractor.getFieldNames().size() + 2];
@@ -226,7 +227,7 @@ private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProces
226227
String[] rowValues = row.getValues();
227228
System.arraycopy(rowValues, 0, record, 0, rowValues.length);
228229
record[record.length - 2] = String.valueOf(row.getChecksum());
229-
customProcessor.process(record);
230+
crossValidationSplitter.process(record);
230231
process.writeRecord(record);
231232
}
232233
}
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
* or more contributor license agreements. Licensed under the Elastic License;
44
* you may not use this file except in compliance with the Elastic License.
55
*/
6-
package org.elasticsearch.xpack.ml.dataframe.process.customprocessing;
6+
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
77

88
/**
9-
* A processor to manipulate rows before writing them to the process
9+
* Processes rows in order to split the dataset in training and test subsets
1010
*/
11-
public interface CustomProcessor {
11+
public interface CrossValidationSplitter {
1212

1313
void process(String[] row);
1414
}
Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
* or more contributor license agreements. Licensed under the Elastic License;
44
* you may not use this file except in compliance with the Elastic License.
55
*/
6-
package org.elasticsearch.xpack.ml.dataframe.process.customprocessing;
6+
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
77

88
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
99
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
@@ -12,23 +12,23 @@
1212
import java.util.List;
1313
import java.util.Objects;
1414

15-
public class CustomProcessorFactory {
15+
public class CrossValidationSplitterFactory {
1616

1717
private final List<String> fieldNames;
1818

19-
public CustomProcessorFactory(List<String> fieldNames) {
19+
public CrossValidationSplitterFactory(List<String> fieldNames) {
2020
this.fieldNames = Objects.requireNonNull(fieldNames);
2121
}
2222

23-
public CustomProcessor create(DataFrameAnalysis analysis) {
23+
public CrossValidationSplitter create(DataFrameAnalysis analysis) {
2424
if (analysis instanceof Regression) {
2525
Regression regression = (Regression) analysis;
26-
return new DatasetSplittingCustomProcessor(
26+
return new RandomCrossValidationSplitter(
2727
fieldNames, regression.getDependentVariable(), regression.getTrainingPercent(), regression.getRandomizeSeed());
2828
}
2929
if (analysis instanceof Classification) {
3030
Classification classification = (Classification) analysis;
31-
return new DatasetSplittingCustomProcessor(
31+
return new RandomCrossValidationSplitter(
3232
fieldNames, classification.getDependentVariable(), classification.getTrainingPercent(), classification.getRandomizeSeed());
3333
}
3434
return row -> {};
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
* or more contributor license agreements. Licensed under the Elastic License;
44
* you may not use this file except in compliance with the Elastic License.
55
*/
6-
package org.elasticsearch.xpack.ml.dataframe.process.customprocessing;
6+
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
77

88
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
99
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
@@ -12,19 +12,19 @@
1212
import java.util.Random;
1313

1414
/**
15-
* A processor that randomly clears the dependent variable value
16-
* in order to split the dataset in training and validation data.
15+
* A cross validation splitter that randomly clears the dependent variable value
16+
* in order to split the dataset in training and test data.
1717
* This relies on the fact that when the dependent variable field
1818
* is empty, then the row is not used for training but only to make predictions.
1919
*/
20-
class DatasetSplittingCustomProcessor implements CustomProcessor {
20+
class RandomCrossValidationSplitter implements CrossValidationSplitter {
2121

2222
private final int dependentVariableIndex;
2323
private final double trainingPercent;
2424
private final Random random;
2525
private boolean isFirstRow = true;
2626

27-
DatasetSplittingCustomProcessor(List<String> fieldNames, String dependentVariable, double trainingPercent, long randomizeSeed) {
27+
RandomCrossValidationSplitter(List<String> fieldNames, String dependentVariable, double trainingPercent, long randomizeSeed) {
2828
this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable);
2929
this.trainingPercent = trainingPercent;
3030
this.random = new Random(randomizeSeed);
Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
* or more contributor license agreements. Licensed under the Elastic License;
44
* you may not use this file except in compliance with the Elastic License.
55
*/
6-
package org.elasticsearch.xpack.ml.dataframe.process.customprocessing;
6+
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;
77

88
import org.elasticsearch.test.ESTestCase;
99
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
@@ -20,7 +20,7 @@
2020
import static org.hamcrest.Matchers.is;
2121
import static org.hamcrest.Matchers.lessThan;
2222

23-
public class DatasetSplittingCustomProcessorTests extends ESTestCase {
23+
public class RandomCrossValidationSplitterTests extends ESTestCase {
2424

2525
private List<String> fields;
2626
private int dependentVariableIndex;
@@ -40,7 +40,7 @@ public void setUpTests() {
4040
}
4141

4242
public void testProcess_GivenRowsWithoutDependentVariableValue() {
43-
CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 50.0, randomizeSeed);
43+
CrossValidationSplitter crossValidationSplitter = new RandomCrossValidationSplitter(fields, dependentVariable, 50.0, randomizeSeed);
4444

4545
for (int i = 0; i < 100; i++) {
4646
String[] row = new String[fields.size()];
@@ -50,15 +50,16 @@ public void testProcess_GivenRowsWithoutDependentVariableValue() {
5050
}
5151

5252
String[] processedRow = Arrays.copyOf(row, row.length);
53-
customProcessor.process(processedRow);
53+
crossValidationSplitter.process(processedRow);
5454

5555
// As all these rows have no dependent variable value, they're not for training and should be unaffected
5656
assertThat(Arrays.equals(processedRow, row), is(true));
5757
}
5858
}
5959

6060
public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsHundred() {
61-
CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 100.0, randomizeSeed);
61+
CrossValidationSplitter crossValidationSplitter = new RandomCrossValidationSplitter(
62+
fields, dependentVariable, 100.0, randomizeSeed);
6263

6364
for (int i = 0; i < 100; i++) {
6465
String[] row = new String[fields.size()];
@@ -68,7 +69,7 @@ public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIs
6869
}
6970

7071
String[] processedRow = Arrays.copyOf(row, row.length);
71-
customProcessor.process(processedRow);
72+
crossValidationSplitter.process(processedRow);
7273

7374
// We should pick them all as training percent is 100
7475
assertThat(Arrays.equals(processedRow, row), is(true));
@@ -78,7 +79,8 @@ public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIs
7879
public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsRandom() {
7980
double trainingPercent = randomDoubleBetween(1.0, 100.0, true);
8081
double trainingFraction = trainingPercent / 100;
81-
CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, trainingPercent, randomizeSeed);
82+
CrossValidationSplitter crossValidationSplitter = new RandomCrossValidationSplitter(
83+
fields, dependentVariable, trainingPercent, randomizeSeed);
8284

8385
int runCount = 20;
8486
int rowsCount = 1000;
@@ -92,7 +94,7 @@ public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIs
9294
}
9395

9496
String[] processedRow = Arrays.copyOf(row, row.length);
95-
customProcessor.process(processedRow);
97+
crossValidationSplitter.process(processedRow);
9698

9799
for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
98100
if (fieldIndex != dependentVariableIndex) {
@@ -124,7 +126,8 @@ public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIs
124126
}
125127

126128
public void testProcess_ShouldHaveAtLeastOneTrainingRow() {
127-
CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 1.0, randomizeSeed);
129+
CrossValidationSplitter crossValidationSplitter = new RandomCrossValidationSplitter(
130+
fields, dependentVariable, 1.0, randomizeSeed);
128131

129132
// We have some non-training rows and then a training row to check
130133
// we maintain the first training row and not just the first row
@@ -139,7 +142,7 @@ public void testProcess_ShouldHaveAtLeastOneTrainingRow() {
139142
}
140143

141144
String[] processedRow = Arrays.copyOf(row, row.length);
142-
customProcessor.process(processedRow);
145+
crossValidationSplitter.process(processedRow);
143146

144147
assertThat(Arrays.equals(processedRow, row), is(true));
145148
}

0 commit comments

Comments
 (0)