Skip to content

Commit 873ad3f

Browse files
[7.x][ML] Add option to regression to randomize training set (#45969) (#46017)
Adds a parameter `training_percent` to regression. The default value is `100`. When the parameter is set to a value less than `100`, from the rows that can be used for training (ie. those that have a value for the dependent variable) we randomly choose whether to actually use for training. This enables splitting the data into a training set and the rest, usually called testing, validation or holdout set, which allows for validating the model on data that have not been used for training. Technically, the analytics process considers as training the data that have a value for the dependent variable. Thus, when we decide a training row is not going to be used for training, we simply clear the row's dependent variable.
1 parent 7b6246e commit 873ad3f

File tree

15 files changed

+622
-95
lines changed

15 files changed

+622
-95
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,20 +32,23 @@ public class Regression implements DataFrameAnalysis {
3232
public static final ParseField MAXIMUM_NUMBER_TREES = new ParseField("maximum_number_trees");
3333
public static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
3434
public static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
35+
public static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
3536

3637
private static final ConstructingObjectParser<Regression, Void> LENIENT_PARSER = createParser(true);
3738
private static final ConstructingObjectParser<Regression, Void> STRICT_PARSER = createParser(false);
3839

3940
private static ConstructingObjectParser<Regression, Void> createParser(boolean lenient) {
4041
ConstructingObjectParser<Regression, Void> parser = new ConstructingObjectParser<>(NAME.getPreferredName(), lenient,
41-
a -> new Regression((String) a[0], (Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (String) a[6]));
42+
a -> new Regression((String) a[0], (Double) a[1], (Double) a[2], (Double) a[3], (Integer) a[4], (Double) a[5], (String) a[6],
43+
(Double) a[7]));
4244
parser.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
4345
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), LAMBDA);
4446
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), GAMMA);
4547
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), ETA);
4648
parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), MAXIMUM_NUMBER_TREES);
4749
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION);
4850
parser.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
51+
parser.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
4952
return parser;
5053
}
5154

@@ -60,9 +63,11 @@ public static Regression fromXContent(XContentParser parser, boolean ignoreUnkno
6063
private final Integer maximumNumberTrees;
6164
private final Double featureBagFraction;
6265
private final String predictionFieldName;
66+
private final double trainingPercent;
6367

6468
public Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
65-
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName) {
69+
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName,
70+
@Nullable Double trainingPercent) {
6671
this.dependentVariable = Objects.requireNonNull(dependentVariable);
6772

6873
if (lambda != null && lambda < 0) {
@@ -91,10 +96,15 @@ public Regression(String dependentVariable, @Nullable Double lambda, @Nullable D
9196
this.featureBagFraction = featureBagFraction;
9297

9398
this.predictionFieldName = predictionFieldName;
99+
100+
if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0)) {
101+
throw ExceptionsHelper.badRequestException("[{}] must be a double in [1, 100]", TRAINING_PERCENT.getPreferredName());
102+
}
103+
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
94104
}
95105

96106
public Regression(String dependentVariable) {
97-
this(dependentVariable, null, null, null, null, null, null);
107+
this(dependentVariable, null, null, null, null, null, null, null);
98108
}
99109

100110
public Regression(StreamInput in) throws IOException {
@@ -105,6 +115,15 @@ public Regression(StreamInput in) throws IOException {
105115
maximumNumberTrees = in.readOptionalVInt();
106116
featureBagFraction = in.readOptionalDouble();
107117
predictionFieldName = in.readOptionalString();
118+
trainingPercent = in.readDouble();
119+
}
120+
121+
public String getDependentVariable() {
122+
return dependentVariable;
123+
}
124+
125+
public double getTrainingPercent() {
126+
return trainingPercent;
108127
}
109128

110129
@Override
@@ -121,6 +140,7 @@ public void writeTo(StreamOutput out) throws IOException {
121140
out.writeOptionalVInt(maximumNumberTrees);
122141
out.writeOptionalDouble(featureBagFraction);
123142
out.writeOptionalString(predictionFieldName);
143+
out.writeDouble(trainingPercent);
124144
}
125145

126146
@Override
@@ -145,6 +165,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
145165
if (predictionFieldName != null) {
146166
builder.field(PREDICTION_FIELD_NAME.getPreferredName(), predictionFieldName);
147167
}
168+
builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent);
148169
builder.endObject();
149170
return builder;
150171
}
@@ -191,7 +212,8 @@ public boolean supportsMissingValues() {
191212

192213
@Override
193214
public int hashCode() {
194-
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName);
215+
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
216+
trainingPercent);
195217
}
196218

