diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java index 255ecfd77c021..c7baad202d2f6 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java @@ -31,8 +31,8 @@ import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory; -import org.elasticsearch.xpack.ml.dataframe.process.customprocessing.CustomProcessor; -import org.elasticsearch.xpack.ml.dataframe.process.customprocessing.CustomProcessorFactory; +import org.elasticsearch.xpack.ml.dataframe.process.crossvalidation.CrossValidationSplitter; +import org.elasticsearch.xpack.ml.dataframe.process.crossvalidation.CrossValidationSplitterFactory; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker; import org.elasticsearch.xpack.ml.extractor.ExtractedFields; @@ -208,7 +208,8 @@ private void processData(DataFrameAnalyticsTask task, ProcessContext processCont private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess process, DataFrameAnalysis analysis, ProgressTracker progressTracker) throws IOException { - CustomProcessor customProcessor = new CustomProcessorFactory(dataExtractor.getFieldNames()).create(analysis); + CrossValidationSplitter crossValidationSplitter = new CrossValidationSplitterFactory(dataExtractor.getFieldNames()) + .create(analysis); // The extra fields are for the doc hash and the control field (should be an empty string) String[] record = new String[dataExtractor.getFieldNames().size() + 2]; @@ -226,7 +227,7 @@ private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProces String[] rowValues = row.getValues(); System.arraycopy(rowValues, 0, record, 0, rowValues.length); record[record.length - 2] = String.valueOf(row.getChecksum()); - customProcessor.process(record); + crossValidationSplitter.process(record); process.writeRecord(record); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/CustomProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitter.java similarity index 60% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/CustomProcessor.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitter.java index 518aee13b8c57..5d12a2a81a607 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/CustomProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitter.java @@ -3,12 +3,12 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.dataframe.process.customprocessing; +package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation; /** - * A processor to manipulate rows before writing them to the process + * Processes rows in order to split the dataset in training and test subsets */ -public interface CustomProcessor { +public interface CrossValidationSplitter { void process(String[] row); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/CustomProcessorFactory.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitterFactory.java similarity index 76% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/CustomProcessorFactory.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitterFactory.java index 77f0b127a2638..47c052dd0bf84 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/CustomProcessorFactory.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/CrossValidationSplitterFactory.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.dataframe.process.customprocessing; +package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; @@ -12,23 +12,23 @@ import java.util.List; import java.util.Objects; -public class CustomProcessorFactory { +public class CrossValidationSplitterFactory { private final List fieldNames; - public CustomProcessorFactory(List fieldNames) { + public CrossValidationSplitterFactory(List fieldNames) { this.fieldNames = Objects.requireNonNull(fieldNames); } - public CustomProcessor create(DataFrameAnalysis analysis) { + public CrossValidationSplitter create(DataFrameAnalysis analysis) { if (analysis instanceof Regression) { Regression regression = (Regression) analysis; - return new DatasetSplittingCustomProcessor( + return new RandomCrossValidationSplitter( fieldNames, regression.getDependentVariable(), regression.getTrainingPercent(), regression.getRandomizeSeed()); } if (analysis instanceof Classification) { Classification classification = (Classification) analysis; - return new DatasetSplittingCustomProcessor( + return new RandomCrossValidationSplitter( fieldNames, classification.getDependentVariable(), classification.getTrainingPercent(), classification.getRandomizeSeed()); } return row -> {}; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/DatasetSplittingCustomProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitter.java similarity index 82% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/DatasetSplittingCustomProcessor.java rename to x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitter.java index 6e6acfb271e3f..0afc59628e7de 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/DatasetSplittingCustomProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitter.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.dataframe.process.customprocessing; +package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; @@ -12,19 +12,19 @@ import java.util.Random; /** - * A processor that randomly clears the dependent variable value - * in order to split the dataset in training and validation data. + * A cross validation splitter that randomly clears the dependent variable value + * in order to split the dataset in training and test data. * This relies on the fact that when the dependent variable field * is empty, then the row is not used for training but only to make predictions. */ -class DatasetSplittingCustomProcessor implements CustomProcessor { +class RandomCrossValidationSplitter implements CrossValidationSplitter { private final int dependentVariableIndex; private final double trainingPercent; private final Random random; private boolean isFirstRow = true; - DatasetSplittingCustomProcessor(List fieldNames, String dependentVariable, double trainingPercent, long randomizeSeed) { + RandomCrossValidationSplitter(List fieldNames, String dependentVariable, double trainingPercent, long randomizeSeed) { this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable); this.trainingPercent = trainingPercent; this.random = new Random(randomizeSeed); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/DatasetSplittingCustomProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitterTests.java similarity index 85% rename from x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/DatasetSplittingCustomProcessorTests.java rename to x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitterTests.java index ac897413a4eac..eea102e673893 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/customprocessing/DatasetSplittingCustomProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/crossvalidation/RandomCrossValidationSplitterTests.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.dataframe.process.customprocessing; +package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor; @@ -20,7 +20,7 @@ import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.lessThan; -public class DatasetSplittingCustomProcessorTests extends ESTestCase { +public class RandomCrossValidationSplitterTests extends ESTestCase { private List fields; private int dependentVariableIndex; @@ -40,7 +40,7 @@ public void setUpTests() { } public void testProcess_GivenRowsWithoutDependentVariableValue() { - CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 50.0, randomizeSeed); + CrossValidationSplitter crossValidationSplitter = new RandomCrossValidationSplitter(fields, dependentVariable, 50.0, randomizeSeed); for (int i = 0; i < 100; i++) { String[] row = new String[fields.size()]; @@ -50,7 +50,7 @@ public void testProcess_GivenRowsWithoutDependentVariableValue() { } String[] processedRow = Arrays.copyOf(row, row.length); - customProcessor.process(processedRow); + crossValidationSplitter.process(processedRow); // As all these rows have no dependent variable value, they're not for training and should be unaffected assertThat(Arrays.equals(processedRow, row), is(true)); @@ -58,7 +58,8 @@ public void testProcess_GivenRowsWithoutDependentVariableValue() { } public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsHundred() { - CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 100.0, randomizeSeed); + CrossValidationSplitter crossValidationSplitter = new RandomCrossValidationSplitter( + fields, dependentVariable, 100.0, randomizeSeed); for (int i = 0; i < 100; i++) { String[] row = new String[fields.size()]; @@ -68,7 +69,7 @@ public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIs } String[] processedRow = Arrays.copyOf(row, row.length); - customProcessor.process(processedRow); + crossValidationSplitter.process(processedRow); // We should pick them all as training percent is 100 assertThat(Arrays.equals(processedRow, row), is(true)); @@ -78,7 +79,8 @@ public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIs public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsRandom() { double trainingPercent = randomDoubleBetween(1.0, 100.0, true); double trainingFraction = trainingPercent / 100; - CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, trainingPercent, randomizeSeed); + CrossValidationSplitter crossValidationSplitter = new RandomCrossValidationSplitter( + fields, dependentVariable, trainingPercent, randomizeSeed); int runCount = 20; int rowsCount = 1000; @@ -92,7 +94,7 @@ public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIs } String[] processedRow = Arrays.copyOf(row, row.length); - customProcessor.process(processedRow); + crossValidationSplitter.process(processedRow); for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) { if (fieldIndex != dependentVariableIndex) { @@ -124,7 +126,8 @@ public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIs } public void testProcess_ShouldHaveAtLeastOneTrainingRow() { - CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 1.0, randomizeSeed); + CrossValidationSplitter crossValidationSplitter = new RandomCrossValidationSplitter( + fields, dependentVariable, 1.0, randomizeSeed); // We have some non-training rows and then a training row to check // we maintain the first training row and not just the first row @@ -139,7 +142,7 @@ public void testProcess_ShouldHaveAtLeastOneTrainingRow() { } String[] processedRow = Arrays.copyOf(row, row.length); - customProcessor.process(processedRow); + crossValidationSplitter.process(processedRow); assertThat(Arrays.equals(processedRow, row), is(true)); }