Skip to content

Commit 8f30228

Browse files
authored
[ML] adds new feature_processors field for data frame analytics (#60528) (#61148)
feature_processors allow users to create custom features from individual document fields. These `feature_processors` are the same object as the trained model's pre_processors. They are passed to the native process and the native process then appends them to the pre_processor array in the inference model. closes #59327
1 parent d1b6026 commit 8f30228

File tree

43 files changed

+1588
-192
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+1588
-192
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,14 @@
1515
import org.elasticsearch.common.xcontent.XContentBuilder;
1616
import org.elasticsearch.common.xcontent.XContentParser;
1717
import org.elasticsearch.index.mapper.FieldAliasMapper;
18+
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
19+
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
20+
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
1821
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
1922
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
2023
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType;
2124
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
25+
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
2226

2327
import java.io.IOException;
2428
import java.util.Arrays;
@@ -46,6 +50,7 @@ public class Classification implements DataFrameAnalysis {
4650
public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
4751
public static final ParseField TRAINING_PERCENT = new ParseField("training_percent");
4852
public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
53+
public static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors");
4954

5055
private static final String STATE_DOC_ID_SUFFIX = "_classification_state#1";
5156

@@ -59,6 +64,7 @@ public class Classification implements DataFrameAnalysis {
5964
*/
6065
public static final int MAX_DEPENDENT_VARIABLE_CARDINALITY = 30;
6166

67+
@SuppressWarnings("unchecked")
6268
private static ConstructingObjectParser<Classification, Void> createParser(boolean lenient) {
6369
ConstructingObjectParser<Classification, Void> parser = new ConstructingObjectParser<>(
6470
NAME.getPreferredName(),
@@ -70,14 +76,21 @@ private static ConstructingObjectParser<Classification, Void> createParser(boole
7076
(ClassAssignmentObjective) a[8],
7177
(Integer) a[9],
7278
(Double) a[10],
73-
(Long) a[11]));
79+
(Long) a[11],
80+
(List<PreProcessor>) a[12]));
7481
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
7582
BoostedTreeParams.declareFields(parser);
7683
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
7784
parser.declareString(optionalConstructorArg(), ClassAssignmentObjective::fromString, CLASS_ASSIGNMENT_OBJECTIVE);
7885
parser.declareInt(optionalConstructorArg(), NUM_TOP_CLASSES);
7986
parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT);
8087
parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED);
88+
parser.declareNamedObjects(optionalConstructorArg(),
89+
(p, c, n) -> lenient ?
90+
p.namedObject(LenientlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)) :
91+
p.namedObject(StrictlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)),
92+
(classification) -> {/*TODO should we throw if this is not set?*/},
93+
FEATURE_PROCESSORS);
8194
return parser;
8295
}
8396

@@ -119,14 +132,16 @@ public static Classification fromXContent(XContentParser parser, boolean ignoreU
119132
private final int numTopClasses;
120133
private final double trainingPercent;
121134
private final long randomizeSeed;
135+
private final List<PreProcessor> featureProcessors;
122136

123137
public Classification(String dependentVariable,
124138
BoostedTreeParams boostedTreeParams,
125139
@Nullable String predictionFieldName,
126140
@Nullable ClassAssignmentObjective classAssignmentObjective,
127141
@Nullable Integer numTopClasses,
128142
@Nullable Double trainingPercent,
129-
@Nullable Long randomizeSeed) {
143+
@Nullable Long randomizeSeed,
144+
@Nullable List<PreProcessor> featureProcessors) {
130145
if (numTopClasses != null && (numTopClasses < 0 || numTopClasses > 1000)) {
131146
throw ExceptionsHelper.badRequestException("[{}] must be an integer in [0, 1000]", NUM_TOP_CLASSES.getPreferredName());
132147
}
@@ -141,10 +156,11 @@ public Classification(String dependentVariable,
141156
this.numTopClasses = numTopClasses == null ? DEFAULT_NUM_TOP_CLASSES : numTopClasses;
142157
this.trainingPercent = trainingPercent == null ? 100.0 : trainingPercent;
143158
this.randomizeSeed = randomizeSeed == null ? Randomness.get().nextLong() : randomizeSeed;
159+
this.featureProcessors = featureProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(featureProcessors);
144160
}
145161

146162
public Classification(String dependentVariable) {
147-
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null);
163+
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null, null);
148164
}
149165

