Skip to content

Commit 269425b

Browse files
[ML] Introduce randomize_seed setting for regression and classification (#49990)
This adds a new `randomize_seed` for regression and classification. When not explicitly set, the seed is randomly generated. One can reuse the seed in a similar job in order to ensure the same docs are picked for training.
1 parent a6351d6 commit 269425b

File tree

24 files changed

+460
-76
lines changed

24 files changed

+460
-76
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ public static Builder builder(String dependentVariable) {
4949
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
5050
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
5151
static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
52+
static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
5253

5354
private static final ConstructingObjectParser<Classification, Void> PARSER =
5455
new ConstructingObjectParser<>(
@@ -63,7 +64,8 @@ public static Builder builder(String dependentVariable) {
6364
(Double) a[5],
6465
(String) a[6],
6566
(Double) a[7],
66-
(Integer) a[8]));
67+
(Integer) a[8],
68+
(Long) a[9]));
6769

6870
static {
6971
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
@@ -75,6 +77,7 @@ public static Builder builder(String dependentVariable) {
7577
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
7678
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
7779
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES);
80+
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED);
7881
}
7982

8083
private final String dependentVariable;
@@ -86,10 +89,11 @@ public static Builder builder(String dependentVariable) {
8689
private final String predictionFieldName;
8790
private final Double trainingPercent;
8891
private final Integer numTopClasses;
92+
private final Long randomizeSeed;
8993

9094
private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
9195
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName,
92-
@Nullable Double trainingPercent, @Nullable Integer numTopClasses) {
96+
@Nullable Double trainingPercent, @Nullable Integer numTopClasses, @Nullable Long randomizeSeed) {
9397
this.dependentVariable = Objects.requireNonNull(dependentVariable);
9498
this.lambda = lambda;
9599
this.gamma = gamma;
@@ -99,6 +103,7 @@ private Classification(String dependentVariable, @Nullable Double lambda, @Nulla
99103
this.predictionFieldName = predictionFieldName;
100104
this.trainingPercent = trainingPercent;
101105
this.numTopClasses = numTopClasses;
106+
this.randomizeSeed = randomizeSeed;
102107
}
103108

104109
@Override
@@ -138,6 +143,10 @@ public Double getTrainingPercent() {
138143
return trainingPercent;
139144
}
140145

146+
public Long getRandomizeSeed() {
147+
return randomizeSeed;
148+
}
149+
141150
public Integer getNumTopClasses() {
142151
return numTopClasses;
143152
}
@@ -167,6 +176,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
167176
if (trainingPercent != null) {
168177
builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent);
169178
}
179+
if (randomizeSeed != null) {
180+
builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed);
181+
}
170182
if (numTopClasses != null) {
171183
builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
172184
}
@@ -177,7 +189,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
177189
@Override
178190
public int hashCode() {
179191
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
180-
trainingPercent, numTopClasses);
192+
trainingPercent, randomizeSeed, numTopClasses);
181193
}
182194

183195
@Override
@@ -193,6 +205,7 @@ public boolean equals(Object o) {
193205
&& Objects.equals(featureBagFraction, that.featureBagFraction)
194206
&& Objects.equals(predictionFieldName, that.predictionFieldName)
195207
&& Objects.equals(trainingPercent, that.trainingPercent)
208+
&& Objects.equals(randomizeSeed, that.randomizeSeed)
196209
&& Objects.equals(numTopClasses, that.numTopClasses);
197210
}
198211

@@ -211,6 +224,7 @@ public static class Builder {
211224
private String predictionFieldName;
212225
private Double trainingPercent;
213226
private Integer numTopClasses;
227+
private Long randomizeSeed;
214228

215229
private Builder(String dependentVariable) {
216230
this.dependentVariable = Objects.requireNonNull(dependentVariable);
@@ -251,14 +265,19 @@ public Builder setTrainingPercent(Double trainingPercent) {
251265
return this;
252266
}
253267

268+
public Builder setRandomizeSeed(Long randomizeSeed) {
269+
this.randomizeSeed = randomizeSeed;
270+
return this;
271+
}
272+
254273
public Builder setNumTopClasses(Integer numTopClasses) {
255274
this.numTopClasses = numTopClasses;
256275
return this;
257276
}
258277

259278
public Classification build() {
260279
return new Classification(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
261-
trainingPercent, numTopClasses);
280+
trainingPercent, numTopClasses, randomizeSeed);
262281
}
263282
}
264283
}

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ public static Builder builder(String dependentVariable) {
4848
static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
4949
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
5050
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
51+
static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
5152

5253
private static final ConstructingObjectParser<Regression, Void> PARSER =
5354
new ConstructingObjectParser<>(
@@ -61,7 +62,8 @@ public static Builder builder(String dependentVariable) {
6162
(Integer) a[4],
6263
(Double) a[5],
6364
(String) a[6],
64-
(Double) a[7]));
65+
(Double) a[7],
66+
(Long) a[8]));
6567

