diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java new file mode 100644 index 0000000000000..884d66032b564 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java @@ -0,0 +1,65 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.ingest.IngestDocument; + +import java.io.IOException; +import java.util.Objects; + +public class RawInferenceResults extends SingleValueInferenceResults { + + public static final String NAME = "raw"; + + public RawInferenceResults(double value) { + super(value); + } + + public RawInferenceResults(StreamInput in) throws IOException { + super(in.readDouble()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + } + + @Override + XContentBuilder innerToXContent(XContentBuilder builder, Params params) { + return builder; + } + + @Override + public boolean equals(Object object) { + if (object == this) { return true; } + if (object == null || getClass() != object.getClass()) { return false; } + RawInferenceResults that = (RawInferenceResults) object; + return Objects.equals(value(), that.value()); + } + + @Override + public int hashCode() { + return Objects.hash(value()); + } + + @Override + public void writeResult(IngestDocument document, String resultField) { + throw new UnsupportedOperationException("[raw] does not support writing inference results"); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public String getName() { + return NAME; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/NullInferenceConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NullInferenceConfig.java similarity index 80% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/NullInferenceConfig.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NullInferenceConfig.java index 42757d889818e..b7c4a71b3e79e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/NullInferenceConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NullInferenceConfig.java @@ -3,20 +3,18 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble; +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import java.io.IOException; /** * Used by ensemble to pass into sub-models. */ -class NullInferenceConfig implements InferenceConfig { +public class NullInferenceConfig implements InferenceConfig { public static final NullInferenceConfig INSTANCE = new NullInferenceConfig(); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java index ff03a621d99fa..3bea5ad80ba0c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java @@ -14,12 +14,14 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NullInferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; @@ -135,6 +137,10 @@ public TargetType targetType() { } private InferenceResults buildResults(List processedInferences, InferenceConfig config) { + // Indicates that the config is useless and the caller just wants the raw value + if (config instanceof NullInferenceConfig) { + return new RawInferenceResults(outputAggregator.aggregate(processedInferences)); + } switch(targetType) { case REGRESSION: return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences)); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java index a48cca3873117..7427a7cc70037 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java @@ -15,11 +15,13 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NullInferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -134,6 +136,10 @@ private InferenceResults infer(List features, InferenceConfig config) { } private InferenceResults buildResult(Double value, InferenceConfig config) { + // Indicates that the config is useless and the caller just wants the raw value + if (config instanceof NullInferenceConfig) { + return new RawInferenceResults(value); + } switch (targetType) { case CLASSIFICATION: ClassificationConfig classificationConfig = (ClassificationConfig) config; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResultsTests.java new file mode 100644 index 0000000000000..d9d4e9933b24d --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResultsTests.java @@ -0,0 +1,26 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.results; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +public class RawInferenceResultsTests extends AbstractWireSerializingTestCase { + + public static RawInferenceResults createRandomResults() { + return new RawInferenceResults(randomDouble()); + } + + @Override + protected RawInferenceResults createTestInstance() { + return createRandomResults(); + } + + @Override + protected Writeable.Reader instanceReader() { + return RawInferenceResults::new; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java index 52f317c2595c3..a81c210e33067 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java @@ -322,7 +322,9 @@ public void testClassificationInference() { .setLeftChild(3) .setRightChild(4)) .addNode(TreeNode.builder(3).setLeafValue(0.0)) - .addNode(TreeNode.builder(4).setLeafValue(1.0)).build(); + .addNode(TreeNode.builder(4).setLeafValue(1.0)) + .setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION)) + .build(); Tree tree2 = Tree.builder() .setFeatureNames(featureNames) .setRoot(TreeNode.builder(0) @@ -332,6 +334,7 @@ public void testClassificationInference() { .setThreshold(0.5)) .addNode(TreeNode.builder(1).setLeafValue(0.0)) .addNode(TreeNode.builder(2).setLeafValue(1.0)) + .setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION)) .build(); Tree tree3 = Tree.builder() .setFeatureNames(featureNames) @@ -342,6 +345,7 @@ public void testClassificationInference() { .setThreshold(1.0)) .addNode(TreeNode.builder(1).setLeafValue(1.0)) .addNode(TreeNode.builder(2).setLeafValue(0.0)) + .setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION)) .build(); Ensemble ensemble = Ensemble.builder() .setTargetType(TargetType.CLASSIFICATION)