150166
public Classification(StreamInput in) throws IOException {
@@ -163,6 +179,11 @@ public Classification(StreamInput in) throws IOException {
163179
} else {
164180
randomizeSeed = Randomness.get().nextLong();
165181
}
182+
if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
183+
featureProcessors = Collections.unmodifiableList(in.readNamedWriteableList(PreProcessor.class));
184+
} else {
185+
featureProcessors = Collections.emptyList();
186+
}
166187
}
167188

168189
public String getDependentVariable() {
@@ -193,6 +214,10 @@ public long getRandomizeSeed() {
193214
return randomizeSeed;
194215
}
195216

217+
public List<PreProcessor> getFeatureProcessors() {
218+
return featureProcessors;
219+
}
220+
196221
@Override
197222
public String getWriteableName() {
198223
return NAME.getPreferredName();
@@ -211,6 +236,9 @@ public void writeTo(StreamOutput out) throws IOException {
211236
if (out.getVersion().onOrAfter(Version.V_7_6_0)) {
212237
out.writeOptionalLong(randomizeSeed);
213238
}
239+
if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
240+
out.writeNamedWriteableList(featureProcessors);
241+
}
214242
}
215243

216244
@Override
@@ -229,6 +257,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
229257
if (version.onOrAfter(Version.V_7_6_0)) {
230258
builder.field(RANDOMIZE_SEED.getPreferredName(), randomizeSeed);
231259
}
260+
if (featureProcessors.isEmpty() == false) {
261+
NamedXContentObjectHelper.writeNamedObjects(builder, params, true, FEATURE_PROCESSORS.getPreferredName(), featureProcessors);
262+
}
232263
builder.endObject();
233264
return builder;
234265
}
@@ -249,6 +280,10 @@ public Map<String, Object> getParams(FieldInfo fieldInfo) {
249280
}
250281
params.put(NUM_CLASSES, fieldInfo.getCardinality(dependentVariable));
251282
params.put(TRAINING_PERCENT.getPreferredName(), trainingPercent);
283+
if (featureProcessors.isEmpty() == false) {
284+
params.put(FEATURE_PROCESSORS.getPreferredName(),
285+
featureProcessors.stream().map(p -> Collections.singletonMap(p.getName(), p)).collect(Collectors.toList()));
286+
}
252287
return params;
253288
}
254289

@@ -390,14 +425,15 @@ public boolean equals(Object o) {
390425
&& Objects.equals(predictionFieldName, that.predictionFieldName)
391426
&& Objects.equals(classAssignmentObjective, that.classAssignmentObjective)
392427
&& Objects.equals(numTopClasses, that.numTopClasses)
428+
&& Objects.equals(featureProcessors, that.featureProcessors)
393429
&& trainingPercent == that.trainingPercent
394430
&& randomizeSeed == that.randomizeSeed;
395431
}
396432

397433
@Override
398434
public int hashCode() {
399435
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, classAssignmentObjective,
400-
numTopClasses, trainingPercent, randomizeSeed);
436+
numTopClasses, trainingPercent, randomizeSeed, featureProcessors);
401437
}
402438

403439
public enum ClassAssignmentObjective {

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,13 @@
1515
import org.elasticsearch.common.xcontent.XContentBuilder;
1616
import org.elasticsearch.common.xcontent.XContentParser;
1717
import org.elasticsearch.index.mapper.NumberFieldMapper;
18+
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
19+
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
20+
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
1821
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
1922
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
2023
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
24+
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
2125

2226
import java.io.IOException;
2327
import java.util.Arrays;
@@ -28,6 +32,7 @@
2832
import java.util.Map;
2933
import java.util.Objects;
3034
import java.util.Set;
35+
import java.util.stream.Collectors;
3136

3237
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
3338
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
@@ -42,12 +47,14 @@ public class Regression implements DataFrameAnalysis {
4247
public static final ParseField RANDOMIZE_SEED = new ParseField("randomize_seed");
4348
public static final ParseField LOSS_FUNCTION = new ParseField("loss_function");
4449
public static final ParseField LOSS_FUNCTION_PARAMETER = new ParseField("loss_function_parameter");
50+
public static final ParseField FEATURE_PROCESSORS = new ParseField("feature_processors");
4551

4652
private static final String STATE_DOC_ID_SUFFIX = "_regression_state#1";
4753

4854
private static final ConstructingObjectParser<Regression, Void> LENIENT_PARSER = createParser(true);
4955
private static final ConstructingObjectParser<Regression, Void> STRICT_PARSER = createParser(false);
5056

57+
@SuppressWarnings("unchecked")
5158
private static ConstructingObjectParser<Regression, Void> createParser(boolean lenient) {
5259
ConstructingObjectParser<Regression, Void> parser = new ConstructingObjectParser<>(
5360
NAME.getPreferredName(),
@@ -59,14 +66,21 @@ private static ConstructingObjectParser<Regression, Void> createParser(boolean l
5966
(Double) a[8],
6067
(Long) a[9],
6168
(LossFunction) a[10],
62-
(Double) a[11]));
69+
(Double) a[11],
70+
(List<PreProcessor>) a[12]));
6371
parser.declareString(constructorArg(), DEPENDENT_VARIABLE);
6472
BoostedTreeParams.declareFields(parser);
6573
parser.declareString(optionalConstructorArg(), PREDICTION_FIELD_NAME);
6674
parser.declareDouble(optionalConstructorArg(), TRAINING_PERCENT);
6775
parser.declareLong(optionalConstructorArg(), RANDOMIZE_SEED);
6876
parser.declareString(optionalConstructorArg(), LossFunction::fromString, LOSS_FUNCTION);
6977
parser.declareDouble(optionalConstructorArg(), LOSS_FUNCTION_PARAMETER);
78+
parser.declareNamedObjects(optionalConstructorArg(),
79+
(p, c, n) -> lenient ?
80+
p.namedObject(LenientlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)) :
81+
p.namedObject(StrictlyParsedPreProcessor.class, n, new PreProcessor.PreProcessorParseContext(true)),
82+
(regression) -> {/*TODO should we throw if this is not set?*/},
83+
FEATURE_PROCESSORS);
7084
return parser;
7185
}
7286

