Skip to content

Commit 8891f4d

Browse files
[7.x][ML] Introduce randomize_seed setting for regression and classification (#49990) (#50023)
This adds a new `randomize_seed` for regression and classification. When not explicitly set, the seed is randomly generated. One can reuse the seed in a similar job in order to ensure the same docs are picked for training. Backport of #49990
1 parent ee4a8a0 commit 8891f4d

File tree

24 files changed

+465
-77
lines changed

24 files changed

+465
-77
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ public static Builder builder(String dependentVariable) {
4949
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
5050
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
5151
static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
52+
static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
5253

5354
private static final ConstructingObjectParser<Classification, Void> PARSER =
5455
new ConstructingObjectParser<>(
@@ -63,7 +64,8 @@ public static Builder builder(String dependentVariable) {
6364
(Double) a[5],
6465
(String) a[6],
6566
(Double) a[7],
66-
(Integer) a[8]));
67+
(Integer) a[8],
68+
(Long) a[9]));
6769

6870
static {
6971
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
@@ -75,6 +77,7 @@ public static Builder builder(String dependentVariable) {
7577
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
7678
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
7779
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES);
80+
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED);
7881
}
7982

8083
private final String dependentVariable;
@@ -86,10 +89,11 @@ public static Builder builder(String dependentVariable) {
8689
private final String predictionFieldName;
8790
private final Double trainingPercent;
8891
private final Integer numTopClasses;
92+
private final Long randomizeSeed;
8993

9094
private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
9195
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName,
92-
@Nullable Double trainingPercent, @Nullable Integer numTopClasses) {
96+
@Nullable Double trainingPercent, @Nullable Integer numTopClasses, @Nullable Long randomizeSeed) {
9397
this.dependentVariable = Objects.requireNonNull(dependentVariable);
9498
this.lambda = lambda;
9599
this.gamma = gamma;
@@ -99,6 +103,7 @@ private Classification(String dependentVariable, @Nullable Double lambda, @Nulla
99103
this.predictionFieldName = predictionFieldName;
100104
this.trainingPercent = trainingPercent;
101105
this.numTopClasses = numTopClasses;
106+
this.randomizeSeed = randomizeSeed;
102107
}
103108

104109
@Override
@@ -138,6 +143,10 @@ public Double getTrainingPercent() {
138143
return trainingPercent;
139144
}
140145

146+
public Long getRandomizeSeed() {
147+
return randomizeSeed;
148+
}
149+
141150
public Integer getNumTopClasses() {
142151
return numTopClasses;
143152
}
@@ -167,6 +176,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
167176
if (trainingPercent != null) {
168177
builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent);
169178
}
179+
if (randomizeSeed != null) {
180+
builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed);
181+
}
170182
if (numTopClasses != null) {
171183
builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
172184
}
@@ -177,7 +189,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
177189
@Override
178190
public int hashCode() {
179191
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
180-
trainingPercent, numTopClasses);
192+
trainingPercent, randomizeSeed, numTopClasses);
181193
}
182194

183195
@Override
@@ -193,6 +205,7 @@ public boolean equals(Object o) {
193205
&& Objects.equals(featureBagFraction, that.featureBagFraction)
194206
&& Objects.equals(predictionFieldName, that.predictionFieldName)
195207
&& Objects.equals(trainingPercent, that.trainingPercent)
208+
&& Objects.equals(randomizeSeed, that.randomizeSeed)
196209
&& Objects.equals(numTopClasses, that.numTopClasses);
197210
}
198211

@@ -211,6 +224,7 @@ public static class Builder {
211224
private String predictionFieldName;
212225
private Double trainingPercent;
213226
private Integer numTopClasses;
227+
private Long randomizeSeed;
214228

215229
private Builder(String dependentVariable) {
216230
this.dependentVariable = Objects.requireNonNull(dependentVariable);
@@ -251,14 +265,19 @@ public Builder setTrainingPercent(Double trainingPercent) {
251265
return this;
252266
}
253267

268+
public Builder setRandomizeSeed(Long randomizeSeed) {
269+
this.randomizeSeed = randomizeSeed;
270+
return this;
271+
}
272+
254273
public Builder setNumTopClasses(Integer numTopClasses) {
255274
this.numTopClasses = numTopClasses;
256275
return this;
257276
}
258277

259278
public Classification build() {
260279
return new Classification(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
261-
trainingPercent, numTopClasses);
280+
trainingPercent, numTopClasses, randomizeSeed);
262281
}
263282
}
264283
}

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Regression.java

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ public static Builder builder(String dependentVariable) {
4848
static final ParseField FEATURE_BAG_FRACTION = new ParseField("feature_bag_fraction");
4949
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
5050
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
51+
static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
5152

5253
private static final ConstructingObjectParser<Regression, Void> PARSER =
5354
new ConstructingObjectParser<>(
@@ -61,7 +62,8 @@ public static Builder builder(String dependentVariable) {
6162
(Integer) a[4],
6263
(Double) a[5],
6364
(String) a[6],
64-
(Double) a[7]));
65+
(Double) a[7],
66+
(Long) a[8]));
6567

