Skip to content

Commit 58340c2

Browse files
authored
[ML] Adds the class_assignment_objective parameter to classification (#52763)
Adds a new parameter for classification that enables choosing whether to assign labels to maximise accuracy or to maximise the minimum class recall. Fixes #52427.
1 parent facd525 commit 58340c2

File tree

17 files changed

+250
-32
lines changed

17 files changed

+250
-32
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/Classification.java

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@
2222
import org.elasticsearch.common.ParseField;
2323
import org.elasticsearch.common.Strings;
2424
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
25+
import org.elasticsearch.common.xcontent.ObjectParser;
2526
import org.elasticsearch.common.xcontent.XContentBuilder;
2627
import org.elasticsearch.common.xcontent.XContentParser;
2728

2829
import java.io.IOException;
30+
import java.util.Locale;
2931
import java.util.Objects;
3032

3133
public class Classification implements DataFrameAnalysis {
@@ -49,6 +51,7 @@ public static Builder builder(String dependentVariable) {
4951
static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values");
5052
static final ParseField PREDICTION_FIELD_NAME = new ParseField("prediction_field_name");
5153
static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
54+
static final ParseField CLASS_ASSIGNMENT_OBJECTIVE = new ParseField("class_assignment_objective");
5255
static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
5356
static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
5457

@@ -67,7 +70,8 @@ public static Builder builder(String dependentVariable) {
6770
(String) a[7],
6871
(Double) a[8],
6972
(Integer) a[9],
70-
(Long) a[10]));
73+
(Long) a[10],
74+
(ClassAssignmentObjective) a[11]));
7175

7276
static {
7377
PARSER.declareString(ConstructingObjectParser.constructorArg(), DEPENDENT_VARIABLE);
@@ -81,6 +85,12 @@ public static Builder builder(String dependentVariable) {
8185
PARSER.declareDouble(ConstructingObjectParser.optionalConstructorArg(), TRAINING_PERCENT);
8286
PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), NUM_TOP_CLASSES);
8387
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), RANDOMIZE_SEED);
88+
PARSER.declareField(ConstructingObjectParser.optionalConstructorArg(), p -> {
89+
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
90+
return ClassAssignmentObjective.fromString(p.text());
91+
}
92+
throw new IllegalArgumentException("Unsupported token [" + p.currentToken() + "]");
93+
}, CLASS_ASSIGNMENT_OBJECTIVE, ObjectParser.ValueType.STRING);
8494
}
8595

8696
private final String dependentVariable;
@@ -92,13 +102,15 @@ public static Builder builder(String dependentVariable) {
92102
private final Integer numTopFeatureImportanceValues;
93103
private final String predictionFieldName;
94104
private final Double trainingPercent;
105+
private final ClassAssignmentObjective classAssignmentObjective;
95106
private final Integer numTopClasses;
96107
private final Long randomizeSeed;
97108

98109
private Classification(String dependentVariable, @Nullable Double lambda, @Nullable Double gamma, @Nullable Double eta,
99110
@Nullable Integer maxTrees, @Nullable Double featureBagFraction,
100111
@Nullable Integer numTopFeatureImportanceValues, @Nullable String predictionFieldName,
101-
@Nullable Double trainingPercent, @Nullable Integer numTopClasses, @Nullable Long randomizeSeed) {
112+
@Nullable Double trainingPercent, @Nullable Integer numTopClasses, @Nullable Long randomizeSeed,
113+
@Nullable ClassAssignmentObjective classAssignmentObjective) {
102114
this.dependentVariable = Objects.requireNonNull(dependentVariable);
103115
this.lambda = lambda;
104116
this.gamma = gamma;
@@ -108,6 +120,7 @@ private Classification(String dependentVariable, @Nullable Double lambda, @Nulla
108120
this.numTopFeatureImportanceValues = numTopFeatureImportanceValues;
109121
this.predictionFieldName = predictionFieldName;
110122
this.trainingPercent = trainingPercent;
123+
this.classAssignmentObjective = classAssignmentObjective;
111124
this.numTopClasses = numTopClasses;
112125
this.randomizeSeed = randomizeSeed;
113126
}
@@ -157,6 +170,10 @@ public Long getRandomizeSeed() {
157170
return randomizeSeed;
158171
}
159172

173+
public ClassAssignmentObjective getClassAssignmentObjective() {
174+
return classAssignmentObjective;
175+
}
176+
160177
public Integer getNumTopClasses() {
161178
return numTopClasses;
162179
}
@@ -192,6 +209,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
192209
if (randomizeSeed != null) {
193210
builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed);
194211
}
212+
if (classAssignmentObjective != null) {
213+
builder.field(CLASS_ASSIGNMENT_OBJECTIVE.getPreferredName(), classAssignmentObjective);
214+
}
195215
if (numTopClasses != null) {
196216
builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses);
197217
}
@@ -202,7 +222,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
202222
@Override
203223
public int hashCode() {
204224
return Objects.hash(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction, numTopFeatureImportanceValues,
205-
predictionFieldName, trainingPercent, randomizeSeed, numTopClasses);
225+
predictionFieldName, trainingPercent, randomizeSeed, numTopClasses, classAssignmentObjective);
206226
}
207227

