Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
import org.elasticsearch.xpack.ml.dataframe.process.customprocessing.CustomProcessor;
import org.elasticsearch.xpack.ml.dataframe.process.customprocessing.CustomProcessorFactory;
import org.elasticsearch.xpack.ml.dataframe.process.crossvalidation.CrossValidationSplitter;
import org.elasticsearch.xpack.ml.dataframe.process.crossvalidation.CrossValidationSplitterFactory;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
Expand Down Expand Up @@ -208,7 +208,8 @@ private void processData(DataFrameAnalyticsTask task, ProcessContext processCont
private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProcess<AnalyticsResult> process,
DataFrameAnalysis analysis, ProgressTracker progressTracker) throws IOException {

CustomProcessor customProcessor = new CustomProcessorFactory(dataExtractor.getFieldNames()).create(analysis);
CrossValidationSplitter crossValidationSplitter = new CrossValidationSplitterFactory(dataExtractor.getFieldNames())
.create(analysis);

// The extra fields are for the doc hash and the control field (should be an empty string)
String[] record = new String[dataExtractor.getFieldNames().size() + 2];
Expand All @@ -226,7 +227,7 @@ private void writeDataRows(DataFrameDataExtractor dataExtractor, AnalyticsProces
String[] rowValues = row.getValues();
System.arraycopy(rowValues, 0, record, 0, rowValues.length);
record[record.length - 2] = String.valueOf(row.getChecksum());
customProcessor.process(record);
crossValidationSplitter.process(record);
process.writeRecord(record);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process.customprocessing;
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;

/**
* A processor to manipulate rows before writing them to the process
* Processes rows in order to split the dataset in training and test subsets
*/
public interface CustomProcessor {
public interface CrossValidationSplitter {

void process(String[] row);
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process.customprocessing;
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;

import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification;
import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis;
Expand All @@ -12,23 +12,23 @@
import java.util.List;
import java.util.Objects;

public class CustomProcessorFactory {
public class CrossValidationSplitterFactory {

private final List<String> fieldNames;

public CustomProcessorFactory(List<String> fieldNames) {
public CrossValidationSplitterFactory(List<String> fieldNames) {
this.fieldNames = Objects.requireNonNull(fieldNames);
}

public CustomProcessor create(DataFrameAnalysis analysis) {
public CrossValidationSplitter create(DataFrameAnalysis analysis) {
if (analysis instanceof Regression) {
Regression regression = (Regression) analysis;
return new DatasetSplittingCustomProcessor(
return new RandomCrossValidationSplitter(
fieldNames, regression.getDependentVariable(), regression.getTrainingPercent(), regression.getRandomizeSeed());
}
if (analysis instanceof Classification) {
Classification classification = (Classification) analysis;
return new DatasetSplittingCustomProcessor(
return new RandomCrossValidationSplitter(
fieldNames, classification.getDependentVariable(), classification.getTrainingPercent(), classification.getRandomizeSeed());
}
return row -> {};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process.customprocessing;
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;

import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
Expand All @@ -12,19 +12,19 @@
import java.util.Random;

/**
* A processor that randomly clears the dependent variable value
* in order to split the dataset in training and validation data.
* A cross validation splitter that randomly clears the dependent variable value
* in order to split the dataset in training and test data.
* This relies on the fact that when the dependent variable field
* is empty, then the row is not used for training but only to make predictions.
*/
class DatasetSplittingCustomProcessor implements CustomProcessor {
class RandomCrossValidationSplitter implements CrossValidationSplitter {

private final int dependentVariableIndex;
private final double trainingPercent;
private final Random random;
private boolean isFirstRow = true;

DatasetSplittingCustomProcessor(List<String> fieldNames, String dependentVariable, double trainingPercent, long randomizeSeed) {
RandomCrossValidationSplitter(List<String> fieldNames, String dependentVariable, double trainingPercent, long randomizeSeed) {
this.dependentVariableIndex = findDependentVariableIndex(fieldNames, dependentVariable);
this.trainingPercent = trainingPercent;
this.random = new Random(randomizeSeed);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.ml.dataframe.process.customprocessing;
package org.elasticsearch.xpack.ml.dataframe.process.crossvalidation;

import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractor;
Expand All @@ -20,7 +20,7 @@
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.lessThan;

public class DatasetSplittingCustomProcessorTests extends ESTestCase {
public class RandomCrossValidationSplitterTests extends ESTestCase {

private List<String> fields;
private int dependentVariableIndex;
Expand All @@ -40,7 +40,7 @@ public void setUpTests() {
}

public void testProcess_GivenRowsWithoutDependentVariableValue() {
CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 50.0, randomizeSeed);
CrossValidationSplitter crossValidationSplitter = new RandomCrossValidationSplitter(fields, dependentVariable, 50.0, randomizeSeed);

for (int i = 0; i < 100; i++) {
String[] row = new String[fields.size()];
Expand All @@ -50,15 +50,16 @@ public void testProcess_GivenRowsWithoutDependentVariableValue() {
}

String[] processedRow = Arrays.copyOf(row, row.length);
customProcessor.process(processedRow);
crossValidationSplitter.process(processedRow);

// As all these rows have no dependent variable value, they're not for training and should be unaffected
assertThat(Arrays.equals(processedRow, row), is(true));
}
}

public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsHundred() {
CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 100.0, randomizeSeed);
CrossValidationSplitter crossValidationSplitter = new RandomCrossValidationSplitter(
fields, dependentVariable, 100.0, randomizeSeed);

for (int i = 0; i < 100; i++) {
String[] row = new String[fields.size()];
Expand All @@ -68,7 +69,7 @@ public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIs
}

String[] processedRow = Arrays.copyOf(row, row.length);
customProcessor.process(processedRow);
crossValidationSplitter.process(processedRow);

// We should pick them all as training percent is 100
assertThat(Arrays.equals(processedRow, row), is(true));
Expand All @@ -78,7 +79,8 @@ public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIs
public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIsRandom() {
double trainingPercent = randomDoubleBetween(1.0, 100.0, true);
double trainingFraction = trainingPercent / 100;
CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, trainingPercent, randomizeSeed);
CrossValidationSplitter crossValidationSplitter = new RandomCrossValidationSplitter(
fields, dependentVariable, trainingPercent, randomizeSeed);

int runCount = 20;
int rowsCount = 1000;
Expand All @@ -92,7 +94,7 @@ public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIs
}

String[] processedRow = Arrays.copyOf(row, row.length);
customProcessor.process(processedRow);
crossValidationSplitter.process(processedRow);

for (int fieldIndex = 0; fieldIndex < fields.size(); fieldIndex++) {
if (fieldIndex != dependentVariableIndex) {
Expand Down Expand Up @@ -124,7 +126,8 @@ public void testProcess_GivenRowsWithDependentVariableValue_AndTrainingPercentIs
}

public void testProcess_ShouldHaveAtLeastOneTrainingRow() {
CustomProcessor customProcessor = new DatasetSplittingCustomProcessor(fields, dependentVariable, 1.0, randomizeSeed);
CrossValidationSplitter crossValidationSplitter = new RandomCrossValidationSplitter(
fields, dependentVariable, 1.0, randomizeSeed);

// We have some non-training rows and then a training row to check
// we maintain the first training row and not just the first row
Expand All @@ -139,7 +142,7 @@ public void testProcess_ShouldHaveAtLeastOneTrainingRow() {
}

String[] processedRow = Arrays.copyOf(row, row.length);
customProcessor.process(processedRow);
crossValidationSplitter.process(processedRow);

assertThat(Arrays.equals(processedRow, row), is(true));
}
Expand Down