6668
static {
6769
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
@@ -72,6 +74,7 @@ public static Builder builder(String dependentVariable) {
7274
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), FEATURE_BAG_FRACTION);
7375
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PREDICTION_FIELD_NAME);
7476
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
77+
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED);
7578
}
7679

7780
private final String dependentVariable;
@@ -82,10 +85,11 @@ public static Builder builder(String dependentVariable) {
8285
private final Double featureBagFraction;
8386
private final String predictionFieldName;
8487
private final Double trainingPercent;
88+
private final Long randomizeSeed;
8589

8690
private Regression(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
8791
@Nullable Integer maximumNumberTrees, @Nullable Double featureBagFraction, @Nullable String predictionFieldName,
88-
@Nullable Double trainingPercent) {
92+
@Nullable Double trainingPercent, @Nullable Long randomizeSeed) {
8993
this.dependentVariable = Objects.requireNonNull(dependentVariable);
9094
this.lambda = lambda;
9195
this.gamma = gamma;
@@ -94,6 +98,7 @@ private Regression(String dependentVariable, @Nullable Double lambda, @Nullable
9498
this.featureBagFraction = featureBagFraction;
9599
this.predictionFieldName = predictionFieldName;
96100
this.trainingPercent = trainingPercent;
101+
this.randomizeSeed = randomizeSeed;
97102
}
98103

99104
@Override
@@ -133,6 +138,10 @@ public Double getTrainingPercent() {
133138
return trainingPercent;
134139
}
135140

141+
public Long getRandomizeSeed() {
142+
return randomizeSeed;
143+
}
144+
136145
@Override
137146
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
138147
builder.startObject();
@@ -158,14 +167,17 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
158167
if (trainingPercent != null) {
159168
builder.field(TRAINING_PERCENT.getPreferredName(), trainingPercent);
160169
}
170+
if (randomizeSeed != null) {
171+
builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed);
172+
}
161173
builder.endObject();
162174
return builder;
163175
}
164176

165177
@Override
166178
public int hashCode() {
167179
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
168-
trainingPercent);
180+
trainingPercent, randomizeSeed);
169181
}
170182

171183
@Override
@@ -180,7 +192,8 @@ public boolean equals(Object o) {
180192
&& Objects.equals(maximumNumberTrees, that.maximumNumberTrees)
181193
&& Objects.equals(featureBagFraction, that.featureBagFraction)
182194
&& Objects.equals(predictionFieldName, that.predictionFieldName)
183-
&& Objects.equals(trainingPercent, that.trainingPercent);
195+
&& Objects.equals(trainingPercent, that.trainingPercent)
196+
&& Objects.equals(randomizeSeed, that.randomizeSeed);
184197
}
185198

186199
@Override
@@ -197,6 +210,7 @@ public static class Builder {
197210
private Double featureBagFraction;
198211
private String predictionFieldName;
199212
private Double trainingPercent;
213+
private Long randomizeSeed;
200214

201215
private Builder(String dependentVariable) {
202216
this.dependentVariable = Objects.requireNonNull(dependentVariable);
@@ -237,9 +251,14 @@ public Builder setTrainingPercent(Double trainingPercent) {
237251
return this;
238252
}
239253

254+
public Builder setRandomizeSeed(Long randomizeSeed) {
255+
this.randomizeSeed = randomizeSeed;
256+
return this;
257+
}
258+
240259
public Regression build() {
241260
return new Regression(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,
242-
trainingPercent);
261+
trainingPercent, randomizeSeed);
243262
}
244263
}
245264
}

