1515import org .elasticsearch .common .xcontent .XContentBuilder ;
1616import org .elasticsearch .common .xcontent .XContentParser ;
1717import org .elasticsearch .index .mapper .FieldAliasMapper ;
18+ import org .elasticsearch .xpack .core .ml .inference .preprocessing .LenientlyParsedPreProcessor ;
19+ import org .elasticsearch .xpack .core .ml .inference .preprocessing .PreProcessor ;
20+ import org .elasticsearch .xpack .core .ml .inference .preprocessing .StrictlyParsedPreProcessor ;
1821import org .elasticsearch .xpack .core .ml .inference .trainedmodel .ClassificationConfig ;
1922import org .elasticsearch .xpack .core .ml .inference .trainedmodel .InferenceConfig ;
2023import org .elasticsearch .xpack .core .ml .inference .trainedmodel .PredictionFieldType ;
2124import org .elasticsearch .xpack .core .ml .utils .ExceptionsHelper ;
25+ import org .elasticsearch .xpack .core .ml .utils .NamedXContentObjectHelper ;
2226
2327import java .io .IOException ;
2428import java .util .Arrays ;
@@ -46,6 +50,7 @@ public class Classification implements DataFrameAnalysis {
4650 public static final ParseField NUM_TOP_CLASSES = new ParseField ("num_top_classes" );
4751 public static final ParseField TRAINING_PERCENT = new ParseField ("training_percent" );
4852 public static final ParseField RANDOMIZE_SEED = new ParseField ("randomize_seed" );
53+ public static final ParseField FEATURE_PROCESSORS = new ParseField ("feature_processors" );
4954
5055 private static final String STATE_DOC_ID_SUFFIX = "_classification_state#1" ;
5156
@@ -59,6 +64,7 @@ public class Classification implements DataFrameAnalysis {
5964 */
6065 public static final int MAX_DEPENDENT_VARIABLE_CARDINALITY = 30 ;
6166
67+ @ SuppressWarnings ("unchecked" )
6268 private static ConstructingObjectParser <Classification , Void > createParser (boolean lenient ) {
6369 ConstructingObjectParser <Classification , Void > parser = new ConstructingObjectParser <>(
6470 NAME .getPreferredName (),
@@ -70,14 +76,21 @@ private static ConstructingObjectParser<Classification, Void> createParser(boole
7076 (ClassAssignmentObjective ) a [8 ],
7177 (Integer ) a [9 ],
7278 (Double ) a [10 ],
73- (Long ) a [11 ]));
79+ (Long ) a [11 ],
80+ (List <PreProcessor >) a [12 ]));
7481 parser .declareString (constructorArg (), DEPENDENT_VARIABLE );
7582 BoostedTreeParams .declareFields (parser );
7683 parser .declareString (optionalConstructorArg (), PREDICTION_FIELD_NAME );
7784 parser .declareString (optionalConstructorArg (), ClassAssignmentObjective ::fromString , CLASS_ASSIGNMENT_OBJECTIVE );
7885 parser .declareInt (optionalConstructorArg (), NUM_TOP_CLASSES );
7986 parser .declareDouble (optionalConstructorArg (), TRAINING_PERCENT );
8087 parser .declareLong (optionalConstructorArg (), RANDOMIZE_SEED );
88+ parser .declareNamedObjects (optionalConstructorArg (),
89+ (p , c , n ) -> lenient ?
90+ p .namedObject (LenientlyParsedPreProcessor .class , n , new PreProcessor .PreProcessorParseContext (true )) :
91+ p .namedObject (StrictlyParsedPreProcessor .class , n , new PreProcessor .PreProcessorParseContext (true )),
92+ (classification ) -> {/*TODO should we throw if this is not set?*/ },
93+ FEATURE_PROCESSORS );
8194 return parser ;
8295 }
8396
@@ -119,14 +132,16 @@ public static Classification fromXContent(XContentParser parser, boolean ignoreU
119132 private final int numTopClasses ;
120133 private final double trainingPercent ;
121134 private final long randomizeSeed ;
135+ private final List <PreProcessor > featureProcessors ;
122136
123137 public Classification (String dependentVariable ,
124138 BoostedTreeParams boostedTreeParams ,
125139 @ Nullable String predictionFieldName ,
126140 @ Nullable ClassAssignmentObjective classAssignmentObjective ,
127141 @ Nullable Integer numTopClasses ,
128142 @ Nullable Double trainingPercent ,
129- @ Nullable Long randomizeSeed ) {
143+ @ Nullable Long randomizeSeed ,
144+ @ Nullable List <PreProcessor > featureProcessors ) {
130145 if (numTopClasses != null && (numTopClasses < 0 || numTopClasses > 1000 )) {
131146 throw ExceptionsHelper .badRequestException ("[{}] must be an integer in [0, 1000]" , NUM_TOP_CLASSES .getPreferredName ());
132147 }
@@ -141,10 +156,11 @@ public Classification(String dependentVariable,
141156 this .numTopClasses = numTopClasses == null ? DEFAULT_NUM_TOP_CLASSES : numTopClasses ;
142157 this .trainingPercent = trainingPercent == null ? 100.0 : trainingPercent ;
143158 this .randomizeSeed = randomizeSeed == null ? Randomness .get ().nextLong () : randomizeSeed ;
159+ this .featureProcessors = featureProcessors == null ? Collections .emptyList () : Collections .unmodifiableList (featureProcessors );
144160 }
145161
146162 public Classification (String dependentVariable ) {
147- this (dependentVariable , BoostedTreeParams .builder ().build (), null , null , null , null , null );
163+ this (dependentVariable , BoostedTreeParams .builder ().build (), null , null , null , null , null , null );
148164 }
149165
150166 public Classification (StreamInput in ) throws IOException {
@@ -163,6 +179,11 @@ public Classification(StreamInput in) throws IOException {
163179 } else {
164180 randomizeSeed = Randomness .get ().nextLong ();
165181 }
182+ if (in .getVersion ().onOrAfter (Version .V_7_10_0 )) {
183+ featureProcessors = Collections .unmodifiableList (in .readNamedWriteableList (PreProcessor .class ));
184+ } else {
185+ featureProcessors = Collections .emptyList ();
186+ }
166187 }
167188
168189 public String getDependentVariable () {
@@ -193,6 +214,10 @@ public long getRandomizeSeed() {
193214 return randomizeSeed ;
194215 }
195216
217+ public List <PreProcessor > getFeatureProcessors () {
218+ return featureProcessors ;
219+ }
220+
196221 @ Override
197222 public String getWriteableName () {
198223 return NAME .getPreferredName ();
@@ -211,6 +236,9 @@ public void writeTo(StreamOutput out) throws IOException {
211236 if (out .getVersion ().onOrAfter (Version .V_7_6_0 )) {
212237 out .writeOptionalLong (randomizeSeed );
213238 }
239+ if (out .getVersion ().onOrAfter (Version .V_7_10_0 )) {
240+ out .writeNamedWriteableList (featureProcessors );
241+ }
214242 }
215243
216244 @ Override
@@ -229,6 +257,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
229257 if (version .onOrAfter (Version .V_7_6_0 )) {
230258 builder .field (RANDOMIZE_SEED .getPreferredName (), randomizeSeed );
231259 }
260+ if (featureProcessors .isEmpty () == false ) {
261+ NamedXContentObjectHelper .writeNamedObjects (builder , params , true , FEATURE_PROCESSORS .getPreferredName (), featureProcessors );
262+ }
232263 builder .endObject ();
233264 return builder ;
234265 }
@@ -249,6 +280,10 @@ public Map<String, Object> getParams(FieldInfo fieldInfo) {
249280 }
250281 params .put (NUM_CLASSES , fieldInfo .getCardinality (dependentVariable ));
251282 params .put (TRAINING_PERCENT .getPreferredName (), trainingPercent );
283+ if (featureProcessors .isEmpty () == false ) {
284+ params .put (FEATURE_PROCESSORS .getPreferredName (),
285+ featureProcessors .stream ().map (p -> Collections .singletonMap (p .getName (), p )).collect (Collectors .toList ()));
286+ }
252287 return params ;
253288 }
254289
@@ -390,14 +425,15 @@ public boolean equals(Object o) {
390425 && Objects .equals (predictionFieldName , that .predictionFieldName )
391426 && Objects .equals (classAssignmentObjective , that .classAssignmentObjective )
392427 && Objects .equals (numTopClasses , that .numTopClasses )
428+ && Objects .equals (featureProcessors , that .featureProcessors )
393429 && trainingPercent == that .trainingPercent
394430 && randomizeSeed == that .randomizeSeed ;
395431 }
396432
397433 @ Override
398434 public int hashCode () {
399435 return Objects .hash (dependentVariable , boostedTreeParams , predictionFieldName , classAssignmentObjective ,
400- numTopClasses , trainingPercent , randomizeSeed );
436+ numTopClasses , trainingPercent , randomizeSeed , featureProcessors );
401437 }
402438
403439 public enum ClassAssignmentObjective {
0 commit comments