2222import org .elasticsearch .common .ParseField ;
2323import org .elasticsearch .common .Strings ;
2424import org .elasticsearch .common .xcontent .ConstructingObjectParser ;
25+ import org .elasticsearch .common .xcontent .ObjectParser ;
2526import org .elasticsearch .common .xcontent .XContentBuilder ;
2627import org .elasticsearch .common .xcontent .XContentParser ;
2728
2829import java .io .IOException ;
30+ import java .util .Locale ;
2931import java .util .Objects ;
3032
3133public class Classification implements DataFrameAnalysis {
@@ -49,6 +51,7 @@ public static Builder builder(String dependentVariable) {
4951 static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField ("num_top_feature_importance_values" );
5052 static final ParseField PREDICTION_FIELD_NAME = new ParseField ("prediction_field_name" );
5153 static final ParseField TRAINING_PERCENT = new ParseField ("training_percent" );
54+ static final ParseField CLASS_ASSIGNMENT_OBJECTIVE = new ParseField ("class_assignment_objective" );
5255 static final ParseField NUM_TOP_CLASSES = new ParseField ("num_top_classes" );
5356 static final ParseField RANDOMIZE_SEED = new ParseField ("randomize_seed" );
5457
@@ -67,7 +70,8 @@ public static Builder builder(String dependentVariable) {
6770 (String ) a [7 ],
6871 (Double ) a [8 ],
6972 (Integer ) a [9 ],
70- (Long ) a [10 ]));
73+ (Long ) a [10 ],
74+ (ClassAssignmentObjective ) a [11 ]));
7175
7276 static {
7377 PARSER .declareString (ConstructingObjectParser .constructorArg (), DEPENDENT_VARIABLE );
@@ -81,6 +85,12 @@ public static Builder builder(String dependentVariable) {
8185 PARSER .declareDouble (ConstructingObjectParser .optionalConstructorArg (), TRAINING_PERCENT );
8286 PARSER .declareInt (ConstructingObjectParser .optionalConstructorArg (), NUM_TOP_CLASSES );
8387 PARSER .declareLong (ConstructingObjectParser .optionalConstructorArg (), RANDOMIZE_SEED );
88+ PARSER .declareField (ConstructingObjectParser .optionalConstructorArg (), p -> {
89+ if (p .currentToken () == XContentParser .Token .VALUE_STRING ) {
90+ return ClassAssignmentObjective .fromString (p .text ());
91+ }
92+ throw new IllegalArgumentException ("Unsupported token [" + p .currentToken () + "]" );
93+ }, CLASS_ASSIGNMENT_OBJECTIVE , ObjectParser .ValueType .STRING );
8494 }
8595
8696 private final String dependentVariable ;
@@ -92,13 +102,15 @@ public static Builder builder(String dependentVariable) {
92102 private final Integer numTopFeatureImportanceValues ;
93103 private final String predictionFieldName ;
94104 private final Double trainingPercent ;
105+ private final ClassAssignmentObjective classAssignmentObjective ;
95106 private final Integer numTopClasses ;
96107 private final Long randomizeSeed ;
97108
98109 private Classification (String dependentVariable , @ Nullable Double lambda , @ Nullable Double gamma , @ Nullable Double eta ,
99110 @ Nullable Integer maxTrees , @ Nullable Double featureBagFraction ,
100111 @ Nullable Integer numTopFeatureImportanceValues , @ Nullable String predictionFieldName ,
101- @ Nullable Double trainingPercent , @ Nullable Integer numTopClasses , @ Nullable Long randomizeSeed ) {
112+ @ Nullable Double trainingPercent , @ Nullable Integer numTopClasses , @ Nullable Long randomizeSeed ,
113+ @ Nullable ClassAssignmentObjective classAssignmentObjective ) {
102114 this .dependentVariable = Objects .requireNonNull (dependentVariable );
103115 this .lambda = lambda ;
104116 this .gamma = gamma ;
@@ -108,6 +120,7 @@ private Classification(String dependentVariable, @Nullable Double lambda, @Nulla
108120 this .numTopFeatureImportanceValues = numTopFeatureImportanceValues ;
109121 this .predictionFieldName = predictionFieldName ;
110122 this .trainingPercent = trainingPercent ;
123+ this .classAssignmentObjective = classAssignmentObjective ;
111124 this .numTopClasses = numTopClasses ;
112125 this .randomizeSeed = randomizeSeed ;
113126 }
@@ -157,6 +170,10 @@ public Long getRandomizeSeed() {
157170 return randomizeSeed ;
158171 }
159172
173+ public ClassAssignmentObjective getClassAssignmentObjective () {
174+ return classAssignmentObjective ;
175+ }
176+
160177 public Integer getNumTopClasses () {
161178 return numTopClasses ;
162179 }
@@ -192,6 +209,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
192209 if (randomizeSeed != null ) {
193210 builder .field (RANDOMIZE_SEED .getPreferredName (), randomizeSeed );
194211 }
212+ if (classAssignmentObjective != null ) {
213+ builder .field (CLASS_ASSIGNMENT_OBJECTIVE .getPreferredName (), classAssignmentObjective );
214+ }
195215 if (numTopClasses != null ) {
196216 builder .field (NUM_TOP_CLASSES .getPreferredName (), numTopClasses );
197217 }
@@ -202,7 +222,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
202222 @ Override
203223 public int hashCode () {
204224 return Objects .hash (dependentVariable , lambda , gamma , eta , maxTrees , featureBagFraction , numTopFeatureImportanceValues ,
205- predictionFieldName , trainingPercent , randomizeSeed , numTopClasses );
225+ predictionFieldName , trainingPercent , randomizeSeed , numTopClasses , classAssignmentObjective );
206226 }
207227
208228 @ Override
@@ -220,14 +240,28 @@ public boolean equals(Object o) {
220240 && Objects .equals (predictionFieldName , that .predictionFieldName )
221241 && Objects .equals (trainingPercent , that .trainingPercent )
222242 && Objects .equals (randomizeSeed , that .randomizeSeed )
223- && Objects .equals (numTopClasses , that .numTopClasses );
243+ && Objects .equals (numTopClasses , that .numTopClasses )
244+ && Objects .equals (classAssignmentObjective , that .classAssignmentObjective );
224245 }
225246
226247 @ Override
227248 public String toString () {
228249 return Strings .toString (this );
229250 }
230251
252+ public enum ClassAssignmentObjective {
253+ MAXIMIZE_ACCURACY , MAXIMIZE_MINIMUM_RECALL ;
254+
255+ public static ClassAssignmentObjective fromString (String value ) {
256+ return ClassAssignmentObjective .valueOf (value .toUpperCase (Locale .ROOT ));
257+ }
258+
259+ @ Override
260+ public String toString () {
261+ return name ().toLowerCase (Locale .ROOT );
262+ }
263+ }
264+
231265 public static class Builder {
232266 private String dependentVariable ;
233267 private Double lambda ;
@@ -240,6 +274,7 @@ public static class Builder {
240274 private Double trainingPercent ;
241275 private Integer numTopClasses ;
242276 private Long randomizeSeed ;
277+ private ClassAssignmentObjective classAssignmentObjective ;
243278
244279 private Builder (String dependentVariable ) {
245280 this .dependentVariable = Objects .requireNonNull (dependentVariable );
@@ -295,9 +330,15 @@ public Builder setNumTopClasses(Integer numTopClasses) {
295330 return this ;
296331 }
297332
333+ public Builder setClassAssignmentObjective (ClassAssignmentObjective classAssignmentObjective ) {
334+ this .classAssignmentObjective = classAssignmentObjective ;
335+ return this ;
336+ }
337+
298338 public Classification build () {
299339 return new Classification (dependentVariable , lambda , gamma , eta , maxTrees , featureBagFraction ,
300- numTopFeatureImportanceValues , predictionFieldName , trainingPercent , numTopClasses , randomizeSeed );
340+ numTopFeatureImportanceValues , predictionFieldName , trainingPercent , numTopClasses , randomizeSeed ,
341+ classAssignmentObjective );
301342 }
302343 }
303344}
0 commit comments