208228
@Override
@@ -220,14 +240,28 @@ public boolean equals(Object o) {
220240
&& Objects.equals(predictionFieldName, that.predictionFieldName)
221241
&& Objects.equals(trainingPercent, that.trainingPercent)
222242
&& Objects.equals(randomizeSeed, that.randomizeSeed)
223-
&& Objects.equals(numTopClasses, that.numTopClasses);
243+
&& Objects.equals(numTopClasses, that.numTopClasses)
244+
&& Objects.equals(classAssignmentObjective, that.classAssignmentObjective);
224245
}
225246

226247
@Override
227248
public String toString() {
228249
return Strings.toString(this);
229250
}
230251

252+
public enum ClassAssignmentObjective {
253+
MAXIMIZE_ACCURACY, MAXIMIZE_MINIMUM_RECALL;
254+
255+
public static ClassAssignmentObjective fromString(String value) {
256+
return ClassAssignmentObjective.valueOf(value.toUpperCase(Locale.ROOT));
257+
}
258+
259+
@Override
260+
public String toString() {
261+
return name().toLowerCase(Locale.ROOT);
262+
}
263+
}
264+
231265
public static class Builder {
232266
private String dependentVariable;
233267
private Double lambda;
@@ -240,6 +274,7 @@ public static class Builder {
240274
private Double trainingPercent;
241275
private Integer numTopClasses;
242276
private Long randomizeSeed;
277+
private ClassAssignmentObjective classAssignmentObjective;
243278

244279
private Builder(String dependentVariable) {
245280
this.dependentVariable = Objects.requireNonNull(dependentVariable);
@@ -295,9 +330,15 @@ public Builder setNumTopClasses(Integer numTopClasses) {
295330
return this;
296331
}
297332

333+
public Builder setClassAssignmentObjective(ClassAssignmentObjective classAssignmentObjective) {
334+
this.classAssignmentObjective = classAssignmentObjective;
335+
return this;
336+
}
337+
298338
public Classification build() {
299339
return new Classification(dependentVariable, lambda, gamma, eta, maxTrees, featureBagFraction,
300-
numTopFeatureImportanceValues, predictionFieldName, trainingPercent, numTopClasses, randomizeSeed);
340+
numTopFeatureImportanceValues, predictionFieldName, trainingPercent, numTopClasses, randomizeSeed,
341+
classAssignmentObjective);
301342
}
302343
}
303344
}

client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1336,6 +1336,8 @@ public void testPutDataFrameAnalyticsConfig_GivenClassification() throws Excepti
13361336
.setPredictionFieldName("my_dependent_variable_prediction")
13371337
.setTrainingPercent(80.0)
13381338
.setRandomizeSeed(42L)
1339+
.setClassAssignmentObjective(
1340+
org.elasticsearch.client.ml.dataframe.Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY)
13391341
.setNumTopClasses(1)
13401342
.setLambda(1.0)
13411343
.setGamma(1.0)

