1515import org .elasticsearch .common .xcontent .XContentBuilder ;
1616import org .elasticsearch .common .xcontent .XContentParser ;
1717import org .elasticsearch .index .mapper .NumberFieldMapper ;
18+ import org .elasticsearch .xpack .core .ml .inference .preprocessing .LenientlyParsedPreProcessor ;
19+ import org .elasticsearch .xpack .core .ml .inference .preprocessing .PreProcessor ;
20+ import org .elasticsearch .xpack .core .ml .inference .preprocessing .StrictlyParsedPreProcessor ;
1821import org .elasticsearch .xpack .core .ml .inference .trainedmodel .InferenceConfig ;
1922import org .elasticsearch .xpack .core .ml .inference .trainedmodel .RegressionConfig ;
2023import org .elasticsearch .xpack .core .ml .utils .ExceptionsHelper ;
24+ import org .elasticsearch .xpack .core .ml .utils .NamedXContentObjectHelper ;
2125
2226import java .io .IOException ;
2327import java .util .Arrays ;
2832import java .util .Map ;
2933import java .util .Objects ;
3034import java .util .Set ;
35+ import java .util .stream .Collectors ;
3136
3237import static org .elasticsearch .common .xcontent .ConstructingObjectParser .constructorArg ;
3338import static org .elasticsearch .common .xcontent .ConstructingObjectParser .optionalConstructorArg ;
@@ -42,12 +47,14 @@ public class Regression implements DataFrameAnalysis {
4247 public static final ParseField RANDOMIZE_SEED = new ParseField ("randomize_seed" );
4348 public static final ParseField LOSS_FUNCTION = new ParseField ("loss_function" );
4449 public static final ParseField LOSS_FUNCTION_PARAMETER = new ParseField ("loss_function_parameter" );
50+ public static final ParseField FEATURE_PROCESSORS = new ParseField ("feature_processors" );
4551
4652 private static final String STATE_DOC_ID_SUFFIX = "_regression_state#1" ;
4753
4854 private static final ConstructingObjectParser <Regression , Void > LENIENT_PARSER = createParser (true );
4955 private static final ConstructingObjectParser <Regression , Void > STRICT_PARSER = createParser (false );
5056
57+ @ SuppressWarnings ("unchecked" )
5158 private static ConstructingObjectParser <Regression , Void > createParser (boolean lenient ) {
5259 ConstructingObjectParser <Regression , Void > parser = new ConstructingObjectParser <>(
5360 NAME .getPreferredName (),
@@ -59,14 +66,21 @@ private static ConstructingObjectParser<Regression, Void> createParser(boolean l
5966 (Double ) a [8 ],
6067 (Long ) a [9 ],
6168 (LossFunction ) a [10 ],
62- (Double ) a [11 ]));
69+ (Double ) a [11 ],
70+ (List <PreProcessor >) a [12 ]));
6371 parser .declareString (constructorArg (), DEPENDENT_VARIABLE );
6472 BoostedTreeParams .declareFields (parser );
6573 parser .declareString (optionalConstructorArg (), PREDICTION_FIELD_NAME );
6674 parser .declareDouble (optionalConstructorArg (), TRAINING_PERCENT );
6775 parser .declareLong (optionalConstructorArg (), RANDOMIZE_SEED );
6876 parser .declareString (optionalConstructorArg (), LossFunction ::fromString , LOSS_FUNCTION );
6977 parser .declareDouble (optionalConstructorArg (), LOSS_FUNCTION_PARAMETER );
78+ parser .declareNamedObjects (optionalConstructorArg (),
79+ (p , c , n ) -> lenient ?
80+ p .namedObject (LenientlyParsedPreProcessor .class , n , new PreProcessor .PreProcessorParseContext (true )) :
81+ p .namedObject (StrictlyParsedPreProcessor .class , n , new PreProcessor .PreProcessorParseContext (true )),
82+ (regression ) -> {/*TODO should we throw if this is not set?*/ },
83+ FEATURE_PROCESSORS );
7084 return parser ;
7185 }
7286
@@ -90,14 +104,16 @@ public static Regression fromXContent(XContentParser parser, boolean ignoreUnkno
90104 private final long randomizeSeed ;
91105 private final LossFunction lossFunction ;
92106 private final Double lossFunctionParameter ;
107+ private final List <PreProcessor > featureProcessors ;
93108
94109 public Regression (String dependentVariable ,
95110 BoostedTreeParams boostedTreeParams ,
96111 @ Nullable String predictionFieldName ,
97112 @ Nullable Double trainingPercent ,
98113 @ Nullable Long randomizeSeed ,
99114 @ Nullable LossFunction lossFunction ,
100- @ Nullable Double lossFunctionParameter ) {
115+ @ Nullable Double lossFunctionParameter ,
116+ @ Nullable List <PreProcessor > featureProcessors ) {
101117 if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0 )) {
102118 throw ExceptionsHelper .badRequestException ("[{}] must be a double in [1, 100]" , TRAINING_PERCENT .getPreferredName ());
103119 }
@@ -112,10 +128,11 @@ public Regression(String dependentVariable,
112128 throw ExceptionsHelper .badRequestException ("[{}] must be a positive double" , LOSS_FUNCTION_PARAMETER .getPreferredName ());
113129 }
114130 this .lossFunctionParameter = lossFunctionParameter ;
131+ this .featureProcessors = featureProcessors == null ? Collections .emptyList () : Collections .unmodifiableList (featureProcessors );
115132 }
116133
117134 public Regression (String dependentVariable ) {
118- this (dependentVariable , BoostedTreeParams .builder ().build (), null , null , null , null , null );
135+ this (dependentVariable , BoostedTreeParams .builder ().build (), null , null , null , null , null , null );
119136 }
120137
121138 public Regression (StreamInput in ) throws IOException {
@@ -126,6 +143,11 @@ public Regression(StreamInput in) throws IOException {
126143 randomizeSeed = in .readOptionalLong ();
127144 lossFunction = in .readEnum (LossFunction .class );
128145 lossFunctionParameter = in .readOptionalDouble ();
146+ if (in .getVersion ().onOrAfter (Version .V_8_0_0 )) {
147+ featureProcessors = Collections .unmodifiableList (in .readNamedWriteableList (PreProcessor .class ));
148+ } else {
149+ featureProcessors = Collections .emptyList ();
150+ }
129151 }
130152
131153 public String getDependentVariable () {
@@ -156,6 +178,10 @@ public Double getLossFunctionParameter() {
156178 return lossFunctionParameter ;
157179 }
158180
181+ public List <PreProcessor > getFeatureProcessors () {
182+ return featureProcessors ;
183+ }
184+
159185 @ Override
160186 public String getWriteableName () {
161187 return NAME .getPreferredName ();
@@ -170,6 +196,9 @@ public void writeTo(StreamOutput out) throws IOException {
170196 out .writeOptionalLong (randomizeSeed );
171197 out .writeEnum (lossFunction );
172198 out .writeOptionalDouble (lossFunctionParameter );
199+ if (out .getVersion ().onOrAfter (Version .V_8_0_0 )) {
200+ out .writeNamedWriteableList (featureProcessors );
201+ }
173202 }
174203
175204 @ Override
@@ -190,6 +219,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
190219 if (lossFunctionParameter != null ) {
191220 builder .field (LOSS_FUNCTION_PARAMETER .getPreferredName (), lossFunctionParameter );
192221 }
222+ if (featureProcessors .isEmpty () == false ) {
223+ NamedXContentObjectHelper .writeNamedObjects (builder , params , true , FEATURE_PROCESSORS .getPreferredName (), featureProcessors );
224+ }
193225 builder .endObject ();
194226 return builder ;
195227 }
@@ -207,6 +239,10 @@ public Map<String, Object> getParams(FieldInfo fieldInfo) {
207239 if (lossFunctionParameter != null ) {
208240 params .put (LOSS_FUNCTION_PARAMETER .getPreferredName (), lossFunctionParameter );
209241 }
242+ if (featureProcessors .isEmpty () == false ) {
243+ params .put (FEATURE_PROCESSORS .getPreferredName (),
244+ featureProcessors .stream ().map (p -> Collections .singletonMap (p .getName (), p )).collect (Collectors .toList ()));
245+ }
210246 return params ;
211247 }
212248
@@ -290,13 +326,14 @@ public boolean equals(Object o) {
290326 && trainingPercent == that .trainingPercent
291327 && randomizeSeed == that .randomizeSeed
292328 && lossFunction == that .lossFunction
329+ && Objects .equals (featureProcessors , that .featureProcessors )
293330 && Objects .equals (lossFunctionParameter , that .lossFunctionParameter );
294331 }
295332
296333 @ Override
297334 public int hashCode () {
298335 return Objects .hash (dependentVariable , boostedTreeParams , predictionFieldName , trainingPercent , randomizeSeed , lossFunction ,
299- lossFunctionParameter );
336+ lossFunctionParameter , featureProcessors );
300337 }
301338
302339 public enum LossFunction {
0 commit comments