Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator;
Expand Down Expand Up @@ -471,6 +474,9 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
new NamedWriteableRegistry.Entry(InferenceResults.class,
RegressionInferenceResults.NAME,
RegressionInferenceResults::new),
// ML - Inference Configuration
new NamedWriteableRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME, ClassificationConfig::new),
new NamedWriteableRegistry.Entry(InferenceConfig.class, RegressionConfig.NAME, RegressionConfig::new),

// monitoring
new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.MONITORING, MonitoringFeatureSetUsage::new),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
Expand All @@ -37,32 +38,32 @@ public static class Request extends ActionRequest {
private final String modelId;
private final long modelVersion;
private final List<Map<String, Object>> objectsToInfer;
private final InferenceParams params;
private final InferenceConfig config;

public Request(String modelId, long modelVersion) {
this(modelId, modelVersion, Collections.emptyList(), InferenceParams.EMPTY_PARAMS);
this(modelId, modelVersion, Collections.emptyList(), new RegressionConfig());
}

public Request(String modelId, long modelVersion, List<Map<String, Object>> objectsToInfer, InferenceParams inferenceParams) {
public Request(String modelId, long modelVersion, List<Map<String, Object>> objectsToInfer, InferenceConfig inferenceConfig) {
this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID);
this.modelVersion = modelVersion;
this.objectsToInfer = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(objectsToInfer, "objects_to_infer"));
this.params = inferenceParams == null ? InferenceParams.EMPTY_PARAMS : inferenceParams;
this.config = ExceptionsHelper.requireNonNull(inferenceConfig, "inference_config");
}

public Request(String modelId, long modelVersion, Map<String, Object> objectToInfer, InferenceParams params) {
public Request(String modelId, long modelVersion, Map<String, Object> objectToInfer, InferenceConfig config) {
this(modelId,
modelVersion,
Arrays.asList(ExceptionsHelper.requireNonNull(objectToInfer, "objects_to_infer")),
params);
config);
}

public Request(StreamInput in) throws IOException {
super(in);
this.modelId = in.readString();
this.modelVersion = in.readVLong();
this.objectsToInfer = Collections.unmodifiableList(in.readList(StreamInput::readMap));
this.params = new InferenceParams(in);
this.config = in.readNamedWriteable(InferenceConfig.class);
}

public String getModelId() {
Expand All @@ -77,8 +78,8 @@ public List<Map<String, Object>> getObjectsToInfer() {
return objectsToInfer;
}

public InferenceParams getParams() {
return params;
public InferenceConfig getConfig() {
return config;
}

@Override
Expand All @@ -92,7 +93,7 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(modelId);
out.writeVLong(modelVersion);
out.writeCollection(objectsToInfer, StreamOutput::writeMap);
params.writeTo(out);
out.writeNamedWriteable(config);
}

@Override
Expand All @@ -102,13 +103,13 @@ public boolean equals(Object o) {
InferModelAction.Request that = (InferModelAction.Request) o;
return Objects.equals(modelId, that.modelId)
&& Objects.equals(modelVersion, that.modelVersion)
&& Objects.equals(params, that.params)
&& Objects.equals(config, that.config)
&& Objects.equals(objectsToInfer, that.objectsToInfer);
}

@Override
public int hashCode() {
return Objects.hash(modelId, modelVersion, objectsToInfer, params);
return Objects.hash(modelId, modelVersion, objectsToInfer, config);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
Expand Down Expand Up @@ -111,6 +114,10 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
RegressionInferenceResults.NAME,
RegressionInferenceResults::new));

// Inference Configs
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME, ClassificationConfig::new));
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class, RegressionConfig.NAME, RegressionConfig::new));

return namedWriteables;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
Expand Down Expand Up @@ -125,9 +125,9 @@ private void preProcess(Map<String, Object> fields) {
preProcessors.forEach(preProcessor -> preProcessor.process(fields));
}