client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@
139139
import org.elasticsearch.client.ml.datafeed.DatafeedStats;
140140
import org.elasticsearch.client.ml.datafeed.DatafeedUpdate;
141141
import org.elasticsearch.client.ml.datafeed.DelayedDataCheckConfig;
142+
import org.elasticsearch.client.ml.dataframe.Classification;
142143
import org.elasticsearch.client.ml.dataframe.DataFrameAnalysis;
143144
import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsConfig;
144145
import org.elasticsearch.client.ml.dataframe.DataFrameAnalyticsDest;
@@ -2969,7 +2970,7 @@ public void testPutDataFrameAnalytics() throws Exception {
29692970
// end::put-data-frame-analytics-outlier-detection-customized
29702971

29712972
// tag::put-data-frame-analytics-classification
2972-
DataFrameAnalysis classification = org.elasticsearch.client.ml.dataframe.Classification.builder("my_dependent_variable") // <1>
2973+
DataFrameAnalysis classification = Classification.builder("my_dependent_variable") // <1>
29732974
.setLambda(1.0) // <2>
29742975
.setGamma(5.5) // <3>
29752976
.setEta(5.5) // <4>
@@ -2979,7 +2980,8 @@ public void testPutDataFrameAnalytics() throws Exception {
29792980
.setPredictionFieldName("my_prediction_field_name") // <8>
29802981
.setTrainingPercent(50.0) // <9>
29812982
.setRandomizeSeed(1234L) // <10>
2982-
.setNumTopClasses(1) // <11>
2983+
.setClassAssignmentObjective(Classification.ClassAssignmentObjective.MAXIMIZE_ACCURACY) // <11>
2984+
.setNumTopClasses(1) // <12>
29832985
.build();
29842986
// end::put-data-frame-analytics-classification
29852987

client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/ClassificationTests.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ public static Classification randomClassification() {
3636
.setPredictionFieldName(randomBoolean() ? null : randomAlphaOfLength(10))
3737
.setTrainingPercent(randomBoolean() ? null : randomDoubleBetween(1.0, 100.0, true))
3838
.setRandomizeSeed(randomBoolean() ? null : randomLong())
39+
.setClassAssignmentObjective(randomBoolean() ? null : randomFrom(Classification.ClassAssignmentObjective.values()))
3940
.setNumTopClasses(randomBoolean() ? null : randomIntBetween(0, 10))
4041
.build();
4142
}

docs/java-rest/high-level/ml/put-data-frame-analytics.asciidoc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,8 @@ include-tagged::{doc-tests-file}[{api}-classification]
121121
<8> The name of the prediction field in the results object.
122122
<9> The percentage of training-eligible rows to be used in training. Defaults to 100%.
123123
<10> The seed to be used by the random generator that picks which rows are used in training.
124-
<11> The number of top classes to be reported in the results. Defaults to 2.
124+
<11> The optimization objective to target when assigning class labels. Defaults to maximize_minimum_recall.
125+
<12> The number of top classes to be reported in the results. Defaults to 2.
125126

126127
===== Regression
127128

docs/reference/ml/df-analytics/apis/put-dfanalytics.asciidoc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,10 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=gamma]
136136
(Optional, double)
137137
include::{docdir}/ml/ml-shared.asciidoc[tag=lambda]
138138

139+
`analysis`.`classification`.`class_assignment_objective`::::
140+
(Optional, string)
141+
include::{docdir}/ml/ml-shared.asciidoc[tag=class-assignment-objective]
142+
139143
`analysis`.`classification`.`num_top_classes`::::
140144
(Optional, integer)
141145
include::{docdir}/ml/ml-shared.asciidoc[tag=num-top-classes]

docs/reference/ml/ml-shared.asciidoc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,14 @@ include::{docdir}/ml/ml-shared.asciidoc[tag=mode]
339339
include::{docdir}/ml/ml-shared.asciidoc[tag=time-span]
340340
end::chunking-config[]
341341

342+
tag::class-assignment-objective[]
343+
Defines the objective to optimize when assigning class labels. Available
344+
objectives are `maximize_accuracy` and `maximize_minimum_recall`. When maximizing
345+
accuracy class labels are chosen to maximize the number of correct predictions.
346+
When maximizing minimum recall labels are chosen to maximize the minimum recall
347+
for any class. Defaults to maximize_minimum_recall.
348+
end::class-assignment-objective[]
349+
342350
tag::custom-rules[]
343351
An array of custom rule objects, which enable you to customize the way detectors
344352
operate. For example, a rule may dictate to the detector conditions under which

0 commit comments

Comments
 (0)