197219
@Override
@@ -205,6 +227,7 @@ public boolean equals(Object o) {
205227
&& Objects.equals(eta, that.eta)
206228
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
207229
&& Objects.equals(featureBagFraction, that.featureBagFraction)
208-
&& Objects.equals(predictionFieldName, that.predictionFieldName);
230+
&& Objects.equals(predictionFieldName, that.predictionFieldName)
231+
&& trainingPercent == that.trainingPercent;
209232
}
210233
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/persistence/ElasticsearchMappings.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,9 @@ public static void addDataFrameAnalyticsFields(XContentBuilder builder) throws I
467467
.startObject(Regression.PREDICTION_FIELD_NAME.getPreferredName())
468468
.field(TYPE, KEYWORD)
469469
.endObject()
470+
.startObject(Regression.TRAINING_PERCENT.getPreferredName())
471+
.field(TYPE, DOUBLE)
472+
.endObject()
470473
.endObject()
471474
.endObject()
472475
.endObject()

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/results/ReservedFieldNames.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ public final class ReservedFieldNames {
309309
Regression.MAXIMUM_NUMBER_TREES.getPreferredName(),
310310
Regression.FEATURE_BAG_FRACTION.getPreferredName(),
311311
Regression.PREDICTION_FIELD_NAME.getPreferredName(),
312+
Regression.TRAINING_PERCENT.getPreferredName(),
312313

313314
ElasticsearchMappings.CONFIG_TYPE,
314315

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@ public static Regression createRandom() {
3333
Integer maximumNumberTrees = randomBoolean() ? null : randomIntBetween(1, 2000);
3434
Double featureBagFraction = randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false);
3535
String predictionFieldName = randomBoolean() ? null : randomAlphaOfLength(10);
36+
Double trainingPercent = randomBoolean() ? null : randomDoubleBetween(0.0, 100.0, true);
3637
return new Regression(randomAlphaOfLength(10), lambda, gamma, eta, maximumNumberTrees, featureBagFraction,
37-
predictionFieldName);
38+
predictionFieldName, trainingPercent);
3839
}
3940

4041
@Override
@@ -44,57 +45,83 @@ protected Writeable.Reader<Regression> instanceReader() {
4445

4546
public void testRegression_GivenNegativeLambda() {
4647
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
47-
() -> new Regression("foo", -0.00001, 0.0, 0.5, 500, 0.3, "result"));
48+
() -> new Regression("foo", -0.00001, 0.0, 0.5, 500, 0.3, "result", 100.0));
4849

4950
assertThat(e.getMessage(), equalTo("[lambda] must be a non-negative double"));
5051
}
5152

5253
public void testRegression_GivenNegativeGamma() {
5354
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
54-
() -> new Regression("foo", 0.0, -0.00001, 0.5, 500, 0.3, "result"));
55+
() -> new Regression("foo", 0.0, -0.00001, 0.5, 500, 0.3, "result", 100.0));
5556

5657
assertThat(e.getMessage(), equalTo("[gamma] must be a non-negative double"));
5758
}
5859

5960
public void testRegression_GivenEtaIsZero() {
6061
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
61-
() -> new Regression("foo", 0.0, 0.0, 0.0, 500, 0.3, "result"));
62+
() -> new Regression("foo", 0.0, 0.0, 0.0, 500, 0.3, "result", 100.0));
6263

6364
assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]"));
6465
}
6566

6667
public void testRegression_GivenEtaIsGreaterThanOne() {
6768
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
68-
() -> new Regression("foo", 0.0, 0.0, 1.00001, 500, 0.3, "result"));
69+
() -> new Regression("foo", 0.0, 0.0, 1.00001, 500, 0.3, "result", 100.0));
6970

7071
assertThat(e.getMessage(), equalTo("[eta] must be a double in [0.001, 1]"));
7172
}
7273

7374
public void testRegression_GivenMaximumNumberTreesIsZero() {
7475
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
75-
() -> new Regression("foo", 0.0, 0.0, 0.5, 0, 0.3, "result"));
76+
() -> new Regression("foo", 0.0, 0.0, 0.5, 0, 0.3, "result", 100.0));
7677

