diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/Ensemble.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/Ensemble.java index d16d758769c2b..60ffde6e35f59 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/Ensemble.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/Ensemble.java @@ -29,6 +29,7 @@ import org.elasticsearch.common.xcontent.XContentParser; import java.io.IOException; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Objects; @@ -41,6 +42,7 @@ public class Ensemble implements TrainedModel { public static final ParseField AGGREGATE_OUTPUT = new ParseField("aggregate_output"); public static final ParseField TARGET_TYPE = new ParseField("target_type"); public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels"); + public static final ParseField CLASSIFICATION_WEIGHTS = new ParseField("classification_weights"); private static final ObjectParser PARSER = new ObjectParser<>( NAME, @@ -60,6 +62,7 @@ public class Ensemble implements TrainedModel { AGGREGATE_OUTPUT); PARSER.declareString(Ensemble.Builder::setTargetType, TARGET_TYPE); PARSER.declareStringArray(Ensemble.Builder::setClassificationLabels, CLASSIFICATION_LABELS); + PARSER.declareDoubleArray(Ensemble.Builder::setClassificationWeights, CLASSIFICATION_WEIGHTS); } public static Ensemble fromXContent(XContentParser parser) { @@ -71,17 +74,20 @@ public static Ensemble fromXContent(XContentParser parser) { private final OutputAggregator outputAggregator; private final TargetType targetType; private final List classificationLabels; + private final double[] classificationWeights; Ensemble(List featureNames, List models, @Nullable OutputAggregator outputAggregator, TargetType targetType, - @Nullable List classificationLabels) { + @Nullable List classificationLabels, + @Nullable double[] classificationWeights) { this.featureNames = featureNames; this.models = models; this.outputAggregator = outputAggregator; this.targetType = targetType; this.classificationLabels = classificationLabels; + this.classificationWeights = classificationWeights; } @Override @@ -116,6 +122,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par if (classificationLabels != null) { builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels); } + if (classificationWeights != null) { + builder.field(CLASSIFICATION_WEIGHTS.getPreferredName(), classificationWeights); + } builder.endObject(); return builder; } @@ -129,12 +138,18 @@ public boolean equals(Object o) { && Objects.equals(models, that.models) && Objects.equals(targetType, that.targetType) && Objects.equals(classificationLabels, that.classificationLabels) + && Arrays.equals(classificationWeights, that.classificationWeights) && Objects.equals(outputAggregator, that.outputAggregator); } @Override public int hashCode() { - return Objects.hash(featureNames, models, outputAggregator, classificationLabels, targetType); + return Objects.hash(featureNames, + models, + outputAggregator, + classificationLabels, + targetType, + Arrays.hashCode(classificationWeights)); } public static Builder builder() { @@ -147,6 +162,7 @@ public static class Builder { private OutputAggregator outputAggregator; private TargetType targetType; private List classificationLabels; + private double[] classificationWeights; public Builder setFeatureNames(List featureNames) { this.featureNames = featureNames; @@ -173,6 +189,11 @@ public Builder setClassificationLabels(List classificationLabels) { return this; } + public Builder setClassificationWeights(List classificationWeights) { + this.classificationWeights = classificationWeights.stream().mapToDouble(Double::doubleValue).toArray(); + return this; + } + private void setOutputAggregatorFromParser(List outputAggregators) { this.setOutputAggregator(outputAggregators.get(0)); } @@ -182,7 +203,7 @@ private void setTargetType(String targetType) { } public Ensemble build() { - return new Ensemble(featureNames, trainedModels, outputAggregator, targetType, classificationLabels); + return new Ensemble(featureNames, trainedModels, outputAggregator, targetType, classificationLabels, classificationWeights); } } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java index a26adf5f09c71..9f15a3f8a31b3 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java @@ -80,11 +80,19 @@ public static Ensemble createRandom(TargetType targetType) { if (randomBoolean() && targetType.equals(TargetType.CLASSIFICATION)) { categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false)); } + double[] thresholds = randomBoolean() && targetType == TargetType.CLASSIFICATION ? + Stream.generate(ESTestCase::randomDouble) + .limit(categoryLabels == null ? randomIntBetween(1, 10) : categoryLabels.size()) + .mapToDouble(Double::valueOf) + .toArray() : + null; + return new Ensemble(featureNames, models, outputAggregator, targetType, - categoryLabels); + categoryLabels, + thresholds); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java index 226a9c1104346..39ae4057fd9ca 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java @@ -112,18 +112,26 @@ public static class TopClassEntry implements Writeable { public final ParseField CLASS_NAME = new ParseField("class_name"); public final ParseField CLASS_PROBABILITY = new ParseField("class_probability"); + public final ParseField CLASS_SCORE = new ParseField("class_score"); private final String classification; private final double probability; + private final double score; - public TopClassEntry(String classification, Double probability) { + public TopClassEntry(String classification, double probability) { + this(classification, probability, probability); + } + + public TopClassEntry(String classification, double probability, double score) { this.classification = ExceptionsHelper.requireNonNull(classification, CLASS_NAME); - this.probability = ExceptionsHelper.requireNonNull(probability, CLASS_PROBABILITY); + this.probability = probability; + this.score = score; } public TopClassEntry(StreamInput in) throws IOException { this.classification = in.readString(); this.probability = in.readDouble(); + this.score = in.readDouble(); } public String getClassification() { @@ -134,10 +142,15 @@ public double getProbability() { return probability; } + public double getScore() { + return score; + } + public Map asValueMap() { - Map map = new HashMap<>(2); + Map map = new HashMap<>(3, 1.0f); map.put(CLASS_NAME.getPreferredName(), classification); map.put(CLASS_PROBABILITY.getPreferredName(), probability); + map.put(CLASS_SCORE.getPreferredName(), score); return map; } @@ -145,6 +158,7 @@ public Map asValueMap() { public void writeTo(StreamOutput out) throws IOException { out.writeString(classification); out.writeDouble(probability); + out.writeDouble(score); } @Override @@ -152,13 +166,12 @@ public boolean equals(Object object) { if (object == this) { return true; } if (object == null || getClass() != object.getClass()) { return false; } TopClassEntry that = (TopClassEntry) object; - return Objects.equals(classification, that.classification) && - Objects.equals(probability, that.probability); + return Objects.equals(classification, that.classification) && probability == that.probability && score == that.score; } @Override public int hashCode() { - return Objects.hash(classification, probability); + return Objects.hash(classification, probability, score); } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java index 86bf076cd6bf1..ae5a4062a69dc 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -20,17 +21,13 @@ public final class InferenceHelpers { private InferenceHelpers() { } - public static List topClasses(List probabilities, - List classificationLabels, - int numToInclude) { - if (numToInclude == 0) { - return Collections.emptyList(); - } - int[] sortedIndices = IntStream.range(0, probabilities.size()) - .boxed() - .sorted(Comparator.comparing(probabilities::get).reversed()) - .mapToInt(i -> i) - .toArray(); + /** + * @return Tuple of the highest scored index and the top classes + */ + public static Tuple> topClasses(List probabilities, + List classificationLabels, + @Nullable double[] classificationWeights, + int numToInclude) { if (classificationLabels != null && probabilities.size() != classificationLabels.size()) { throw ExceptionsHelper @@ -38,7 +35,24 @@ public static List topClasses(List "model returned classification probabilities of size [{}] which is not equal to classification labels size [{}]", null, probabilities.size(), - classificationLabels); + classificationLabels.size()); + } + + List scores = classificationWeights == null ? + probabilities : + IntStream.range(0, probabilities.size()) + .mapToDouble(i -> probabilities.get(i) * classificationWeights[i]) + .boxed() + .collect(Collectors.toList()); + + int[] sortedIndices = IntStream.range(0, probabilities.size()) + .boxed() + .sorted(Comparator.comparing(scores::get).reversed()) + .mapToInt(i -> i) + .toArray(); + + if (numToInclude == 0) { + return Tuple.tuple(sortedIndices[0], Collections.emptyList()); } List labels = classificationLabels == null ? @@ -50,26 +64,24 @@ public static List topClasses(List List topClassEntries = new ArrayList<>(count); for(int i = 0; i < count; i++) { int idx = sortedIndices[i]; - topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(labels.get(idx), probabilities.get(idx))); + topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(labels.get(idx), probabilities.get(idx), scores.get(idx))); } - return topClassEntries; + return Tuple.tuple(sortedIndices[0], topClassEntries); } - public static String classificationLabel(double inferenceValue, @Nullable List classificationLabels) { - assert inferenceValue == Math.rint(inferenceValue); + public static String classificationLabel(Integer inferenceValue, @Nullable List classificationLabels) { if (classificationLabels == null) { return String.valueOf(inferenceValue); } - int label = Double.valueOf(inferenceValue).intValue(); - if (label < 0 || label >= classificationLabels.size()) { + if (inferenceValue < 0 || inferenceValue >= classificationLabels.size()) { throw ExceptionsHelper.serverError( "model returned classification value of [{}] which is not a valid index in classification labels [{}]", null, - label, + inferenceValue, classificationLabels); } - return classificationLabels.get(label); + return classificationLabels.get(inferenceValue); } public static Double toDouble(Object value) { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java index e206a70918096..4bbca5ed0b1d5 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java @@ -6,21 +6,14 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; import org.apache.lucene.util.Accountable; -import org.elasticsearch.common.Nullable; import org.elasticsearch.common.io.stream.NamedWriteable; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; -import java.util.List; import java.util.Map; public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accountable { - /** - * @return List of featureNames expected by the model. In the order that they are expected - */ - List getFeatureNames(); - /** * Infer against the provided fields * @@ -36,12 +29,6 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accou */ TargetType targetType(); - /** - * @return Ordinal encoded list of classification labels. - */ - @Nullable - List classificationLabels(); - /** * Runs validations against the model. * diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java index c9bde54c460cd..a455730ae8208 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java @@ -10,6 +10,7 @@ import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ObjectParser; @@ -33,6 +34,7 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.List; @@ -53,6 +55,7 @@ public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrai public static final ParseField AGGREGATE_OUTPUT = new ParseField("aggregate_output"); public static final ParseField TARGET_TYPE = new ParseField("target_type"); public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels"); + public static final ParseField CLASSIFICATION_WEIGHTS = new ParseField("classification_weights"); private static final ObjectParser LENIENT_PARSER = createParser(true); private static final ObjectParser STRICT_PARSER = createParser(false); @@ -77,6 +80,7 @@ private static ObjectParser createParser(boolean lenient AGGREGATE_OUTPUT); parser.declareString(Ensemble.Builder::setTargetType, TARGET_TYPE); parser.declareStringArray(Ensemble.Builder::setClassificationLabels, CLASSIFICATION_LABELS); + parser.declareDoubleArray(Ensemble.Builder::setClassificationWeights, CLASSIFICATION_WEIGHTS); return parser; } @@ -93,17 +97,22 @@ public static Ensemble fromXContentLenient(XContentParser parser) { private final OutputAggregator outputAggregator; private final TargetType targetType; private final List classificationLabels; + private final double[] classificationWeights; Ensemble(List featureNames, List models, OutputAggregator outputAggregator, TargetType targetType, - @Nullable List classificationLabels) { + @Nullable List classificationLabels, + @Nullable double[] classificationWeights) { this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES)); this.models = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(models, TRAINED_MODELS)); this.outputAggregator = ExceptionsHelper.requireNonNull(outputAggregator, AGGREGATE_OUTPUT); this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE); this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels); + this.classificationWeights = classificationWeights == null ? + null : + Arrays.copyOf(classificationWeights, classificationWeights.length); } public Ensemble(StreamInput in) throws IOException { @@ -116,11 +125,11 @@ public Ensemble(StreamInput in) throws IOException { } else { this.classificationLabels = null; } - } - - @Override - public List getFeatureNames() { - return featureNames; + if (in.readBoolean()) { + this.classificationWeights = in.readDoubleArray(); + } else { + this.classificationWeights = null; + } } @Override @@ -153,25 +162,22 @@ private InferenceResults buildResults(List processedInferences, Inferenc return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences), config); case CLASSIFICATION: ClassificationConfig classificationConfig = (ClassificationConfig) config; - List topClasses = InferenceHelpers.topClasses( + assert classificationWeights == null || processedInferences.size() == classificationWeights.length; + // Adjust the probabilities according to the thresholds + Tuple> topClasses = InferenceHelpers.topClasses( processedInferences, classificationLabels, + classificationWeights, classificationConfig.getNumTopClasses()); - double value = outputAggregator.aggregate(processedInferences); - return new ClassificationInferenceResults(outputAggregator.aggregate(processedInferences), - classificationLabel(value, classificationLabels), - topClasses, + return new ClassificationInferenceResults((double)topClasses.v1(), + classificationLabel(topClasses.v1(), classificationLabels), + topClasses.v2(), config); default: throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on ensemble model"); } } - @Override - public List classificationLabels() { - return classificationLabels; - } - @Override public String getWriteableName() { return NAME.getPreferredName(); @@ -187,6 +193,10 @@ public void writeTo(StreamOutput out) throws IOException { if (classificationLabels != null) { out.writeStringCollection(classificationLabels); } + out.writeBoolean(classificationWeights != null); + if (classificationWeights != null) { + out.writeDoubleArray(classificationWeights); + } } @Override @@ -208,6 +218,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (classificationLabels != null) { builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels); } + if (classificationWeights != null) { + builder.field(CLASSIFICATION_WEIGHTS.getPreferredName(), classificationWeights); + } builder.endObject(); return builder; } @@ -221,12 +234,18 @@ public boolean equals(Object o) { && Objects.equals(models, that.models) && Objects.equals(targetType, that.targetType) && Objects.equals(classificationLabels, that.classificationLabels) - && Objects.equals(outputAggregator, that.outputAggregator); + && Objects.equals(outputAggregator, that.outputAggregator) + && Arrays.equals(classificationWeights, that.classificationWeights); } @Override public int hashCode() { - return Objects.hash(featureNames, models, outputAggregator, targetType, classificationLabels); + return Objects.hash(featureNames, + models, + outputAggregator, + targetType, + classificationLabels, + Arrays.hashCode(classificationWeights)); } @Override @@ -246,9 +265,16 @@ public void validate() { outputAggregator.expectedValueSize(), models.size()); } - if ((this.targetType == TargetType.CLASSIFICATION) != (this.classificationLabels != null)) { + if ((this.classificationLabels != null || this.classificationWeights != null) && (this.targetType != TargetType.CLASSIFICATION)) { throw ExceptionsHelper.badRequestException( - "[target_type] should be [classification] if [classification_labels] is provided, and vice versa"); + "[target_type] should be [classification] if [classification_labels] or [classification_weights] are provided"); + } + if (classificationWeights != null && + classificationLabels != null && + classificationWeights.length != classificationLabels.size()) { + throw ExceptionsHelper.badRequestException( + "[classification_weights] and [classification_labels] should be the same length if both are provided" + ); } this.models.forEach(TrainedModel::validate); } @@ -271,6 +297,9 @@ public long ramBytesUsed() { size += RamUsageEstimator.sizeOfCollection(featureNames); size += RamUsageEstimator.sizeOfCollection(classificationLabels); size += RamUsageEstimator.sizeOfCollection(models); + if (classificationWeights != null) { + size += RamUsageEstimator.sizeOf(classificationWeights); + } size += outputAggregator.ramBytesUsed(); return size; } @@ -291,6 +320,7 @@ public static class Builder { private OutputAggregator outputAggregator = new WeightedSum(); private TargetType targetType = TargetType.REGRESSION; private List classificationLabels; + private double[] classificationWeights; private boolean modelsAreOrdered; private Builder (boolean modelsAreOrdered) { @@ -330,6 +360,11 @@ public Builder setClassificationLabels(List classificationLabels) { return this; } + public Builder setClassificationWeights(List classificationWeights) { + this.classificationWeights = classificationWeights.stream().mapToDouble(Double::doubleValue).toArray(); + return this; + } + private void setOutputAggregatorFromParser(List outputAggregators) { if (outputAggregators.size() != 1) { throw ExceptionsHelper.badRequestException("[{}] must have exactly one aggregator defined.", @@ -352,7 +387,7 @@ public Ensemble build() { if (modelsAreOrdered == false && trainedModels != null && trainedModels.size() > 1) { throw ExceptionsHelper.badRequestException("[trained_models] needs to be an array of objects"); } - return new Ensemble(featureNames, trainedModels, outputAggregator, targetType, classificationLabels); + return new Ensemble(featureNames, trainedModels, outputAggregator, targetType, classificationLabels, classificationWeights); } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java index b9edeb7885504..7de2c8f060500 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.java @@ -8,6 +8,7 @@ import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.ConstructingObjectParser; @@ -25,13 +26,10 @@ import java.io.IOException; import java.util.Arrays; -import java.util.Collections; -import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.stream.Collectors; -import java.util.stream.IntStream; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.xpack.core.ml.inference.utils.Statistics.softMax; @@ -105,11 +103,6 @@ public LangIdentNeuralNetwork(StreamInput in) throws IOException { this.softmaxLayer = new LangNetLayer(in); } - @Override - public List getFeatureNames() { - return Collections.singletonList(embeddedVectorFeatureName); - } - @Override public InferenceResults infer(Map fields, InferenceConfig config) { if (config instanceof ClassificationConfig == false) { @@ -134,20 +127,17 @@ public InferenceResults infer(Map fields, InferenceConfig config List probabilities = softMax(Arrays.stream(scores).boxed().collect(Collectors.toList())); - int maxIndex = IntStream.range(0, probabilities.size()) - .boxed() - .max(Comparator.comparing(probabilities::get)) - .orElseThrow(() -> ExceptionsHelper.serverError("Unexpected null value while searching for max probability")); - - assert maxIndex >= 0 && maxIndex < LANGUAGE_NAMES.size() : "Invalid language predicted. Predicted language index " + maxIndex; ClassificationConfig classificationConfig = (ClassificationConfig) config; - List topClasses = InferenceHelpers.topClasses( + Tuple> topClasses = InferenceHelpers.topClasses( probabilities, LANGUAGE_NAMES, + null, classificationConfig.getNumTopClasses()); - return new ClassificationInferenceResults(maxIndex, - LANGUAGE_NAMES.get(maxIndex), - topClasses, + assert topClasses.v1() >= 0 && topClasses.v1() < LANGUAGE_NAMES.size() : + "Invalid language predicted. Predicted language index " + topClasses.v1(); + return new ClassificationInferenceResults(topClasses.v1(), + LANGUAGE_NAMES.get(topClasses.v1()), + topClasses.v2(), classificationConfig); } @@ -156,11 +146,6 @@ public TargetType targetType() { return TargetType.CLASSIFICATION; } - @Override - public List classificationLabels() { - return LANGUAGE_NAMES; - } - @Override public void validate() { } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java index 831838e0f7df2..527307597a597 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java @@ -10,6 +10,7 @@ import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.util.CachedSupplier; @@ -114,11 +115,6 @@ public String getName() { return NAME.getPreferredName(); } - @Override - public List getFeatureNames() { - return featureNames; - } - public List getNodes() { return nodes; } @@ -152,11 +148,15 @@ private InferenceResults buildResult(Double value, InferenceConfig config) { switch (targetType) { case CLASSIFICATION: ClassificationConfig classificationConfig = (ClassificationConfig) config; - List topClasses = InferenceHelpers.topClasses( + Tuple> topClasses = InferenceHelpers.topClasses( classificationProbability(value), classificationLabels, + null, classificationConfig.getNumTopClasses()); - return new ClassificationInferenceResults(value, classificationLabel(value, classificationLabels), topClasses, config); + return new ClassificationInferenceResults(value, + classificationLabel(topClasses.v1(), classificationLabels), + topClasses.v2(), + config); case REGRESSION: return new RegressionInferenceResults(value, config); default: @@ -197,11 +197,6 @@ private List classificationProbability(double inferenceValue) { return list; } - @Override - public List classificationLabels() { - return classificationLabels; - } - @Override public String getWriteableName() { return NAME.getPreferredName(); @@ -270,9 +265,9 @@ public long estimatedNumOperations() { } private void checkTargetType() { - if ((this.targetType == TargetType.CLASSIFICATION) != (this.classificationLabels != null)) { + if (this.classificationLabels != null && this.targetType != TargetType.CLASSIFICATION) { throw ExceptionsHelper.badRequestException( - "[target_type] should be [classification] if [classification_labels] is provided, and vice versa"); + "[target_type] should be [classification] if [classification_labels] are provided"); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java index 46373dae834c9..ba2926e5050b3 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java @@ -8,10 +8,8 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.search.SearchModule; import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; @@ -28,7 +26,6 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -77,16 +74,24 @@ public static Ensemble createRandom() { OutputAggregator outputAggregator = randomFrom(new WeightedMode(weights), new WeightedSum(weights), new LogisticRegression(weights)); + TargetType targetType = randomFrom(TargetType.values()); List categoryLabels = null; - if (randomBoolean()) { + if (randomBoolean() && targetType == TargetType.CLASSIFICATION) { categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false)); } + double[] thresholds = randomBoolean() && targetType == TargetType.CLASSIFICATION ? + Stream.generate(ESTestCase::randomDouble) + .limit(categoryLabels == null ? randomIntBetween(1, 10) : categoryLabels.size()) + .mapToDouble(Double::valueOf) + .toArray() : + null; return new Ensemble(featureNames, models, outputAggregator, - randomFrom(TargetType.values()), - categoryLabels); + targetType, + categoryLabels, + thresholds); } @Override @@ -101,17 +106,12 @@ protected Writeable.Reader instanceReader() { @Override protected NamedXContentRegistry xContentRegistry() { - List namedXContent = new ArrayList<>(); - namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); - namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); - return new NamedXContentRegistry(namedXContent); + return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); } @Override protected NamedWriteableRegistry getNamedWriteableRegistry() { - List entries = new ArrayList<>(); - entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); - return new NamedWriteableRegistry(entries); + return new NamedWriteableRegistry(new MlInferenceNamedXContentProvider().getNamedWriteables()); } public void testEnsembleWithAggregatedOutputDifferingFromTrainedModels() { @@ -184,16 +184,15 @@ public void testEnsembleWithAggregatorOutputNotSupportingTargetType() { public void testEnsembleWithTargetTypeAndLabelsMismatch() { List featureNames = Arrays.asList("foo", "bar"); - String msg = "[target_type] should be [classification] if [classification_labels] is provided, and vice versa"; + String msg = "[target_type] should be [classification] if " + + "[classification_labels] or [classification_weights] are provided"; ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> { Ensemble.builder() .setFeatureNames(featureNames) .setTrainedModels(Arrays.asList( Tree.builder() .setNodes(TreeNode.builder(0) - .setLeftChild(1) - .setSplitFeature(1) - .setThreshold(randomDouble())) + .setLeafValue(randomDouble())) .setFeatureNames(featureNames) .build())) .setClassificationLabels(Arrays.asList("label1", "label2")) @@ -201,23 +200,6 @@ public void testEnsembleWithTargetTypeAndLabelsMismatch() { .validate(); }); assertThat(ex.getMessage(), equalTo(msg)); - ex = expectThrows(ElasticsearchException.class, () -> { - Ensemble.builder() - .setFeatureNames(featureNames) - .setTrainedModels(Arrays.asList( - Tree.builder() - .setNodes(TreeNode.builder(0) - .setLeftChild(1) - .setSplitFeature(1) - .setThreshold(randomDouble())) - .setFeatureNames(featureNames) - .build())) - .setTargetType(TargetType.CLASSIFICATION) - .setOutputAggregator(new WeightedMode()) - .build() - .validate(); - }); - assertThat(ex.getMessage(), equalTo(msg)); } public void testClassificationProbability() { @@ -262,34 +244,41 @@ public void testClassificationProbability() { .setFeatureNames(featureNames) .setTrainedModels(Arrays.asList(tree1, tree2, tree3)) .setOutputAggregator(new WeightedMode(new double[]{0.7, 0.5, 1.0})) + .setClassificationWeights(Arrays.asList(0.7, 0.3)) .build(); List featureVector = Arrays.asList(0.4, 0.0); Map featureMap = zipObjMap(featureNames, featureVector); List expected = Arrays.asList(0.768524783, 0.231475216); + List scores = Arrays.asList(0.230557435, 0.162032651); double eps = 0.000001; List probabilities = ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); for(int i = 0; i < expected.size(); i++) { assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps)); + assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps)); } featureVector = Arrays.asList(2.0, 0.7); featureMap = zipObjMap(featureNames, featureVector); - expected = Arrays.asList(0.689974481, 0.3100255188); + expected = Arrays.asList(0.310025518, 0.6899744811); + scores = Arrays.asList(0.217017863, 0.2069923443); probabilities = ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); for(int i = 0; i < expected.size(); i++) { assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps)); + assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps)); } featureVector = Arrays.asList(0.0, 1.0); featureMap = zipObjMap(featureNames, featureVector); expected = Arrays.asList(0.768524783, 0.231475216); + scores = Arrays.asList(0.230557435, 0.162032651); probabilities = ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); for(int i = 0; i < expected.size(); i++) { assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps)); + assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps)); } // This should handle missing values and take the default_left path @@ -298,10 +287,12 @@ public void testClassificationProbability() { put("bar", null); }}; expected = Arrays.asList(0.6899744811, 0.3100255188); + scores = Arrays.asList(0.482982136, 0.0930076556); probabilities = ((ClassificationInferenceResults)ensemble.infer(featureMap, new ClassificationConfig(2))).getTopClasses(); for(int i = 0; i < expected.size(); i++) { assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps)); + assertThat(probabilities.get(i).getScore(), closeTo(scores.get(i), eps)); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java index 8c05c8d7b9d3a..123a298b1d3a9 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java @@ -93,14 +93,13 @@ public static Tree buildRandomTree(List featureNames, int depth) { } childNodes = nextNodes; } + TargetType targetType = randomFrom(TargetType.values()); List categoryLabels = null; - if (randomBoolean()) { + if (randomBoolean() && targetType == TargetType.CLASSIFICATION) { categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false)); } - return builder.setTargetType(randomFrom(TargetType.values())) - .setClassificationLabels(categoryLabels) - .build(); + return builder.setTargetType(targetType).setClassificationLabels(categoryLabels).build(); } @Override @@ -325,7 +324,7 @@ public void testTreeWithCycle() { public void testTreeWithTargetTypeAndLabelsMismatch() { List featureNames = Arrays.asList("foo", "bar"); - String msg = "[target_type] should be [classification] if [classification_labels] is provided, and vice versa"; + String msg = "[target_type] should be [classification] if [classification_labels] are provided"; ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> { Tree.builder() .setRoot(TreeNode.builder(0) @@ -338,18 +337,6 @@ public void testTreeWithTargetTypeAndLabelsMismatch() { .validate(); }); assertThat(ex.getMessage(), equalTo(msg)); - ex = expectThrows(ElasticsearchException.class, () -> { - Tree.builder() - .setRoot(TreeNode.builder(0) - .setLeftChild(1) - .setSplitFeature(1) - .setThreshold(randomDouble())) - .setFeatureNames(featureNames) - .setTargetType(TargetType.CLASSIFICATION) - .build() - .validate(); - }); - assertThat(ex.getMessage(), equalTo(msg)); } public void testOperationsEstimations() { diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java index 216eac723115f..9320d393a4f66 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java @@ -470,7 +470,7 @@ private Map generateSourceDoc() { " },\n" + " {\n" + " \"node_index\": 2,\n" + - " \"leaf_value\": 2\n" + + " \"leaf_value\": 0\n" + " }\n" + " ],\n" + " \"target_type\": \"regression\"\n" + @@ -500,7 +500,7 @@ private Map generateSourceDoc() { " },\n" + " {\n" + " \"node_index\": 2,\n" + - " \"leaf_value\": 2\n" + + " \"leaf_value\": 0\n" + " }\n" + " ],\n" + " \"target_type\": \"regression\"\n" + diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java index faa1717cc4fde..512baa7624a68 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java @@ -56,7 +56,7 @@ public void testClassificationInfer() throws Exception { SingleValueInferenceResults result = getSingleValue(model, fields, new ClassificationConfig(0)); assertThat(result.value(), equalTo(0.0)); - assertThat(result.valueAsString(), is("0.0")); + assertThat(result.valueAsString(), is("0")); ClassificationInferenceResults classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new ClassificationConfig(1));