@@ -90,14 +104,16 @@ public static Regression fromXContent(XContentParser parser, boolean ignoreUnkno
90104
private final long randomizeSeed;
91105
private final LossFunction lossFunction;
92106
private final Double lossFunctionParameter;
107+
private final List<PreProcessor> featureProcessors;
93108

94109
public Regression(String dependentVariable,
95110
BoostedTreeParams boostedTreeParams,
96111
@Nullable String predictionFieldName,
97112
@Nullable Double trainingPercent,
98113
@Nullable Long randomizeSeed,
99114
@Nullable LossFunction lossFunction,
100-
@Nullable Double lossFunctionParameter) {
115+
@Nullable Double lossFunctionParameter,
116+
@Nullable List<PreProcessor> featureProcessors) {
101117
if (trainingPercent != null && (trainingPercent < 1.0 || trainingPercent > 100.0)) {
102118
throw ExceptionsHelper.badRequestException("[{}] must be a double in [1, 100]", TRAINING_PERCENT.getPreferredName());
103119
}
@@ -112,10 +128,11 @@ public Regression(String dependentVariable,
112128
throw ExceptionsHelper.badRequestException("[{}] must be a positive double", LOSS_FUNCTION_PARAMETER.getPreferredName());
113129
}
114130
this.lossFunctionParameter = lossFunctionParameter;
131+
this.featureProcessors = featureProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(featureProcessors);
115132
}
116133

117134
public Regression(String dependentVariable) {
118-
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null);
135+
this(dependentVariable, BoostedTreeParams.builder().build(), null, null, null, null, null, null);
119136
}
120137

121138
public Regression(StreamInput in) throws IOException {
@@ -136,6 +153,11 @@ public Regression(StreamInput in) throws IOException {
136153
lossFunction = LossFunction.MSE;
137154
lossFunctionParameter = null;
138155
}
156+
if (in.getVersion().onOrAfter(Version.V_7_10_0)) {
157+
featureProcessors = Collections.unmodifiableList(in.readNamedWriteableList(PreProcessor.class));
158+
} else {
159+
featureProcessors = Collections.emptyList();
160+
}
139161
}
140162

141163
public String getDependentVariable() {
@@ -166,6 +188,10 @@ public Double getLossFunctionParameter() {
166188
return lossFunctionParameter;
167189
}
168190

191+
public List<PreProcessor> getFeatureProcessors() {
192+
return featureProcessors;
193+
}
194+
169195
@Override
170196
public String getWriteableName() {
171197
return NAME.getPreferredName();
@@ -184,6 +210,9 @@ public void writeTo(StreamOutput out) throws IOException {
184210
out.writeEnum(lossFunction);
185211
out.writeOptionalDouble(lossFunctionParameter);
186212
}
213+
if (out.getVersion().onOrAfter(Version.V_7_10_0)) {
214+
out.writeNamedWriteableList(featureProcessors);
215+
}
187216
}
188217

189218
@Override
@@ -204,6 +233,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
204233
if (lossFunctionParameter != null) {
205234
builder.field(LOSS_FUNCTION_PARAMETER.getPreferredName(), lossFunctionParameter);
206235
}
236+
if (featureProcessors.isEmpty() == false) {
237+
NamedXContentObjectHelper.writeNamedObjects(builder, params, true, FEATURE_PROCESSORS.getPreferredName(), featureProcessors);
238+
}
207239
builder.endObject();
208240
return builder;
209241
}
@@ -221,6 +253,10 @@ public Map<String, Object> getParams(FieldInfo fieldInfo) {
221253
if (lossFunctionParameter != null) {
222254
params.put(LOSS_FUNCTION_PARAMETER.getPreferredName(), lossFunctionParameter);
223255
}
256+
if (featureProcessors.isEmpty() == false) {
257+
params.put(FEATURE_PROCESSORS.getPreferredName(),
258+
featureProcessors.stream().map(p -> Collections.singletonMap(p.getName(), p)).collect(Collectors.toList()));
259+
}
224260
return params;
225261
}
226262