7778
assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]"));
7879
}
7980

8081
public void testRegression_GivenMaximumNumberTreesIsGreaterThan2k() {
8182
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
82-
() -> new Regression("foo", 0.0, 0.0, 0.5, 2001, 0.3, "result"));
83+
() -> new Regression("foo", 0.0, 0.0, 0.5, 2001, 0.3, "result", 100.0));
8384

8485
assertThat(e.getMessage(), equalTo("[maximum_number_trees] must be an integer in [1, 2000]"));
8586
}
8687

8788
public void testRegression_GivenFeatureBagFractionIsLessThanZero() {
8889
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
89-
() -> new Regression("foo", 0.0, 0.0, 0.5, 500, -0.00001, "result"));
90+
() -> new Regression("foo", 0.0, 0.0, 0.5, 500, -0.00001, "result", 100.0));
9091

9192
assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]"));
9293
}
9394

9495
public void testRegression_GivenFeatureBagFractionIsGreaterThanOne() {
9596
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
96-
() -> new Regression("foo", 0.0, 0.0, 0.5, 500, 1.00001, "result"));
97+
() -> new Regression("foo", 0.0, 0.0, 0.5, 500, 1.00001, "result", 100.0));
9798

9899
assertThat(e.getMessage(), equalTo("[feature_bag_fraction] must be a double in (0, 1]"));
99100
}
101+
102+
public void testRegression_GivenTrainingPercentIsNull() {
103+
Regression regression = new Regression("foo", 0.0, 0.0, 0.5, 500, 1.0, "result", null);
104+
assertThat(regression.getTrainingPercent(), equalTo(100.0));
105+
}
106+
107+
public void testRegression_GivenTrainingPercentIsBoundary() {
108+
Regression regression = new Regression("foo", 0.0, 0.0, 0.5, 500, 1.0, "result", 1.0);
109+
assertThat(regression.getTrainingPercent(), equalTo(1.0));
110+
regression = new Regression("foo", 0.0, 0.0, 0.5, 500, 1.0, "result", 100.0);
111+
assertThat(regression.getTrainingPercent(), equalTo(100.0));
112+
}
113+
114+
public void testRegression_GivenTrainingPercentIsLessThanOne() {
115+
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
116+
() -> new Regression("foo", 0.0, 0.0, 0.5, 500, 1.0, "result", 0.999));
117+
118+
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
119+
}
120+
121+
public void testRegression_GivenTrainingPercentIsGreaterThan100() {
122+
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class,
123+
() -> new Regression("foo", 0.0, 0.0, 0.5, 500, 1.0, "result", 100.0001));
124+
125+
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
126+
}
100127
}

x-pack/plugin/ml/qa/ml-with-security/build.gradle

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ integTest.runner {
7171
'ml/data_frame_analytics_crud/Test put regression given maximum_number_trees is greater than 2k',
7272
'ml/data_frame_analytics_crud/Test put regression given feature_bag_fraction is negative',
7373
'ml/data_frame_analytics_crud/Test put regression given feature_bag_fraction is greater than one',
74+
'ml/data_frame_analytics_crud/Test put regression given training_percent is less than one',
75+
'ml/data_frame_analytics_crud/Test put regression given training_percent is greater than hundred',
7476
'ml/evaluate_data_frame/Test given missing index',
7577
'ml/evaluate_data_frame/Test given index does not exist',
7678
'ml/evaluate_data_frame/Test given missing evaluation',

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,12 +143,12 @@ protected SearchResponse searchStoredProgress(String id) {
143143
}
144144

145145
protected static DataFrameAnalyticsConfig buildRegressionAnalytics(String id, String[] sourceIndex, String destIndex,
146-
@Nullable String resultsField, String dependentVariable) {
146+
@Nullable String resultsField, Regression regression) {
147147
DataFrameAnalyticsConfig.Builder configBuilder = new DataFrameAnalyticsConfig.Builder();
148148
configBuilder.setId(id);
149149
configBuilder.setSource(new DataFrameAnalyticsSource(sourceIndex, null));
150150
configBuilder.setDest(new DataFrameAnalyticsDest(destIndex, resultsField));
151-
configBuilder.setAnalysis(new Regression(dependentVariable));
151+
configBuilder.setAnalysis(regression);
152152
return configBuilder.build();
153153
}
154154
}

0 commit comments

Comments
 (0)