From 77f11ca26918bce823837542be24f1f08f808399 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Tue, 8 Oct 2019 15:43:58 -0400 Subject: [PATCH 1/3] [ML][Inference] Adjust inference configuration option API --- .../xpack/core/XPackClientPlugin.java | 6 ++ .../core/ml/action/InferModelAction.java | 27 ++++---- .../MlInferenceNamedXContentProvider.java | 7 ++ .../ml/inference/TrainedModelDefinition.java | 6 +- ...eParams.java => ClassificationConfig.java} | 32 +++++++--- .../trainedmodel/InferenceConfig.java | 16 +++++ .../trainedmodel/RegressionConfig.java | 64 +++++++++++++++++++ .../inference/trainedmodel/TrainedModel.java | 3 +- .../trainedmodel/ensemble/Ensemble.java | 32 ++++++---- .../ensemble/NullInferenceConfig.java | 47 ++++++++++++++ .../ensemble/OutputAggregator.java | 3 +- .../trainedmodel/ensemble/WeightedMode.java | 10 +-- .../trainedmodel/ensemble/WeightedSum.java | 5 +- .../ml/inference/trainedmodel/tree/Tree.java | 25 +++++--- .../action/InferModelActionRequestTests.java | 22 ++++++- .../ClassificationConfigTests.java | 27 ++++++++ ...sTests.java => RegressionConfigTests.java} | 14 ++-- .../trainedmodel/ensemble/EnsembleTests.java | 51 +++++++++++---- .../ensemble/WeightedModeTests.java | 8 +++ .../ensemble/WeightedSumTests.java | 8 +++ .../trainedmodel/tree/TreeTests.java | 17 ++--- .../inference/loadingservice/LocalModel.java | 6 +- .../ml/inference/loadingservice/Model.java | 4 +- .../loadingservice/LocalModelTests.java | 26 ++++---- .../integration/ModelInferenceActionIT.java | 6 +- 25 files changed, 366 insertions(+), 106 deletions(-) rename x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/{InferenceParams.java => ClassificationConfig.java} (67%) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/NullInferenceConfig.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigTests.java rename x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/{InferenceParamsTests.java => RegressionConfigTests.java} (50%) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java index 2c8553d68fe98..6b7da4b0ff57b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java @@ -143,6 +143,9 @@ import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator; @@ -471,6 +474,9 @@ public List getNamedWriteables() { new NamedWriteableRegistry.Entry(InferenceResults.class, RegressionInferenceResults.NAME, RegressionInferenceResults::new), + // ML - Inference Configuration + new NamedWriteableRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME, ClassificationConfig::new), + new NamedWriteableRegistry.Entry(InferenceConfig.class, RegressionConfig.NAME, RegressionConfig::new), // monitoring new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.MONITORING, MonitoringFeatureSetUsage::new), diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java index 248d6180d3256..67e3a75283d67 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java @@ -13,7 +13,8 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; @@ -37,24 +38,24 @@ public static class Request extends ActionRequest { private final String modelId; private final long modelVersion; private final List> objectsToInfer; - private final InferenceParams params; + private final InferenceConfig config; public Request(String modelId, long modelVersion) { - this(modelId, modelVersion, Collections.emptyList(), InferenceParams.EMPTY_PARAMS); + this(modelId, modelVersion, Collections.emptyList(), new RegressionConfig()); } - public Request(String modelId, long modelVersion, List> objectsToInfer, InferenceParams inferenceParams) { + public Request(String modelId, long modelVersion, List> objectsToInfer, InferenceConfig inferenceConfig) { this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID); this.modelVersion = modelVersion; this.objectsToInfer = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(objectsToInfer, "objects_to_infer")); - this.params = inferenceParams == null ? InferenceParams.EMPTY_PARAMS : inferenceParams; + this.config = ExceptionsHelper.requireNonNull(inferenceConfig, "inference_config"); } - public Request(String modelId, long modelVersion, Map objectToInfer, InferenceParams params) { + public Request(String modelId, long modelVersion, Map objectToInfer, InferenceConfig config) { this(modelId, modelVersion, Arrays.asList(ExceptionsHelper.requireNonNull(objectToInfer, "objects_to_infer")), - params); + config); } public Request(StreamInput in) throws IOException { @@ -62,7 +63,7 @@ public Request(StreamInput in) throws IOException { this.modelId = in.readString(); this.modelVersion = in.readVLong(); this.objectsToInfer = Collections.unmodifiableList(in.readList(StreamInput::readMap)); - this.params = new InferenceParams(in); + this.config = in.readNamedWriteable(InferenceConfig.class); } public String getModelId() { @@ -77,8 +78,8 @@ public List> getObjectsToInfer() { return objectsToInfer; } - public InferenceParams getParams() { - return params; + public InferenceConfig getConfig() { + return config; } @Override @@ -92,7 +93,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(modelId); out.writeVLong(modelVersion); out.writeCollection(objectsToInfer, StreamOutput::writeMap); - params.writeTo(out); + out.writeNamedWriteable(config); } @Override @@ -102,13 +103,13 @@ public boolean equals(Object o) { InferModelAction.Request that = (InferModelAction.Request) o; return Objects.equals(modelId, that.modelId) && Objects.equals(modelVersion, that.modelVersion) - && Objects.equals(params, that.params) + && Objects.equals(config, that.config) && Objects.equals(objectsToInfer, that.objectsToInfer); } @Override public int hashCode() { - return Objects.hash(modelId, modelVersion, objectsToInfer, params); + return Objects.hash(modelId, modelVersion, objectsToInfer, config); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java index 7b56a4c3b4da3..352271b6f27ab 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java @@ -11,7 +11,10 @@ import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; @@ -111,6 +114,10 @@ public List getNamedWriteables() { RegressionInferenceResults.NAME, RegressionInferenceResults::new)); + // Inference Configs + namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME, ClassificationConfig::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class, RegressionConfig.NAME, RegressionConfig::new)); + return namedWriteables; } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java index e936d60bf87b7..0798e721ed17f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java @@ -19,7 +19,7 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; @@ -125,9 +125,9 @@ private void preProcess(Map fields) { preProcessors.forEach(preProcessor -> preProcessor.process(fields)); } - public InferenceResults infer(Map fields, InferenceParams params) { + public InferenceResults infer(Map fields, InferenceConfig config) { preProcess(fields); - return trainedModel.infer(fields, params); + return trainedModel.infer(fields, config); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParams.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java similarity index 67% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParams.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java index 150bf3d483f26..5aa0403d94753 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParams.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java @@ -8,26 +8,26 @@ import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import java.io.IOException; import java.util.Objects; -public class InferenceParams implements ToXContentObject, Writeable { +public class ClassificationConfig implements InferenceConfig { - public static ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); + public static final String NAME = "classification"; - public static InferenceParams EMPTY_PARAMS = new InferenceParams(0); + public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); + + public static ClassificationConfig EMPTY_PARAMS = new ClassificationConfig(0); private final int numTopClasses; - public InferenceParams(Integer numTopClasses) { + public ClassificationConfig(Integer numTopClasses) { this.numTopClasses = numTopClasses == null ? 0 : numTopClasses; } - public InferenceParams(StreamInput in) throws IOException { + public ClassificationConfig(StreamInput in) throws IOException { this.numTopClasses = in.readInt(); } @@ -44,7 +44,7 @@ public void writeTo(StreamOutput out) throws IOException { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - InferenceParams that = (InferenceParams) o; + ClassificationConfig that = (ClassificationConfig) o; return Objects.equals(numTopClasses, that.numTopClasses); } @@ -62,4 +62,20 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.endObject(); return builder; } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public String getName() { + return NAME; + } + + @Override + public boolean isTargetTypeSupported(TargetType targetType) { + return TargetType.CLASSIFICATION.equals(targetType); + } + } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java new file mode 100644 index 0000000000000..6129d71d5ff95 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; + + +public interface InferenceConfig extends NamedXContentObject, NamedWriteable { + + boolean isTargetTypeSupported(TargetType targetType); + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java new file mode 100644 index 0000000000000..bb7f772f86ba4 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java @@ -0,0 +1,64 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +public class RegressionConfig implements InferenceConfig { + + public static final String NAME = "regression"; + + public RegressionConfig() { + } + + public RegressionConfig(StreamInput in) { + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + } + + @Override + public String getName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + return true; + } + + @Override + public int hashCode() { + return Objects.hash(NAME); + } + + @Override + public boolean isTargetTypeSupported(TargetType targetType) { + return TargetType.REGRESSION.equals(targetType); + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java index a6c6f1eff011d..d1215943cbe12 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java @@ -24,10 +24,11 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable { * Infer against the provided fields * * @param fields The fields and their values to infer against + * @param config The configuration options for inference * @return The predicted value. For classification this will be discrete values (e.g. 0.0, or 1.0). * For regression this is continuous. */ - InferenceResults infer(Map fields, InferenceParams params); + InferenceResults infer(Map fields, InferenceConfig config); /** * @return {@link TargetType} for the model. diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java index 09e418cec916d..ff03a621d99fa 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java @@ -16,8 +16,9 @@ import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; @@ -114,21 +115,18 @@ public List getFeatureNames() { } @Override - public InferenceResults infer(Map fields, InferenceParams params) { - if (params.getNumTopClasses() != 0 && - (targetType != TargetType.CLASSIFICATION || outputAggregator.providesProbabilities() == false)) { + public InferenceResults infer(Map fields, InferenceConfig config) { + if (config.isTargetTypeSupported(targetType) == false) { throw ExceptionsHelper.badRequestException( - "Cannot return top classes for target_type [{}] and aggregate_output [{}]", - targetType, - outputAggregator.getName()); + "Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString()); } List inferenceResults = this.models.stream().map(model -> { - InferenceResults results = model.infer(fields, InferenceParams.EMPTY_PARAMS); + InferenceResults results = model.infer(fields, NullInferenceConfig.INSTANCE); assert results instanceof SingleValueInferenceResults; return ((SingleValueInferenceResults)results).value(); }).collect(Collectors.toList()); List processed = outputAggregator.processValues(inferenceResults); - return buildResults(processed, params); + return buildResults(processed, config); } @Override @@ -136,13 +134,16 @@ public TargetType targetType() { return targetType; } - private InferenceResults buildResults(List processedInferences, InferenceParams params) { + private InferenceResults buildResults(List processedInferences, InferenceConfig config) { switch(targetType) { case REGRESSION: return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences)); case CLASSIFICATION: - List topClasses = - InferenceHelpers.topClasses(processedInferences, classificationLabels, params.getNumTopClasses()); + ClassificationConfig classificationConfig = (ClassificationConfig) config; + List topClasses = InferenceHelpers.topClasses( + processedInferences, + classificationLabels, + classificationConfig.getNumTopClasses()); double value = outputAggregator.aggregate(processedInferences); return new ClassificationInferenceResults(outputAggregator.aggregate(processedInferences), classificationLabel(value, classificationLabels), @@ -216,6 +217,13 @@ public int hashCode() { @Override public void validate() { + if (outputAggregator.compatibleWith(targetType) == false) { + throw ExceptionsHelper.badRequestException( + "aggregate_output [{}] is not compatible with target_type [{}]", + this.targetType, + outputAggregator.getName() + ); + } if (outputAggregator.expectedValueSize() != null && outputAggregator.expectedValueSize() != models.size()) { throw ExceptionsHelper.badRequestException( diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/NullInferenceConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/NullInferenceConfig.java new file mode 100644 index 0000000000000..7628d0beec25f --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/NullInferenceConfig.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; + +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; + +import java.io.IOException; + +/** + * Used by ensemble to pass into sub-models. + */ +class NullInferenceConfig implements InferenceConfig { + + public static final NullInferenceConfig INSTANCE = new NullInferenceConfig(); + + private NullInferenceConfig() { } + + @Override + public boolean isTargetTypeSupported(TargetType targetType) { + return true; + } + + @Override + public String getWriteableName() { + return "null"; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + } + + @Override + public String getName() { + return "null"; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java index 012f474ab0618..f19ae376f0e96 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; import java.util.List; @@ -45,5 +46,5 @@ public interface OutputAggregator extends NamedXContentObject, NamedWriteable { */ String getName(); - boolean providesProbabilities(); + boolean compatibleWith(TargetType targetType); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java index 0689d748b0ccb..a872565ad20b5 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import java.io.IOException; import java.util.ArrayList; @@ -123,6 +124,11 @@ public String getName() { return NAME.getPreferredName(); } + @Override + public boolean compatibleWith(TargetType targetType) { + return true; + } + @Override public String getWriteableName() { return NAME.getPreferredName(); @@ -159,8 +165,4 @@ public int hashCode() { return Objects.hash(weights); } - @Override - public boolean providesProbabilities() { - return true; - } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java index 9c5c2bf582e54..db70346c0849e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import java.io.IOException; import java.util.Collections; @@ -137,7 +138,7 @@ public Integer expectedValueSize() { } @Override - public boolean providesProbabilities() { - return false; + public boolean compatibleWith(TargetType targetType) { + return TargetType.REGRESSION.equals(targetType); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java index bce6b08b6ed4b..a48cca3873117 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java @@ -16,8 +16,9 @@ import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; @@ -112,30 +113,34 @@ public List getNodes() { } @Override - public InferenceResults infer(Map fields, InferenceParams params) { - if (targetType != TargetType.CLASSIFICATION && params.getNumTopClasses() != 0) { + public InferenceResults infer(Map fields, InferenceConfig config) { + if (config.isTargetTypeSupported(targetType) == false) { throw ExceptionsHelper.badRequestException( - "Cannot return top classes for target_type [{}]", targetType.toString()); + "Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString()); } + List features = featureNames.stream().map(f -> fields.get(f) instanceof Number ? ((Number) fields.get(f)).doubleValue() : null ).collect(Collectors.toList()); - return infer(features, params); + return infer(features, config); } - private InferenceResults infer(List features, InferenceParams params) { + private InferenceResults infer(List features, InferenceConfig config) { TreeNode node = nodes.get(0); while(node.isLeaf() == false) { node = nodes.get(node.compare(features)); } - return buildResult(node.getLeafValue(), params); + return buildResult(node.getLeafValue(), config); } - private InferenceResults buildResult(Double value, InferenceParams params) { + private InferenceResults buildResult(Double value, InferenceConfig config) { switch (targetType) { case CLASSIFICATION: - List topClasses = - InferenceHelpers.topClasses(classificationProbability(value), classificationLabels, params.getNumTopClasses()); + ClassificationConfig classificationConfig = (ClassificationConfig) config; + List topClasses = InferenceHelpers.topClasses( + classificationProbability(value), + classificationLabels, + classificationConfig.getNumTopClasses()); return new ClassificationInferenceResults(value, classificationLabel(value, classificationLabels), topClasses); case REGRESSION: return new RegressionInferenceResults(value); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java index a49643d081957..ed782a04e0c86 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java @@ -5,16 +5,22 @@ */ package org.elasticsearch.xpack.core.ml.action; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.core.ml.action.InferModelAction.Request; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigTests; +import java.util.ArrayList; +import java.util.List; import java.util.Map; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; -import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParamsTests.randomInferenceParams; public class InferModelActionRequestTests extends AbstractWireSerializingTestCase { @@ -25,12 +31,16 @@ protected Request createTestInstance() { randomAlphaOfLength(10), randomLongBetween(1, 100), Stream.generate(InferModelActionRequestTests::randomMap).limit(randomInt(10)).collect(Collectors.toList()), - randomBoolean() ? null : randomInferenceParams()) : + randomInferenceConfig()) : new Request( randomAlphaOfLength(10), randomLongBetween(1, 100), randomMap(), - randomBoolean() ? null : randomInferenceParams()); + randomInferenceConfig()); + } + + private static InferenceConfig randomInferenceConfig() { + return randomFrom(RegressionConfigTests.randomRegressionConfig(), ClassificationConfigTests.randomClassificationConfig()); } private static Map randomMap() { @@ -44,4 +54,10 @@ protected Writeable.Reader instanceReader() { return Request::new; } + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + List entries = new ArrayList<>(); + entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); + return new NamedWriteableRegistry(entries); + } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigTests.java new file mode 100644 index 0000000000000..4df3263215f63 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfigTests.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +public class ClassificationConfigTests extends AbstractWireSerializingTestCase { + + public static ClassificationConfig randomClassificationConfig() { + return new ClassificationConfig(randomBoolean() ? null : randomIntBetween(-1, 10)); + } + + @Override + protected ClassificationConfig createTestInstance() { + return randomClassificationConfig(); + } + + @Override + protected Writeable.Reader instanceReader() { + return ClassificationConfig::new; + } + +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParamsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java similarity index 50% rename from x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParamsTests.java rename to x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java index 2586cdd75de47..57efcdd15009a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParamsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfigTests.java @@ -8,20 +8,20 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; -public class InferenceParamsTests extends AbstractWireSerializingTestCase { +public class RegressionConfigTests extends AbstractWireSerializingTestCase { - public static InferenceParams randomInferenceParams() { - return randomBoolean() ? InferenceParams.EMPTY_PARAMS : new InferenceParams(randomIntBetween(-1, 100)); + public static RegressionConfig randomRegressionConfig() { + return new RegressionConfig(); } @Override - protected InferenceParams createTestInstance() { - return randomInferenceParams(); + protected RegressionConfig createTestInstance() { + return randomRegressionConfig(); } @Override - protected Writeable.Reader instanceReader() { - return InferenceParams::new; + protected Writeable.Reader instanceReader() { + return RegressionConfig::new; } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java index 8da0c15718f24..816fabf7b0b67 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java @@ -17,7 +17,8 @@ import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; @@ -159,6 +160,27 @@ public void testEnsembleWithInvalidModel() { }); } + public void testEnsembleWithAggregatorOutputNotSupportingTargetType() { + List featureNames = Arrays.asList("foo", "bar"); + ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> { + Ensemble.builder() + .setFeatureNames(featureNames) + .setTrainedModels(Arrays.asList( + Tree.builder() + .setNodes(TreeNode.builder(0) + .setLeftChild(1) + .setSplitFeature(1) + .setThreshold(randomDouble())) + .setFeatureNames(featureNames) + .build())) + .setClassificationLabels(Arrays.asList("label1", "label2")) + .setTargetType(TargetType.CLASSIFICATION) + .setOutputAggregator(new WeightedSum()) + .build() + .validate(); + }); + } + public void testEnsembleWithTargetTypeAndLabelsMismatch() { List featureNames = Arrays.asList("foo", "bar"); String msg = "[target_type] should be [classification] if [classification_labels] is provided, and vice versa"; @@ -190,6 +212,7 @@ public void testEnsembleWithTargetTypeAndLabelsMismatch() { .setFeatureNames(featureNames) .build())) .setTargetType(TargetType.CLASSIFICATION) + .setOutputAggregator(new WeightedMode()) .build() .validate(); }); @@ -245,7 +268,7 @@ public void testClassificationProbability() { List expected = Arrays.asList(0.768524783, 0.231475216); double eps = 0.000001; List probabilities = - ((ClassificationInferenceResults)ensemble.infer(featureMap, new InferenceParams(2))).getTopClasses(); + ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); for(int i = 0; i < expected.size(); i++) { assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps)); } @@ -254,7 +277,7 @@ public void testClassificationProbability() { featureMap = zipObjMap(featureNames, featureVector); expected = Arrays.asList(0.689974481, 0.3100255188); probabilities = - ((ClassificationInferenceResults)ensemble.infer(featureMap, new InferenceParams(2))).getTopClasses(); + ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); for(int i = 0; i < expected.size(); i++) { assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps)); } @@ -263,7 +286,7 @@ public void testClassificationProbability() { featureMap = zipObjMap(featureNames, featureVector); expected = Arrays.asList(0.768524783, 0.231475216); probabilities = - ((ClassificationInferenceResults)ensemble.infer(featureMap, new InferenceParams(2))).getTopClasses(); + ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); for(int i = 0; i < expected.size(); i++) { assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps)); } @@ -275,7 +298,7 @@ public void testClassificationProbability() { }}; expected = Arrays.asList(0.6899744811, 0.3100255188); probabilities = - ((ClassificationInferenceResults)ensemble.infer(featureMap, new InferenceParams(2))).getTopClasses(); + ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); for(int i = 0; i < expected.size(); i++) { assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps)); } @@ -328,24 +351,24 @@ public void testClassificationInference() { List featureVector = Arrays.asList(0.4, 0.0); Map featureMap = zipObjMap(featureNames, featureVector); assertThat(1.0, - closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001)); featureVector = Arrays.asList(2.0, 0.7); featureMap = zipObjMap(featureNames, featureVector); assertThat(1.0, - closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001)); featureVector = Arrays.asList(0.0, 1.0); featureMap = zipObjMap(featureNames, featureVector); assertThat(1.0, - closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001)); featureMap = new HashMap<>(2) {{ put("foo", 0.3); put("bar", null); }}; assertThat(0.0, - closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(0))).value(), 0.00001)); } public void testRegressionInference() { @@ -385,12 +408,12 @@ public void testRegressionInference() { List featureVector = Arrays.asList(0.4, 0.0); Map featureMap = zipObjMap(featureNames, featureVector); assertThat(0.9, - closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new RegressionConfig())).value(), 0.00001)); featureVector = Arrays.asList(2.0, 0.7); featureMap = zipObjMap(featureNames, featureVector); assertThat(0.5, - closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new RegressionConfig())).value(), 0.00001)); // Test with NO aggregator supplied, verifies default behavior of non-weighted sum ensemble = Ensemble.builder() @@ -402,19 +425,19 @@ public void testRegressionInference() { featureVector = Arrays.asList(0.4, 0.0); featureMap = zipObjMap(featureNames, featureVector); assertThat(1.8, - closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new RegressionConfig())).value(), 0.00001)); featureVector = Arrays.asList(2.0, 0.7); featureMap = zipObjMap(featureNames, featureVector); assertThat(1.0, - closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new RegressionConfig())).value(), 0.00001)); featureMap = new HashMap<>(2) {{ put("foo", 0.3); put("bar", null); }}; assertThat(1.8, - closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new RegressionConfig())).value(), 0.00001)); } private static Map zipObjMap(List keys, List values) { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java index 7849d6d071ef1..683115e63879e 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedModeTests.java @@ -8,6 +8,7 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import java.io.IOException; import java.util.Arrays; @@ -15,6 +16,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.equalTo; public class WeightedModeTests extends WeightedAggregatorTests { @@ -55,4 +57,10 @@ public void testAggregate() { weightedMode = new WeightedMode(); assertThat(weightedMode.aggregate(weightedMode.processValues(values)), equalTo(2.0)); } + + public void testCompatibleWith() { + WeightedMode weightedMode = createTestInstance(); + assertThat(weightedMode.compatibleWith(TargetType.CLASSIFICATION), is(true)); + assertThat(weightedMode.compatibleWith(TargetType.REGRESSION), is(true)); + } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java index 89222365c83d8..fa372f043a410 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSumTests.java @@ -8,6 +8,7 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import java.io.IOException; import java.util.Arrays; @@ -15,6 +16,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.equalTo; public class WeightedSumTests extends WeightedAggregatorTests { @@ -55,4 +57,10 @@ public void testAggregate() { weightedSum = new WeightedSum(); assertThat(weightedSum.aggregate(weightedSum.processValues(values)), equalTo(13.0)); } + + public void testCompatibleWith() { + WeightedSum weightedSum = createTestInstance(); + assertThat(weightedSum.compatibleWith(TargetType.CLASSIFICATION), is(false)); + assertThat(weightedSum.compatibleWith(TargetType.REGRESSION), is(true)); + } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java index c362c17fd579d..075bfbe912270 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java @@ -12,7 +12,8 @@ import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.junit.Before; @@ -124,21 +125,21 @@ public void testInfer() { List featureVector = Arrays.asList(0.6, 0.0); Map featureMap = zipObjMap(featureNames, featureVector); assertThat(0.3, - closeTo(((SingleValueInferenceResults)tree.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, new RegressionConfig())).value(), 0.00001)); // This should hit the left child of the left child of the root node // i.e. it takes the path left, left featureVector = Arrays.asList(0.3, 0.7); featureMap = zipObjMap(featureNames, featureVector); assertThat(0.1, - closeTo(((SingleValueInferenceResults)tree.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, new RegressionConfig())).value(), 0.00001)); // This should hit the right child of the left child of the root node // i.e. it takes the path left, right featureVector = Arrays.asList(0.3, 0.9); featureMap = zipObjMap(featureNames, featureVector); assertThat(0.2, - closeTo(((SingleValueInferenceResults)tree.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, new RegressionConfig())).value(), 0.00001)); // This should handle missing values and take the default_left path featureMap = new HashMap<>(2) {{ @@ -146,7 +147,7 @@ public void testInfer() { put("bar", null); }}; assertThat(0.1, - closeTo(((SingleValueInferenceResults)tree.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, new RegressionConfig())).value(), 0.00001)); } public void testTreeClassificationProbability() { @@ -169,7 +170,7 @@ public void testTreeClassificationProbability() { List expectedFields = Arrays.asList("dog", "cat"); Map featureMap = zipObjMap(featureNames, featureVector); List probabilities = - ((ClassificationInferenceResults)tree.infer(featureMap, new InferenceParams(2))).getTopClasses(); + ((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); for(int i = 0; i < expectedProbs.size(); i++) { assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps)); assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i))); @@ -180,7 +181,7 @@ public void testTreeClassificationProbability() { featureVector = Arrays.asList(0.3, 0.7); featureMap = zipObjMap(featureNames, featureVector); probabilities = - ((ClassificationInferenceResults)tree.infer(featureMap, new InferenceParams(2))).getTopClasses(); + ((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); for(int i = 0; i < expectedProbs.size(); i++) { assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps)); assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i))); @@ -192,7 +193,7 @@ public void testTreeClassificationProbability() { put("bar", null); }}; probabilities = - ((ClassificationInferenceResults)tree.infer(featureMap, new InferenceParams(2))).getTopClasses(); + ((ClassificationInferenceResults)tree.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); for(int i = 0; i < expectedProbs.size(); i++) { assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps)); assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i))); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java index e5253b3d5b173..9bbf42915410d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java @@ -7,7 +7,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; @@ -40,9 +40,9 @@ public String getResultsType() { } @Override - public void infer(Map fields, InferenceParams params, ActionListener listener) { + public void infer(Map fields, InferenceConfig config, ActionListener listener) { try { - listener.onResponse(trainedModelDefinition.infer(fields, params)); + listener.onResponse(trainedModelDefinition.infer(fields, config)); } catch (Exception e) { listener.onFailure(e); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java index 27924a47aa153..c66a23d78f98e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java @@ -7,7 +7,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import java.util.Map; @@ -15,6 +15,6 @@ public interface Model { String getResultsType(); - void infer(Map fields, InferenceParams inferenceParams, ActionListener listener); + void infer(Map fields, InferenceConfig inferenceConfig, ActionListener listener); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java index e66a5790d85e5..48aa70dec74f2 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java @@ -10,7 +10,9 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; @@ -49,12 +51,12 @@ public void testClassificationInfer() throws Exception { put("categorical", "dog"); }}; - SingleValueInferenceResults result = getSingleValue(model, fields, InferenceParams.EMPTY_PARAMS); + SingleValueInferenceResults result = getSingleValue(model, fields, new ClassificationConfig(0)); assertThat(result.value(), equalTo(0.0)); assertThat(result.valueAsString(), is("0.0")); ClassificationInferenceResults classificationResult = - (ClassificationInferenceResults)getSingleValue(model, fields, new InferenceParams(1)); + (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(1)); assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001)); assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("0")); @@ -65,18 +67,18 @@ public void testClassificationInfer() throws Exception { .setTrainedModel(buildClassification(true)) .build(); model = new LocalModel(modelId, definition); - result = getSingleValue(model, fields, InferenceParams.EMPTY_PARAMS); + result = getSingleValue(model, fields, new ClassificationConfig(0)); assertThat(result.value(), equalTo(0.0)); assertThat(result.valueAsString(), equalTo("not_to_be")); - classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new InferenceParams(1)); + classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(1)); assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001)); assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("not_to_be")); - classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new InferenceParams(2)); + classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(2)); assertThat(classificationResult.getTopClasses(), hasSize(2)); - classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new InferenceParams(-1)); + classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(-1)); assertThat(classificationResult.getTopClasses(), hasSize(2)); } @@ -94,21 +96,21 @@ public void testRegression() throws Exception { put("categorical", "dog"); }}; - SingleValueInferenceResults results = getSingleValue(model, fields, InferenceParams.EMPTY_PARAMS); + SingleValueInferenceResults results = getSingleValue(model, fields, new RegressionConfig()); assertThat(results.value(), equalTo(1.3)); PlainActionFuture failedFuture = new PlainActionFuture<>(); - model.infer(fields, new InferenceParams(2), failedFuture); + model.infer(fields, new ClassificationConfig(2), failedFuture); ExecutionException ex = expectThrows(ExecutionException.class, failedFuture::get); assertThat(ex.getCause().getMessage(), - equalTo("Cannot return top classes for target_type [regression] and aggregate_output [weighted_sum]")); + equalTo("Cannot infer using configuration for [classification] when model target_type is [regression]")); } private static SingleValueInferenceResults getSingleValue(Model model, Map fields, - InferenceParams params) throws Exception { + InferenceConfig config) throws Exception { PlainActionFuture future = new PlainActionFuture<>(); - model.infer(fields, params, future); + model.infer(fields, config, future); return (SingleValueInferenceResults)future.get(); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java index 032f84d16b52d..db44fa0725482 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java @@ -16,7 +16,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; @@ -127,7 +127,7 @@ public void testInferModels() throws Exception { contains("not_to_be", "to_be")); // Get top classes - request = new InferModelAction.Request(modelId2, 0, toInfer, new InferenceParams(2)); + request = new InferModelAction.Request(modelId2, 0, toInfer, new ClassificationConfig(2)); response = client().execute(InferModelAction.INSTANCE, request).actionGet(); ClassificationInferenceResults classificationInferenceResults = @@ -146,7 +146,7 @@ public void testInferModels() throws Exception { greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability())); // Test that top classes restrict the number returned - request = new InferModelAction.Request(modelId2, 0, toInfer2, new InferenceParams(1)); + request = new InferModelAction.Request(modelId2, 0, toInfer2, new ClassificationConfig(1)); response = client().execute(InferModelAction.INSTANCE, request).actionGet(); classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(0); From 213cf6ddeffebd15259c02448502f99ff5eb4f31 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Wed, 9 Oct 2019 12:42:53 -0400 Subject: [PATCH 2/3] fixing method reference --- .../xpack/ml/action/TransportInferModelAction.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java index a2f79fc9d0437..b5d0b7c4e330a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java @@ -48,7 +48,7 @@ protected void doExecute(Task task, InferModelAction.Request request, ActionList ex -> true); request.getObjectsToInfer().forEach(stringObjectMap -> typedChainTaskExecutor.add(chainedTask -> - model.infer(stringObjectMap, request.getParams(), chainedTask))); + model.infer(stringObjectMap, request.getConfig(), chainedTask))); typedChainTaskExecutor.execute(ActionListener.wrap( inferenceResultsInterfaces -> From d4cb6b3e1c570a24c7aa34472fe66135747b15d9 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Wed, 9 Oct 2019 15:33:00 -0400 Subject: [PATCH 3/3] fixing tests --- .../ml/integration/ModelInferenceActionIT.java | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java index db44fa0725482..36f5817e40f3a 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java @@ -17,6 +17,7 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; @@ -58,13 +59,13 @@ public void testInferModels() throws Exception { Map oneHotEncoding = new HashMap<>(); oneHotEncoding.put("cat", "animal_cat"); oneHotEncoding.put("dog", "animal_dog"); - TrainedModelConfig config1 = buildTrainedModelConfigBuilder(modelId2, 0) + TrainedModelConfig config1 = buildTrainedModelConfigBuilder(modelId2) .setDefinition(new TrainedModelDefinition.Builder() .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding))) .setInput(new TrainedModelDefinition.Input(Arrays.asList("field1", "field2"))) .setTrainedModel(buildClassification(true))) .build(Version.CURRENT); - TrainedModelConfig config2 = buildTrainedModelConfigBuilder(modelId1, 0) + TrainedModelConfig config2 = buildTrainedModelConfigBuilder(modelId1) .setDefinition(new TrainedModelDefinition.Builder() .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding))) .setInput(new TrainedModelDefinition.Input(Arrays.asList("field1", "field2"))) @@ -106,19 +107,19 @@ public void testInferModels() throws Exception { }}); // Test regression - InferModelAction.Request request = new InferModelAction.Request(modelId1, 0, toInfer, null); + InferModelAction.Request request = new InferModelAction.Request(modelId1, 0, toInfer, new RegressionConfig()); InferModelAction.Response response = client().execute(InferModelAction.INSTANCE, request).actionGet(); assertThat(response.getInferenceResults().stream().map(i -> ((SingleValueInferenceResults)i).value()).collect(Collectors.toList()), contains(1.3, 1.25)); - request = new InferModelAction.Request(modelId1, 0, toInfer2, null); + request = new InferModelAction.Request(modelId1, 0, toInfer2, new RegressionConfig()); response = client().execute(InferModelAction.INSTANCE, request).actionGet(); assertThat(response.getInferenceResults().stream().map(i -> ((SingleValueInferenceResults)i).value()).collect(Collectors.toList()), contains(1.65, 1.55)); // Test classification - request = new InferModelAction.Request(modelId2, 0, toInfer, null); + request = new InferModelAction.Request(modelId2, 0, toInfer, new ClassificationConfig(0)); response = client().execute(InferModelAction.INSTANCE, request).actionGet(); assertThat(response.getInferenceResults() .stream() @@ -156,7 +157,7 @@ public void testInferModels() throws Exception { public void testInferMissingModel() { String model = "test-infer-missing-model"; - InferModelAction.Request request = new InferModelAction.Request(model, 0, Collections.emptyList(), null); + InferModelAction.Request request = new InferModelAction.Request(model, 0, Collections.emptyList(), new RegressionConfig()); try { client().execute(InferModelAction.INSTANCE, request).actionGet(); } catch (ElasticsearchException ex) { @@ -164,14 +165,13 @@ public void testInferMissingModel() { } } - private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId, long modelVersion) { + private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId) { return TrainedModelConfig.builder() .setCreatedBy("ml_test") .setDefinition(TrainedModelDefinitionTests.createRandomBuilder()) .setDescription("trained model config for test") .setModelId(modelId) - .setModelType("binary_decision_tree") - .setModelVersion(modelVersion); + .setModelType("binary_decision_tree"); } @Override