Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.ingest.IngestDocument;

import java.io.IOException;
import java.util.Objects;

public class RawInferenceResults extends SingleValueInferenceResults {

public static final String NAME = "raw";

public RawInferenceResults(double value) {
super(value);
}

public RawInferenceResults(StreamInput in) throws IOException {
super(in.readDouble());
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
}

@Override
XContentBuilder innerToXContent(XContentBuilder builder, Params params) {
return builder;
}

@Override
public boolean equals(Object object) {
if (object == this) { return true; }
if (object == null || getClass() != object.getClass()) { return false; }
RawInferenceResults that = (RawInferenceResults) object;
return Objects.equals(value(), that.value());
}

@Override
public int hashCode() {
return Objects.hash(value());
}

@Override
public void writeResult(IngestDocument document, String resultField) {
throw new UnsupportedOperationException("[raw] does not support writing inference results");
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public String getName() {
return NAME;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,18 @@
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;

import java.io.IOException;

/**
* Used by ensemble to pass into sub-models.
*/
class NullInferenceConfig implements InferenceConfig {
public class NullInferenceConfig implements InferenceConfig {

public static final NullInferenceConfig INSTANCE = new NullInferenceConfig();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NullInferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
Expand Down Expand Up @@ -135,6 +137,10 @@ public TargetType targetType() {
}

private InferenceResults buildResults(List<Double> processedInferences, InferenceConfig config) {
// Indicates that the config is useless and the caller just wants the raw value
if (config instanceof NullInferenceConfig) {
return new RawInferenceResults(outputAggregator.aggregate(processedInferences));
}
switch(targetType) {
case REGRESSION:
return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RawInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NullInferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
Expand Down Expand Up @@ -134,6 +136,10 @@ private InferenceResults infer(List<Double> features, InferenceConfig config) {
}

private InferenceResults buildResult(Double value, InferenceConfig config) {
// Indicates that the config is useless and the caller just wants the raw value
if (config instanceof NullInferenceConfig) {
return new RawInferenceResults(value);
}
switch (targetType) {
case CLASSIFICATION:
ClassificationConfig classificationConfig = (ClassificationConfig) config;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.results;

import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;

public class RawInferenceResultsTests extends AbstractWireSerializingTestCase<RawInferenceResults> {

public static RawInferenceResults createRandomResults() {
return new RawInferenceResults(randomDouble());
}

@Override
protected RawInferenceResults createTestInstance() {
return createRandomResults();
}

@Override
protected Writeable.Reader<RawInferenceResults> instanceReader() {
return RawInferenceResults::new;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,9 @@ public void testClassificationInference() {
.setLeftChild(3)
.setRightChild(4))
.addNode(TreeNode.builder(3).setLeafValue(0.0))
.addNode(TreeNode.builder(4).setLeafValue(1.0)).build();
.addNode(TreeNode.builder(4).setLeafValue(1.0))
.setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION))
.build();
Tree tree2 = Tree.builder()
.setFeatureNames(featureNames)
.setRoot(TreeNode.builder(0)
Expand All @@ -332,6 +334,7 @@ public void testClassificationInference() {
.setThreshold(0.5))
.addNode(TreeNode.builder(1).setLeafValue(0.0))
.addNode(TreeNode.builder(2).setLeafValue(1.0))
.setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION))
.build();
Tree tree3 = Tree.builder()
.setFeatureNames(featureNames)
Expand All @@ -342,6 +345,7 @@ public void testClassificationInference() {
.setThreshold(1.0))
.addNode(TreeNode.builder(1).setLeafValue(1.0))
.addNode(TreeNode.builder(2).setLeafValue(0.0))
.setTargetType(randomFrom(TargetType.CLASSIFICATION, TargetType.REGRESSION))
.build();
Ensemble ensemble = Ensemble.builder()
.setTargetType(TargetType.CLASSIFICATION)
Expand Down