public InferenceResults infer(Map<String, Object> fields, InferenceParams params) {
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
preProcess(fields);
return trainedModel.infer(fields, params);
return trainedModel.infer(fields, config);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,26 @@
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Objects;

public class InferenceParams implements ToXContentObject, Writeable {
public class ClassificationConfig implements InferenceConfig {

public static ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");
public static final String NAME = "classification";

public static InferenceParams EMPTY_PARAMS = new InferenceParams(0);
public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes");

public static ClassificationConfig EMPTY_PARAMS = new ClassificationConfig(0);

private final int numTopClasses;

public InferenceParams(Integer numTopClasses) {
public ClassificationConfig(Integer numTopClasses) {
this.numTopClasses = numTopClasses == null ? 0 : numTopClasses;
}

public InferenceParams(StreamInput in) throws IOException {
public ClassificationConfig(StreamInput in) throws IOException {
this.numTopClasses = in.readInt();
}

Expand All @@ -44,7 +44,7 @@ public void writeTo(StreamOutput out) throws IOException {
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
InferenceParams that = (InferenceParams) o;
ClassificationConfig that = (ClassificationConfig) o;
return Objects.equals(numTopClasses, that.numTopClasses);
}

Expand All @@ -62,4 +62,20 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.endObject();
return builder;
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public String getName() {
return NAME;
}

@Override
public boolean isTargetTypeSupported(TargetType targetType) {
return TargetType.CLASSIFICATION.equals(targetType);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.common.io.stream.NamedWriteable;
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;


public interface InferenceConfig extends NamedXContentObject, NamedWriteable {

boolean isTargetTypeSupported(TargetType targetType);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Objects;

public class RegressionConfig implements InferenceConfig {

public static final String NAME = "regression";

public RegressionConfig() {
}

public RegressionConfig(StreamInput in) {
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
}

@Override
public String getName() {
return NAME;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;

return true;
}

@Override
public int hashCode() {
return Objects.hash(NAME);
}

@Override
public boolean isTargetTypeSupported(TargetType targetType) {
return TargetType.REGRESSION.equals(targetType);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable {
* Infer against the provided fields
*
* @param fields The fields and their values to infer against
* @param config The configuration options for inference
* @return The predicted value. For classification this will be discrete values (e.g. 0.0, or 1.0).
* For regression this is continuous.
*/
InferenceResults infer(Map<String, Object> fields, InferenceParams params);
InferenceResults infer(Map<String, Object> fields, InferenceConfig config);

/**
* @return {@link TargetType} for the model.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
Expand Down Expand Up @@ -114,35 +115,35 @@ public List<String> getFeatureNames() {
}

@Override
public InferenceResults infer(Map<String, Object> fields, InferenceParams params) {
if (params.getNumTopClasses() != 0 &&
(targetType != TargetType.CLASSIFICATION || outputAggregator.providesProbabilities() == false)) {
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
if (config.isTargetTypeSupported(targetType) == false) {
throw ExceptionsHelper.badRequestException(
"Cannot return top classes for target_type [{}] and aggregate_output [{}]",
targetType,
outputAggregator.getName());
"Cannot infer using configuration for [{}] when model target_type is [{}]", config.getName(), targetType.toString());
}
List<Double> inferenceResults = this.models.stream().map(model -> {
InferenceResults results = model.infer(fields, InferenceParams.EMPTY_PARAMS);
InferenceResults results = model.infer(fields, NullInferenceConfig.INSTANCE);
assert results instanceof SingleValueInferenceResults;
return ((SingleValueInferenceResults)results).value();
}).collect(Collectors.toList());
List<Double> processed = outputAggregator.processValues(inferenceResults);
return buildResults(processed, params);
return buildResults(processed, config);
}

@Override
public TargetType targetType() {
return targetType;
}

private InferenceResults buildResults(List<Double> processedInferences, InferenceParams params) {
private InferenceResults buildResults(List<Double> processedInferences, InferenceConfig config) {
switch(targetType) {
case REGRESSION:
return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences));
case CLASSIFICATION:
List<ClassificationInferenceResults.TopClassEntry> topClasses =
InferenceHelpers.topClasses(processedInferences, classificationLabels, params.getNumTopClasses());
ClassificationConfig classificationConfig = (ClassificationConfig) config;
List<ClassificationInferenceResults.TopClassEntry> topClasses = InferenceHelpers.topClasses(
processedInferences,
classificationLabels,
classificationConfig.getNumTopClasses());
double value = outputAggregator.aggregate(processedInferences);
return new ClassificationInferenceResults(outputAggregator.aggregate(processedInferences),
classificationLabel(value, classificationLabels),
Expand Down Expand Up @@ -216,6 +217,13 @@ public int hashCode() {

@Override
public void validate() {
if (outputAggregator.compatibleWith(targetType) == false) {
throw ExceptionsHelper.badRequestException(
"aggregate_output [{}] is not compatible with target_type [{}]",
this.targetType,
outputAggregator.getName()
);
}
if (outputAggregator.expectedValueSize() != null &&
outputAggregator.expectedValueSize() != models.size()) {
throw ExceptionsHelper.badRequestException(
Expand Down
Loading