@@ -48,6 +48,7 @@ public static Builder builder(String dependentVariable) {
4848 static final ParseField FEATURE_BAG_FRACTION = new ParseField ("feature_bag_fraction" );
4949 static final ParseField PREDICTION_FIELD_NAME = new ParseField ("prediction_field_name" );
5050 static final ParseField TRAINING_PERCENT = new ParseField ("training_percent" );
51+ static final ParseField NUM_TOP_CLASSES = new ParseField ("num_top_classes" );
5152
5253 private static final ConstructingObjectParser <Classification , Void > PARSER =
5354 new ConstructingObjectParser <>(
@@ -61,7 +62,8 @@ public static Builder builder(String dependentVariable) {
6162 (Integer ) a [4 ],
6263 (Double ) a [5 ],
6364 (String ) a [6 ],
64- (Double ) a [7 ]));
65+ (Double ) a [7 ],
66+ (Integer ) a [8 ]));
6567
6668 static {
6769 PARSER .declareString (ConstructingObjectParser .constructorArg (), DEPENDENT_VARIABLE );
@@ -72,6 +74,7 @@ public static Builder builder(String dependentVariable) {
7274 PARSER .declareDouble (ConstructingObjectParser .optionalConstructorArg (), FEATURE_BAG_FRACTION );
7375 PARSER .declareString (ConstructingObjectParser .optionalConstructorArg (), PREDICTION_FIELD_NAME );
7476 PARSER .declareDouble (ConstructingObjectParser .optionalConstructorArg (), TRAINING_PERCENT );
77+ PARSER .declareInt (ConstructingObjectParser .optionalConstructorArg (), NUM_TOP_CLASSES );
7578 }
7679
7780 private final String dependentVariable ;
@@ -82,10 +85,11 @@ public static Builder builder(String dependentVariable) {
8285 private final Double featureBagFraction ;
8386 private final String predictionFieldName ;
8487 private final Double trainingPercent ;
88+ private final Integer numTopClasses ;
8589
8690 private Classification (String dependentVariable , @ Nullable Double lambda , @ Nullable Double gamma , @ Nullable Double eta ,
8791 @ Nullable Integer maximumNumberTrees , @ Nullable Double featureBagFraction , @ Nullable String predictionFieldName ,
88- @ Nullable Double trainingPercent ) {
92+ @ Nullable Double trainingPercent , @ Nullable Integer numTopClasses ) {
8993 this .dependentVariable = Objects .requireNonNull (dependentVariable );
9094 this .lambda = lambda ;
9195 this .gamma = gamma ;
@@ -94,6 +98,7 @@ private Classification(String dependentVariable, @Nullable Double lambda, @Nulla
9498 this .featureBagFraction = featureBagFraction ;
9599 this .predictionFieldName = predictionFieldName ;
96100 this .trainingPercent = trainingPercent ;
101+ this .numTopClasses = numTopClasses ;
97102 }
98103
99104 @ Override
@@ -133,6 +138,10 @@ public Double getTrainingPercent() {
133138 return trainingPercent ;
134139 }
135140
141+ public Integer getNumTopClasses () {
142+ return numTopClasses ;
143+ }
144+
136145 @ Override
137146 public XContentBuilder toXContent (XContentBuilder builder , Params params ) throws IOException {
138147 builder .startObject ();
@@ -158,14 +167,17 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
158167 if (trainingPercent != null ) {
159168 builder .field (TRAINING_PERCENT .getPreferredName (), trainingPercent );
160169 }
170+ if (numTopClasses != null ) {
171+ builder .field (NUM_TOP_CLASSES .getPreferredName (), numTopClasses );
172+ }
161173 builder .endObject ();
162174 return builder ;
163175 }
164176
165177 @ Override
166178 public int hashCode () {
167179 return Objects .hash (dependentVariable , lambda , gamma , eta , maximumNumberTrees , featureBagFraction , predictionFieldName ,
168- trainingPercent );
180+ trainingPercent , numTopClasses );
169181 }
170182
171183 @ Override
@@ -180,7 +192,8 @@ public boolean equals(Object o) {
180192 && Objects .equals (maximumNumberTrees , that .maximumNumberTrees )
181193 && Objects .equals (featureBagFraction , that .featureBagFraction )
182194 && Objects .equals (predictionFieldName , that .predictionFieldName )
183- && Objects .equals (trainingPercent , that .trainingPercent );
195+ && Objects .equals (trainingPercent , that .trainingPercent )
196+ && Objects .equals (numTopClasses , that .numTopClasses );
184197 }
185198
186199 @ Override
@@ -197,6 +210,7 @@ public static class Builder {
197210 private Double featureBagFraction ;
198211 private String predictionFieldName ;
199212 private Double trainingPercent ;
213+ private Integer numTopClasses ;
200214
201215 private Builder (String dependentVariable ) {
202216 this .dependentVariable = Objects .requireNonNull (dependentVariable );
@@ -237,9 +251,14 @@ public Builder setTrainingPercent(Double trainingPercent) {
237251 return this ;
238252 }
239253
254+ public Builder setNumTopClasses (Integer numTopClasses ) {
255+ this .numTopClasses = numTopClasses ;
256+ return this ;
257+ }
258+
240259 public Classification build () {
241260 return new Classification (dependentVariable , lambda , gamma , eta , maximumNumberTrees , featureBagFraction , predictionFieldName ,
242- trainingPercent );
261+ trainingPercent , numTopClasses );
243262 }
244263 }
245264}
0 commit comments