@@ -304,13 +340,14 @@ public boolean equals(Object o) {
304340
&& trainingPercent == that.trainingPercent
305341
&& randomizeSeed == that.randomizeSeed
306342
&& lossFunction == that.lossFunction
343+
&& Objects.equals(featureProcessors, that.featureProcessors)
307344
&& Objects.equals(lossFunctionParameter, that.lossFunctionParameter);
308345
}
309346

310347
@Override
311348
public int hashCode() {
312349
return Objects.hash(dependentVariable, boostedTreeParams, predictionFieldName, trainingPercent, randomizeSeed, lossFunction,
313-
lossFunctionParameter);
350+
lossFunctionParameter, featureProcessors);
314351
}
315352

316353
public enum LossFunction {

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,23 +57,23 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
5757

5858
// PreProcessing Lenient
5959
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, OneHotEncoding.NAME,
60-
OneHotEncoding::fromXContentLenient));
60+
(p, c) -> OneHotEncoding.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c)));
6161
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, TargetMeanEncoding.NAME,
62-
TargetMeanEncoding::fromXContentLenient));
62+
(p, c) -> TargetMeanEncoding.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c)));
6363
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, FrequencyEncoding.NAME,
64-
FrequencyEncoding::fromXContentLenient));
64+
(p, c) -> FrequencyEncoding.fromXContentLenient(p, (PreProcessor.PreProcessorParseContext) c)));
6565
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedPreProcessor.class, CustomWordEmbedding.NAME,
66-
CustomWordEmbedding::fromXContentLenient));
66+
(p, c) -> CustomWordEmbedding.fromXContentLenient(p)));
6767

6868
// PreProcessing Strict
6969
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, OneHotEncoding.NAME,
70-
OneHotEncoding::fromXContentStrict));
70+
(p, c) -> OneHotEncoding.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c)));
7171
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, TargetMeanEncoding.NAME,
72-
TargetMeanEncoding::fromXContentStrict));
72+
(p, c) -> TargetMeanEncoding.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c)));
7373
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, FrequencyEncoding.NAME,
74-
FrequencyEncoding::fromXContentStrict));
74+
(p, c) -> FrequencyEncoding.fromXContentStrict(p, (PreProcessor.PreProcessorParseContext) c)));
7575
namedXContent.add(new NamedXContentRegistry.Entry(StrictlyParsedPreProcessor.class, CustomWordEmbedding.NAME,
76-
CustomWordEmbedding::fromXContentStrict));
76+
(p, c) -> CustomWordEmbedding.fromXContentStrict(p)));
7777

7878
// Model Lenient
7979
namedXContent.add(new NamedXContentRegistry.Entry(LenientlyParsedTrainedModel.class, Tree.NAME, Tree::fromXContentLenient));

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ private static ObjectParser<TrainedModelDefinition.Builder, Void> createParser(b
5656
TRAINED_MODEL);
5757
parser.declareNamedObjects(TrainedModelDefinition.Builder::setPreProcessors,
5858
(p, c, n) -> ignoreUnknownFields ?
59-
p.namedObject(LenientlyParsedPreProcessor.class, n, null) :
60-
p.namedObject(StrictlyParsedPreProcessor.class, n, null),
59+
p.namedObject(LenientlyParsedPreProcessor.class, n, PreProcessor.PreProcessorParseContext.DEFAULT) :
60+
p.namedObject(StrictlyParsedPreProcessor.class, n, PreProcessor.PreProcessorParseContext.DEFAULT),
6161
(trainedModelDefBuilder) -> trainedModelDefBuilder.setProcessorsInOrder(true),
6262
PREPROCESSORS);
6363
return parser;

0 commit comments

Comments
 (0)