client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1321,6 +1321,7 @@ public void testPutDataFrameAnalyticsConfig_GivenRegression() throws Exception {
13211321
.setAnalysis(org.elasticsearch.client.ml.dataframe.Regression.builder("my_dependent_variable")
13221322
.setPredictionFieldName("my_dependent_variable_prediction")
13231323
.setTrainingPercent(80.0)
1324+
.setRandomizeSeed(42L)
13241325
.build())
13251326
.setDescription("this is a regression")
13261327
.build();
@@ -1356,6 +1357,7 @@ public void testPutDataFrameAnalyticsConfig_GivenClassification() throws Excepti
13561357
.setAnalysis(org.elasticsearch.client.ml.dataframe.Classification.builder("my_dependent_variable")
13571358
.setPredictionFieldName("my_dependent_variable_prediction")
13581359
.setTrainingPercent(80.0)
1360+
.setRandomizeSeed(42L)
13591361
.setNumTopClasses(1)
13601362
.build())
13611363
.setDescription("this is a classification")

client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2975,7 +2975,8 @@ public void testPutDataFrameAnalytics() throws Exception {
29752975
.setFeatureBagFraction(0.4) // <6>
29762976
.setPredictionFieldName("my_prediction_field_name") // <7>
29772977
.setTrainingPercent(50.0) // <8>
2978-
.setNumTopClasses(1) // <9>
2978+
.setRandomizeSeed(1234L) // <9>
2979+
.setNumTopClasses(1) // <10>
29792980
.build();
29802981
// end::put-data-frame-analytics-classification
29812982

@@ -2988,6 +2989,7 @@ public void testPutDataFrameAnalytics() throws Exception {
29882989
.setFeatureBagFraction(0.4) // <6>
29892990
.setPredictionFieldName("my_prediction_field_name") // <7>
29902991
.setTrainingPercent(50.0) // <8>
2992+
.setRandomizeSeed(1234L) // <9>
29912993
.build();
29922994
// end::put-data-frame-analytics-regression
29932995

client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ public static Classification randomClassification() {
3434
.setFeatureBagFraction(randomBoolean() ? null : randomDoubleBetween(0.0, 1.0, false))
3535
.setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10))
3636
.setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
37+
.setRandomizeSeed(randomBoolean() ? null : randomLong())
3738
.setNumTopClasses(randomBoolean() ? null : randomIntBetween(0, 10))
3839
.build();
3940
}

docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ include-tagged::{doc-tests-file}[{api}-classification]
119119
<6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1].
120120
<7> The name of the prediction field in the results object.
121121
<8> The percentage of training-eligible rows to be used in training. Defaults to 100%.
122-
<9> The number of top classes to be reported in the results. Defaults to 2.
122+
<9> The seed to be used by the random generator that picks which rows are used in training.
123+
<10> The number of top classes to be reported in the results. Defaults to 2.
123124

124125
===== Regression
125126

@@ -138,6 +139,7 @@ include-tagged::{doc-tests-file}[{api}-regression]
138139
<6> The fraction of features which will be used when selecting a random bag for each candidate split. A double in (0, 1].
139140
<7> The name of the prediction field in the results object.
140141
<8> The percentage of training-eligible rows to be used in training. Defaults to 100%.
142+
<9> The seed to be used by the random generator that picks which rows are used in training.
141143

142144
==== Analyzed fields
143145

docs/reference/ml/df-analytics/apis/dfanalyticsresources.asciidoc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,8 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=prediction_field_name]
204204

205205
include::{docdir}/ml/ml-shared.asciidoc[tag=training_percent]
206206

207+
include::{docdir}/ml/ml-shared.asciidoc[tag=randomize_seed]
208+
207209

208210
[float]
209211
[[regression-resources-advanced]]
@@ -252,6 +254,8 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=prediction_field_name]
252254

253255
include::{docdir}/ml/ml-shared.asciidoc[tag=training_percent]
254256

257+
include::{docdir}/ml/ml-shared.asciidoc[tag=randomize_seed]
258+
255259

256260
[float]
257261
[[classification-resources-advanced]]

docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,8 @@ PUT _ml/data_frame/analytics/student_performance_mathematics_0.3
402402
{
403403
"regression": {
404404
"dependent_variable": "G3",
405-
"training_percent": 70 <1>
405+
"training_percent": 70, <1>
406+
"randomize_seed": 19673948271 <2>
406407
}
407408
}
408409
}
@@ -411,6 +412,7 @@ PUT _ml/data_frame/analytics/student_performance_mathematics_0.3
411412

412413
<1> The `training_percent` defines the percentage of the data set that will be used
413414
for training the model.
415+
<2> The `randomize_seed` is the seed used to randomly pick which data is used for training.
414416

415417

416418
[[ml-put-dfanalytics-example-c]]

docs/reference/ml/ml-shared.asciidoc

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,4 +67,18 @@ tag::training_percent[]
6767
be used for training. Documents that are ignored by the analysis (for example
6868
those that contain arrays) won’t be included in the calculation for used
6969
percentage. Defaults to `100`.
70-
end::training_percent[]
70+
end::training_percent[]
71+
72+
tag::randomize_seed[]
73+
`randomize_seed`::
74+
(Optional, long) Defines the seed to the random generator that is used to pick
75+
which documents will be used for training. By default it is randomly generated.
76+
Set it to a specific value to ensure the same documents are used for training
77+
assuming other related parameters (e.g. `source`, `analyzed_fields`, etc.) are the same.
78+
end::randomize_seed[]
79+
80+
81+
tag::use-null[]
82+
Defines whether a new series is used as the null series when there is no value
83+
for the by or partition fields. The default value is `false`.
84+
end::use-null[]

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/DataFrameAnalyticsConfig.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
225225
builder.field(DEST.getPreferredName(), dest);
226226

227227
builder.startObject(ANALYSIS.getPreferredName());
228-
builder.field(analysis.getWriteableName(), analysis);
228+
builder.field(analysis.getWriteableName(), analysis,
229+
new MapParams(Collections.singletonMap(VERSION.getPreferredName(), version == null ? null : version.toString())));
229230
builder.endObject();
230231

231232
if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) {

0 commit comments

Comments
 (0)