6668
static {
6769
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
@@ -72,6 +74,7 @@ public static Builder builder(String dependentVariable) {
7274
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION);
7375
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
7476
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
77+
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED);
7578
}
7679

7780
private final String dependentVariable;
@@ -82,10 +85,11 @@ public static Builder builder(String dependentVariable) {
8285
private final Double featureBagFraction;
8386
private final String predictionFieldName;
8487
private final Double trainingPercent;
88+
private final Long randomizeSeed;
8589

8690
private Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
8791
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName,
88-
@Nullable Double trainingPercent) {
92+
@Nullable Double trainingPercent, @Nullable Long randomizeSeed) {
8993
this.dependentVariable = Objects.requireNonNull(dependentVariable);
9094
this.lambda = lambda;
9195
this.gamma = gamma;
@@ -94,6 +98,7 @@ private Regression(String dependentVariable, @Nullable Double lambda, @Nullable
9498
this.featureBagFraction = featureBagFraction;
9599
this.predictionFieldName = predictionFieldName;
96100
this.trainingPercent = trainingPercent;
101+
this.randomizeSeed = randomizeSeed;
97102
}
98103

99104
@Override
@@ -133,6 +138,10 @@ public Double getTrainingPercent() {
133138
return trainingPercent;
134139
}
135140

141+
public Long getRandomizeSeed() {
142+
return randomizeSeed;
143+
}
144+
136145
@Override
137146
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
138147
builder.startObject();
@@ -158,14 +167,17 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
158167
if (trainingPercent != null) {
159168
builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent);
160169
}
170+
if (randomizeSeed != null) {
171+
builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed);
172+
}
161173
builder.endObject();
162174
return builder;
163175
}
164176

165177
@Override
166178
public int hashCode() {
167179
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
168-
trainingPercent);
180+
trainingPercent, randomizeSeed);
169181
}
170182

171183
@Override
@@ -180,7 +192,8 @@ public boolean equals(Object o) {
180192
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
181193
&& Objects.equals(featureBagFraction, that.featureBagFraction)
182194
&& Objects.equals(predictionFieldName, that.predictionFieldName)
183-
&& Objects.equals(trainingPercent, that.trainingPercent);
195+
&& Objects.equals(trainingPercent, that.trainingPercent)
196+
&& Objects.equals(randomizeSeed, that.randomizeSeed);
184197
}
185198

186199
@Override
@@ -197,6 +210,7 @@ public static class Builder {
197210
private Double featureBagFraction;
198211
private String predictionFieldName;
199212
private Double trainingPercent;
213+
private Long randomizeSeed;
200214

201215
private Builder(String dependentVariable) {
202216
this.dependentVariable = Objects.requireNonNull(dependentVariable);
@@ -237,9 +251,14 @@ public Builder setTrainingPercent(Double trainingPercent) {
237251
return this;
238252
}
239253

254+
public Builder setRandomizeSeed(Long randomizeSeed) {
255+
this.randomizeSeed = randomizeSeed;
256+
return this;
257+
}
258+
240259
public Regression build() {
241260
return new Regression(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
242-
trainingPercent);
261+
trainingPercent, randomizeSeed);
243262
}
244263
}
245264
}

client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,6 +1291,7 @@ public void testPutDataFrameAnalyticsConfig_GivenRegression() throws Exception {
12911291
.setAnalysis(org.elasticsearch.client.ml.dataframe.Regression.builder("my_dependent_variable")
12921292
.setPredictionFieldName("my_dependent_variable_prediction")
12931293
.setTrainingPercent(80.0)
1294+
.setRandomizeSeed(42L)
12941295
.build())
12951296
.setDescription("this is a regression")
12961297
.build();
@@ -1326,6 +1327,7 @@ public void testPutDataFrameAnalyticsConfig_GivenClassification() throws Excepti
13261327
.setAnalysis(org.elasticsearch.client.ml.dataframe.Classification.builder("my_dependent_variable")
13271328
.setPredictionFieldName("my_dependent_variable_prediction")
13281329
.setTrainingPercent(80.0)
1330+
.setRandomizeSeed(42L)
13291331
.setNumTopClasses(1)
13301332
.build())
13311333
.setDescription("this is a classification")

client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2975,7 +2975,8 @@ public void testPutDataFrameAnalytics() throws Exception {
29752975
.setFeatureBagFraction(0.4) // <6>
29762976
.setPredictionFieldName("my_prediction_field_name") // <7>
29772977
.setTrainingPercent(50.0) // <8>
2978-
.setNumTopClasses(1) // <9>
2978+
.setRandomizeSeed(1234L) // <9>
2979+
.setNumTopClasses(1) // <10>
29792980
.build();
29802981
// end::put-data-frame-analytics-classification
29812982

@@ -2988,6 +2989,7 @@ public void testPutDataFrameAnalytics() throws Exception {
29882989
.setFeatureBagFraction(0.4) // <6>
29892990
.setPredictionFieldName("my_prediction_field_name") // <7>
29902991
.setTrainingPercent(50.0) // <8>
2992+
.setRandomizeSeed(1234L) // <9>
29912993
.build();
29922994
// end::put-data-frame-analytics-regression
29932995

client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ public static Classification randomClassification() {
3434
.setFeatureBagFraction(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false))
3535
.setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10))
3636
.setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
37+
.setRandomizeSeed(randomBoolean() ? null : randomLong())
3738
.setNumTopClasses(randomBoolean() ? null : randomIntBetween(0, 10))
3839
.build();
3940
}

docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ include-tagged::{doc-tests-file}[{api}-classification]
119119
<6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1].
120120
<7> The name of the prediction field in the results object.
121121
<8> The percentage of training-eligible rows to be used in training. Defaults to 100%.
122-
<9> The number of top classes to be reported in the results. Defaults to 2.
122+
<9> The seed to be used by the random generator that picks which rows are used in training.
123+
<10> The number of top classes to be reported in the results. Defaults to 2.
123124

124125
===== Regression
125126

@@ -138,6 +139,7 @@ include-tagged::{doc-tests-file}[{api}-regression]
138139
<6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1].
139140
<7> The name of the prediction field in the results object.
140141
<8> The percentage of training-eligible rows to be used in training. Defaults to 100%.
142+
<9> The seed to be used by the random generator that picks which rows are used in training.
141143

142144
==== Analyzed fields
143145

docs/reference/ml/df-analytics/apis/dfanalyticsresources.asciidoc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,8 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=prediction_field_name]
204204

205205
include::{docdir}/ml/ml-shared.asciidoc[tag=training_percent]
206206

207+
include::{docdir}/ml/ml-shared.asciidoc[tag=randomize_seed]
208+
207209

208210
[float]
209211
[[regression-resources-advanced]]
@@ -252,6 +254,8 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=prediction_field_name]
252254

253255
include::{docdir}/ml/ml-shared.asciidoc[tag=training_percent]
254256

257+
include::{docdir}/ml/ml-shared.asciidoc[tag=randomize_seed]
258+
255259

256260
[float]
257261
[[classification-resources-advanced]]

docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,8 @@ PUT _ml/data_frame/analytics/student_performance_mathematics_0.3
397397
{
398398
"regression": {
399399
"dependent_variable": "G3",
400-
"training_percent": 70 <1>
400+
"training_percent": 70, <1>
401+
"randomize_seed": 19673948271 <2>
401402
}
402403
}
403404
}
@@ -406,6 +407,7 @@ PUT _ml/data_frame/analytics/student_performance_mathematics_0.3
406407

407408
<1> The `training_percent` defines the percentage of the data set that will be used
408409
for training the model.
410+
<2> The `randomize_seed` is the seed used to randomly pick which data is used for training.
409411

410412

411413
[[ml-put-dfanalytics-example-c]]

docs/reference/ml/ml-shared.asciidoc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,15 @@ those that contain arrays) won’t be included in the calculation for used
681681
percentage. Defaults to `100`.
682682
end::training_percent[]
683683

684+
tag::randomize_seed[]
685+
`randomize_seed`::
686+
(Optional, long) Defines the seed to the random generator that is used to pick
687+
which documents will be used for training. By default it is randomly generated.
688+
Set it to a specific value to ensure the same documents are used for training
689+
assuming other related parameters (e.g. `source`, `analyzed_fields`, etc.) are the same.
690+
end::randomize_seed[]
691+
692+
684693
tag::use-null[]
685694
Defines whether a new series is used as the null series when there is no value
686695
for the by or partition fields. The default value is `false`.

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
225225
builder.field(DEST.getPreferredName(), dest);
226226

227227
builder.startObject(ANALYSIS.getPreferredName());
228-
builder.field(analysis.getWriteableName(), analysis);
228+
builder.field(analysis.getWriteableName(), analysis,
229+
new MapParams(Collections.singletonMap(VERSION.getPreferredName(), version == null ? null : version.toString())));
229230
builder.endObject();
230231

231232
if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) {

0 commit comments

Comments
 (0)