From dadfec15083f8e410f95a606eb7546ecfad5eb3a Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Mon, 30 Sep 2019 17:02:42 -0400 Subject: [PATCH 01/13] [ML][Inference] adds lazy model loader and inference --- .../trainedmodel/ensemble/Ensemble.java | 4 +- .../xpack/ml/MachineLearning.java | 11 +- .../ml/action/TransportInferModelAction.java | 71 +++++ .../ml/inference/action/InferModelAction.java | 168 ++++++++++ .../inference/loadingservice/LocalModel.java | 92 ++++++ .../ml/inference/loadingservice/Model.java | 17 + .../loadingservice/ModelLoadingService.java | 115 +++++++ .../action/InferModelActionRequestTests.java | 33 ++ .../action/InferModelActionResponseTests.java | 35 +++ .../loadingservice/LocalModelTests.java | 207 +++++++++++++ .../ModelLoadingServiceTests.java | 104 +++++++ .../integration/ModelInferenceActionIT.java | 292 ++++++++++++++++++ 12 files changed, 1146 insertions(+), 3 deletions(-) create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/InferModelAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/InferModelActionRequestTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/InferModelActionResponseTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java index 7f2a7cc9a02ce..add3a8d7c151c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java @@ -107,7 +107,7 @@ public List getFeatureNames() { @Override public double infer(Map fields) { - List features = featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList()); + List features = featureNames.stream().map(f -> ((Number)fields.get(f)).doubleValue()).collect(Collectors.toList()); return infer(features); } @@ -128,7 +128,7 @@ public List classificationProbability(Map fields) { throw new UnsupportedOperationException( "Cannot determine classification probability with target_type [" + targetType.toString() + "]"); } - List features = featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList()); + List features = featureNames.stream().map(f -> ((Number)fields.get(f)).doubleValue()).collect(Collectors.toList()); return classificationProbability(features); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 9d1fdc762ea80..4296e46e232a7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -161,6 +161,7 @@ import org.elasticsearch.xpack.ml.action.TransportGetModelSnapshotsAction; import org.elasticsearch.xpack.ml.action.TransportGetOverallBucketsAction; import org.elasticsearch.xpack.ml.action.TransportGetRecordsAction; +import org.elasticsearch.xpack.ml.action.TransportInferModelAction; import org.elasticsearch.xpack.ml.action.TransportIsolateDatafeedAction; import org.elasticsearch.xpack.ml.action.TransportKillProcessAction; import org.elasticsearch.xpack.ml.action.TransportMlInfoAction; @@ -200,7 +201,10 @@ import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationProcessFactory; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult; +import org.elasticsearch.xpack.ml.inference.action.InferModelAction; +import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import org.elasticsearch.xpack.ml.inference.persistence.InferenceInternalIndex; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.job.JobManager; import org.elasticsearch.xpack.ml.job.JobManagerHolder; import org.elasticsearch.xpack.ml.job.UpdateJobProcessNotifier; @@ -495,6 +499,8 @@ public Collection createComponents(Client client, ClusterService cluster notifier, xContentRegistry); + final TrainedModelProvider trainedModelProvider = new TrainedModelProvider(client, xContentRegistry); + final ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider); // special holder for @link(MachineLearningFeatureSetUsage) which needs access to job manager if ML is enabled JobManagerHolder jobManagerHolder = new JobManagerHolder(jobManager); @@ -607,7 +613,9 @@ public Collection createComponents(Client client, ClusterService cluster analyticsProcessManager, memoryEstimationProcessManager, dataFrameAnalyticsConfigProvider, - nativeStorageProvider + nativeStorageProvider, + modelLoadingService, + trainedModelProvider ); } @@ -762,6 +770,7 @@ public List getRestHandlers(Settings settings, RestController restC new ActionHandler<>(StopDataFrameAnalyticsAction.INSTANCE, TransportStopDataFrameAnalyticsAction.class), new ActionHandler<>(EvaluateDataFrameAction.INSTANCE, TransportEvaluateDataFrameAction.class), new ActionHandler<>(EstimateMemoryUsageAction.INSTANCE, TransportEstimateMemoryUsageAction.class), + new ActionHandler<>(InferModelAction.INSTANCE, TransportInferModelAction.class), usageAction, infoAction); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java new file mode 100644 index 0000000000000..b5d33758ce260 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java @@ -0,0 +1,71 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.action; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.client.Client; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.ml.inference.action.InferModelAction; +import org.elasticsearch.xpack.ml.inference.loadingservice.Model; +import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; +import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor; + +import java.util.List; + +public class TransportInferModelAction extends HandledTransportAction { + + private final ModelLoadingService modelLoadingService; + private final Client client; + + @Inject + public TransportInferModelAction(String actionName, + TransportService transportService, + ActionFilters actionFilters, + ModelLoadingService modelLoadingService, + Client client) { + super(actionName, transportService, actionFilters, InferModelAction.Request::new); + this.modelLoadingService = modelLoadingService; + this.client = client; + } + + @Override + protected void doExecute(Task task, InferModelAction.Request request, ActionListener listener) { + + ActionListener> inferenceCompleteListener = ActionListener.wrap( + inferenceResponse -> listener.onResponse(new InferModelAction.Response(inferenceResponse)), + listener::onFailure + ); + + ActionListener getModelListener = ActionListener.wrap( + model -> { + TypedChainTaskExecutor typedChainTaskExecutor = + new TypedChainTaskExecutor<>(client.threadPool().executor(ThreadPool.Names.SAME), + // run through all tasks + r -> true, + // Always fail immediately and return an error + ex -> true); + if (request.getTopClasses() != null) { + request.getObjectsToInfer().forEach(stringObjectMap -> + typedChainTaskExecutor.add(chainedTask -> model.confidence(stringObjectMap, request.getTopClasses(), chainedTask)) + ); + } else { + request.getObjectsToInfer().forEach(stringObjectMap -> + typedChainTaskExecutor.add(chainedTask -> model.infer(stringObjectMap, chainedTask)) + ); + } + typedChainTaskExecutor.execute(inferenceCompleteListener); + }, + listener::onFailure + ); + + this.modelLoadingService.getModelAndCache(request.getModelId(), request.getModelVersion(), getModelListener); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/InferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/InferModelAction.java new file mode 100644 index 0000000000000..134243343f32e --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/InferModelAction.java @@ -0,0 +1,168 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.action; + +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionRequestBuilder; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.client.ElasticsearchClient; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +public class InferModelAction extends ActionType { + + public static final InferModelAction INSTANCE = new InferModelAction(); + public static final String NAME = "cluster:admin/xpack/ml/infer"; + + private InferModelAction() { + super(NAME, Response::new); + } + + public static class Request extends ActionRequest { + + private final String modelId; + private final long modelVersion; + private final List> objectsToInfer; + private final boolean cacheModel; + private final Integer topClasses; + + public Request(String modelId, long modelVersion) { + this(modelId, modelVersion, Collections.emptyList(), null); + } + + public Request(String modelId, long modelVersion, List> objectsToInfer, Integer topClasses) { + this.modelId = modelId; + this.modelVersion = modelVersion; + this.objectsToInfer = objectsToInfer == null ? Collections.emptyList() : + Collections.unmodifiableList(new ArrayList<>(objectsToInfer)); + this.cacheModel = true; + this.topClasses = topClasses; + } + + public Request(String modelId, long modelVersion, Map objectToInfer, Integer topClasses) { + this(modelId, + modelVersion, + objectToInfer == null ? Collections.emptyList() : Collections.singletonList(objectToInfer), + topClasses); + } + + public Request(StreamInput in) throws IOException { + super(in); + this.modelId = in.readString(); + this.modelVersion = in.readVLong(); + this.objectsToInfer = Collections.unmodifiableList(in.readList(StreamInput::readMap)); + this.topClasses = in.readOptionalInt(); + this.cacheModel = in.readBoolean(); + } + + public String getModelId() { + return modelId; + } + + public long getModelVersion() { + return modelVersion; + } + + public List> getObjectsToInfer() { + return objectsToInfer; + } + + public boolean isCacheModel() { + return cacheModel; + } + + public Integer getTopClasses() { + return topClasses; + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(modelId); + out.writeVLong(modelVersion); + out.writeCollection(objectsToInfer, StreamOutput::writeMap); + out.writeOptionalInt(topClasses); + out.writeBoolean(cacheModel); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + InferModelAction.Request that = (InferModelAction.Request) o; + return Objects.equals(modelId, that.modelId) + && Objects.equals(modelVersion, that.modelVersion) + && Objects.equals(topClasses, that.topClasses) + && Objects.equals(cacheModel, that.cacheModel) + && Objects.equals(objectsToInfer, that.objectsToInfer); + } + + @Override + public int hashCode() { + return Objects.hash(modelId, modelVersion, objectsToInfer, topClasses, cacheModel); + } + + } + + public static class RequestBuilder extends ActionRequestBuilder { + public RequestBuilder(ElasticsearchClient client, Request request) { + super(client, INSTANCE, request); + } + } + + public static class Response extends ActionResponse { + + // TODO come up with a better union type object + private final List inferenceResponse; + + public Response(List inferenceResponse) { + super(); + this.inferenceResponse = Collections.unmodifiableList(inferenceResponse); + } + + public Response(StreamInput in) throws IOException { + super(in); + this.inferenceResponse = Collections.unmodifiableList(in.readList(StreamInput::readGenericValue)); + } + + public List getInferenceResponse() { + return inferenceResponse; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeCollection(inferenceResponse, StreamOutput::writeGenericValue); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + InferModelAction.Response that = (InferModelAction.Response) o; + return Objects.equals(inferenceResponse, that.inferenceResponse); + } + + @Override + public int hashCode() { + return Objects.hash(inferenceResponse); + } + + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java new file mode 100644 index 0000000000000..7ef8ba121dab1 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java @@ -0,0 +1,92 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.loadingservice; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public class LocalModel implements Model { + + private final TrainedModelDefinition trainedModelDefinition; + public LocalModel(TrainedModelDefinition trainedModelDefinition) { + this.trainedModelDefinition = trainedModelDefinition; + } + + @Override + public void infer(Map fields, ActionListener listener) { + trainedModelDefinition.getPreProcessors().forEach(preProcessor -> preProcessor.process(fields)); + double value = trainedModelDefinition.getTrainedModel().infer(fields); + if (trainedModelDefinition.getTrainedModel().targetType() == TargetType.CLASSIFICATION && + trainedModelDefinition.getTrainedModel().classificationLabels() != null) { + assert value == Math.rint(value); + int classIndex = Double.valueOf(value).intValue(); + if (classIndex < 0 || classIndex >= trainedModelDefinition.getTrainedModel().classificationLabels().size()) { + listener.onFailure(new ElasticsearchStatusException("model returned classification [{}] which is invalid given labels {}", + RestStatus.INTERNAL_SERVER_ERROR, + classIndex, + trainedModelDefinition.getTrainedModel().classificationLabels())); + return; + } + listener.onResponse(trainedModelDefinition.getTrainedModel().classificationLabels().get(classIndex)); + return; + } + listener.onResponse(Double.valueOf(value)); + } + + @Override + public void confidence(Map fields, int topN, ActionListener listener) { + if (topN == 0) { + listener.onResponse(Collections.emptyMap()); + return; + } + if (trainedModelDefinition.getTrainedModel().targetType() != TargetType.CLASSIFICATION) { + listener.onFailure(ExceptionsHelper + .badRequestException("top result probabilities is only available for classification models")); + return; + } + trainedModelDefinition.getPreProcessors().forEach(preProcessor -> preProcessor.process(fields)); + List probabilities = trainedModelDefinition.getTrainedModel().classificationProbability(fields); + int[] sortedIndices = IntStream.range(0, probabilities.size()) + .boxed() + .sorted(Comparator.comparing(probabilities::get).reversed()) + .mapToInt(i -> i) + .toArray(); + if (trainedModelDefinition.getTrainedModel().classificationLabels() != null) { + if (probabilities.size() != trainedModelDefinition.getTrainedModel().classificationLabels().size()) { + listener.onFailure(ExceptionsHelper + .badRequestException( + "model returned classification probabilities of size [{}] which is not equal to classification labels size [{}]", + probabilities.size(), + trainedModelDefinition.getTrainedModel().classificationLabels())); + return; + } + } + List labels = trainedModelDefinition.getTrainedModel().classificationLabels() == null ? + // If we don't have the labels we should return the top classification values anyways, they will just be numeric + IntStream.range(0, probabilities.size()).boxed().map(String::valueOf).collect(Collectors.toList()) : + trainedModelDefinition.getTrainedModel().classificationLabels(); + + int count = topN < 0 ? probabilities.size() : topN; + Map probabilityMap = new HashMap<>(count); + for(int i = 0; i < count; i++) { + int idx = sortedIndices[i]; + probabilityMap.put(labels.get(idx), probabilities.get(idx)); + } + listener.onResponse(probabilityMap); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java new file mode 100644 index 0000000000000..ea6a8022eacd4 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java @@ -0,0 +1,17 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.loadingservice; + +import org.elasticsearch.action.ActionListener; + +import java.util.Map; + +public interface Model { + + void infer(Map fields, ActionListener listener); + + void confidence(Map fields, int topN, ActionListener listener); +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java new file mode 100644 index 0000000000000..aae17c90e21fb --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -0,0 +1,115 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.loadingservice; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.ClusterChangedEvent; +import org.elasticsearch.cluster.ClusterStateListener; +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; + +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.Queue; +import java.util.concurrent.ConcurrentHashMap; + +public class ModelLoadingService implements ClusterStateListener { + + private static final Logger logger = LogManager.getLogger(ModelLoadingService.class); + private final ConcurrentHashMap loadedModels = new ConcurrentHashMap<>(); + private final ConcurrentHashMap>> loadingListeners = new ConcurrentHashMap<>(); + private final TrainedModelProvider provider; + + public ModelLoadingService(TrainedModelProvider trainedModelProvider) { + this.provider = trainedModelProvider; + } + + public void getModelAndCache(String modelId, long modelVersion, ActionListener modelActionListener) { + String key = modelKey(modelId, modelVersion); + Model cachedModel = loadedModels.get(key); + if (cachedModel != null) { + modelActionListener.onResponse(cachedModel); + return; + } + SetOnce newLoad = new SetOnce<>(); + synchronized (loadingListeners) { + cachedModel = loadedModels.get(key); + if (cachedModel != null) { + modelActionListener.onResponse(cachedModel); + return; + } + loadingListeners.compute(key, (modelKey, listeners) -> { + if (listeners == null) { + newLoad.set(true); + Deque> listenerDeque = new ArrayDeque<>(); + listenerDeque.addLast(modelActionListener); + return listenerDeque; + } + newLoad.set(false); + listeners.add(modelActionListener); + return listeners; + }); + } + if (newLoad.get()) { + // TODO support loading other types of models? + loadModel(key, modelId, modelVersion); + } + } + + private void loadModel(String modelKey, String modelId, long modelVersion) { + provider.getTrainedModel(modelId, modelVersion, ActionListener.wrap( + trainedModelConfig -> { + logger.debug("[{}] successfully loaded model", modelKey); + handleLoadSuccess(modelKey, trainedModelConfig); + }, + failure -> { + logger.warn(new ParameterizedMessage("[{}] failed to load model", modelKey), failure); + handleLoadFailure(modelKey, failure); + } + )); + } + + private void handleLoadSuccess(String modelKey, TrainedModelConfig trainedModelConfig) { + Queue> listeners; + Model loadedModel = new LocalModel(trainedModelConfig.getDefinition()); + synchronized (loadingListeners) { + loadedModels.put(modelKey, loadedModel); + listeners = loadingListeners.remove(modelKey); + } + if (listeners != null) { + for(ActionListener listener = listeners.poll(); listener != null; listener = listeners.poll()) { + listener.onResponse(loadedModel); + } + } + } + + private void handleLoadFailure(String modelKey, Exception failure) { + Queue> listeners; + synchronized (loadingListeners) { + // TODO do we want to cache the failure? + listeners = loadingListeners.remove(modelKey); + } + if (listeners != null) { + for(ActionListener listener = listeners.poll(); listener != null; listener = listeners.poll()) { + listener.onFailure(failure); + } + } + } + + private String modelKey(String modelId, long modelVersion) { + return modelId + "_" + modelVersion; + } + + @Override + public void clusterChanged(ClusterChangedEvent event) { + // TODO + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/InferModelActionRequestTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/InferModelActionRequestTests.java new file mode 100644 index 0000000000000..e29cc63e30bce --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/InferModelActionRequestTests.java @@ -0,0 +1,33 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.ml.inference.action.InferModelAction.Request; + +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class InferModelActionRequestTests extends AbstractWireSerializingTestCase { + + @Override + protected Request createTestInstance() { + return new Request(randomAlphaOfLength(10), + randomLongBetween(1, 100), + randomBoolean() ? null : Stream.generate(()-> randomAlphaOfLength(10)) + .limit(randomInt(10)) + .collect(Collectors.toMap(Function.identity(), (v) -> randomAlphaOfLength(10))), + randomBoolean() ? null : randomIntBetween(-1, 100)); + } + + @Override + protected Writeable.Reader instanceReader() { + return Request::new; + } + +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/InferModelActionResponseTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/InferModelActionResponseTests.java new file mode 100644 index 0000000000000..06bf1ab983b2d --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/InferModelActionResponseTests.java @@ -0,0 +1,35 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.ml.inference.action.InferModelAction.Response; + +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class InferModelActionResponseTests extends AbstractWireSerializingTestCase { + + @Override + @SuppressWarnings("unchecked") + protected Response createTestInstance() { + Supplier resultSupplier = randomFrom(() -> randomAlphaOfLength(10), + ESTestCase::randomDouble, + () -> Stream.generate(() -> randomAlphaOfLength(10)) + .limit(randomIntBetween(1, 10)) + .collect(Collectors.toMap(Function.identity(), v -> randomDouble()))); + return new Response(Stream.generate(resultSupplier).limit(randomIntBetween(0, 10)).collect(Collectors.toList())); + } + + @Override + protected Writeable.Reader instanceReader() { + return Response::new; + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java new file mode 100644 index 0000000000000..d9c91afa34b9d --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java @@ -0,0 +1,207 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.loadingservice; + +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; + +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.equalTo; + +public class LocalModelTests extends ESTestCase { + + @SuppressWarnings("unchecked") + public void testClassificationInfer() throws Exception { + TrainedModelDefinition definition = new TrainedModelDefinition.Builder() + .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) + .setTrainedModel(buildClassification(false)) + .build(); + + Model model = new LocalModel(definition); + Map fields = new HashMap<>() {{ + put("foo", 1.0); + put("bar", 0.5); + put("categorical", "dog"); + }}; + + PlainActionFuture future = new PlainActionFuture<>(); + model.infer(fields, future); + assertThat(future.get(), equalTo(0.0)); + + future = new PlainActionFuture<>(); + model.confidence(fields, 0, future); + assertThat(future.get(), equalTo(Collections.emptyMap())); + + future = new PlainActionFuture<>(); + model.confidence(fields, 1, future); + assertThat(((Map)future.get()).get("0"), closeTo(0.5498339973124778, 0.0000001)); + + // Test with labels + definition = new TrainedModelDefinition.Builder() + .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) + .setTrainedModel(buildClassification(true)) + .build(); + model = new LocalModel(definition); + future = new PlainActionFuture<>(); + model.infer(fields, future); + assertThat(future.get(), equalTo("not_to_be")); + + future = new PlainActionFuture<>(); + model.confidence(fields, 0, future); + assertThat(future.get(), equalTo(Collections.emptyMap())); + + future = new PlainActionFuture<>(); + model.confidence(fields, 1, future); + assertThat(((Map)future.get()).get("not_to_be"), closeTo(0.5498339973124778, 0.0000001)); + + future = new PlainActionFuture<>(); + model.confidence(fields, 2, future); + assertThat((Map)future.get(), aMapWithSize(2)); + + future = new PlainActionFuture<>(); + model.confidence(fields, -1, future); + assertThat((Map)future.get(), aMapWithSize(2)); + } + + public void testRegression() throws Exception { + TrainedModelDefinition trainedModelDefinition = new TrainedModelDefinition.Builder() + .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) + .setTrainedModel(buildRegression()) + .build(); + Model model = new LocalModel(trainedModelDefinition); + + Map fields = new HashMap<>() {{ + put("foo", 1.0); + put("bar", 0.5); + put("categorical", "dog"); + }}; + + PlainActionFuture future = new PlainActionFuture<>(); + model.infer(fields, future); + assertThat(future.get(), equalTo(1.3)); + + PlainActionFuture failedFuture = new PlainActionFuture<>(); + model.confidence(fields, -1, failedFuture); + ExecutionException ex = expectThrows(ExecutionException.class, failedFuture::get); + assertThat(ex.getCause().getMessage(), equalTo("top result probabilities is only available for classification models")); + } + + private static Map oneHotMap() { + Map oneHotEncoding = new HashMap<>(); + oneHotEncoding.put("cat", "animal_cat"); + oneHotEncoding.put("dog", "animal_dog"); + return oneHotEncoding; + } + + private static TrainedModel buildClassification(boolean includeLabels) { + List featureNames = Arrays.asList("foo", "bar", "animal_cat", "animal_dog"); + Tree tree1 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(0) + .setThreshold(0.5)) + .addNode(TreeNode.builder(1).setLeafValue(1.0)) + .addNode(TreeNode.builder(2) + .setThreshold(0.8) + .setSplitFeature(1) + .setLeftChild(3) + .setRightChild(4)) + .addNode(TreeNode.builder(3).setLeafValue(0.0)) + .addNode(TreeNode.builder(4).setLeafValue(1.0)).build(); + Tree tree2 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(3) + .setThreshold(1.0)) + .addNode(TreeNode.builder(1).setLeafValue(0.0)) + .addNode(TreeNode.builder(2).setLeafValue(1.0)) + .build(); + Tree tree3 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(0) + .setThreshold(1.0)) + .addNode(TreeNode.builder(1).setLeafValue(1.0)) + .addNode(TreeNode.builder(2).setLeafValue(0.0)) + .build(); + return Ensemble.builder() + .setClassificationLabels(includeLabels ? Arrays.asList("not_to_be", "to_be") : null) + .setTargetType(TargetType.CLASSIFICATION) + .setFeatureNames(featureNames) + .setTrainedModels(Arrays.asList(tree1, tree2, tree3)) + .setOutputAggregator(new WeightedMode(Arrays.asList(0.7, 0.5, 1.0))) + .build(); + } + + private static TrainedModel buildRegression() { + List featureNames = Arrays.asList("foo", "bar", "animal_cat", "animal_dog"); + Tree tree1 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(0) + .setThreshold(0.5)) + .addNode(TreeNode.builder(1).setLeafValue(0.3)) + .addNode(TreeNode.builder(2) + .setThreshold(0.0) + .setSplitFeature(3) + .setLeftChild(3) + .setRightChild(4)) + .addNode(TreeNode.builder(3).setLeafValue(0.1)) + .addNode(TreeNode.builder(4).setLeafValue(0.2)).build(); + Tree tree2 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(2) + .setThreshold(1.0)) + .addNode(TreeNode.builder(1).setLeafValue(1.5)) + .addNode(TreeNode.builder(2).setLeafValue(0.9)) + .build(); + Tree tree3 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(1) + .setThreshold(0.2)) + .addNode(TreeNode.builder(1).setLeafValue(1.5)) + .addNode(TreeNode.builder(2).setLeafValue(0.9)) + .build(); + return Ensemble.builder() + .setTargetType(TargetType.REGRESSION) + .setFeatureNames(featureNames) + .setTrainedModels(Arrays.asList(tree1, tree2, tree3)) + .setOutputAggregator(new WeightedSum(Arrays.asList(0.5, 0.5, 0.5))) + .build(); + } + +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java new file mode 100644 index 0000000000000..c9779abff8b2b --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java @@ -0,0 +1,104 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.loadingservice; + +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; +import org.junit.Before; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class ModelLoadingServiceTests extends ESTestCase { + + TrainedModelProvider trainedModelProvider; + + @Before + public void setUpComponents() { + trainedModelProvider = mock(TrainedModelProvider.class); + } + + public void testGetModelAndCache() throws Exception { + String model1 = "test-load-model-1"; + String model2 = "test-load-model-2"; + String model3 = "test-load-model-3"; + withTrainedModel(model1, 0); + withTrainedModel(model2, 0); + withTrainedModel(model3, 0); + + ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider); + + String[] modelIds = new String[]{model1, model2, model3}; + for(int i = 0; i < 10; i++) { + String model = modelIds[i%3]; + PlainActionFuture future = new PlainActionFuture<>(); + modelLoadingService.getModelAndCache(model, 0, future); + assertThat(future.get(), is(not(nullValue()))); + } + + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model1), eq(0L), any()); + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model2), eq(0L), any()); + verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(0L), any()); + } + + public void testGetMissingModelAndCache() { + String model = "test-load-missing-model"; + withMissingModel(model, 0); + + ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider); + + modelLoadingService.getModelAndCache(model, 0, ActionListener.wrap( + m -> fail("Should not have succeeded"), + f -> assertThat(f.getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model, 0))) + )); + } + + @SuppressWarnings("unchecked") + private void withTrainedModel(String modelId, long modelVersion) { + TrainedModelConfig trainedModelConfig = buildTrainedModelConfigBuilder(modelId, modelVersion).build(Version.CURRENT); + doAnswer(invocationOnMock -> { + @SuppressWarnings("rawtypes") + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(trainedModelConfig); + return null; + }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(modelVersion), any()); + } + + private void withMissingModel(String modelId, long modelVersion) { + doAnswer(invocationOnMock -> { + @SuppressWarnings("rawtypes") + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onFailure(new ResourceNotFoundException( + Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId, modelVersion))); + return null; + }).when(trainedModelProvider).getTrainedModel(eq(modelId), eq(modelVersion), any()); + } + + private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId, long modelVersion) { + return TrainedModelConfig.builder() + .setCreatedBy("ml_test") + .setDefinition(TrainedModelDefinitionTests.createRandomBuilder()) + .setDescription("trained model config for test") + .setModelId(modelId) + .setModelType("binary_decision_tree") + .setModelVersion(modelVersion); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java new file mode 100644 index 0000000000000..3c1c3915a8e11 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java @@ -0,0 +1,292 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.integration; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.Version; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.search.SearchModule; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; +import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; +import org.elasticsearch.xpack.ml.inference.action.InferModelAction; +import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; +import org.junit.Before; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.nullValue; + +public class ModelInferenceActionIT extends MlSingleNodeTestCase { + + private TrainedModelProvider trainedModelProvider; + + @Before + public void createComponents() throws Exception { + trainedModelProvider = new TrainedModelProvider(client(), xContentRegistry()); + waitForMlTemplates(); + } + + @SuppressWarnings("unchecked") + public void testInferModels() throws Exception { + String modelId1 = "test-load-models-regression"; + String modelId2 = "test-load-models-classification"; + Map oneHotEncoding = new HashMap<>(); + oneHotEncoding.put("cat", "animal_cat"); + oneHotEncoding.put("dog", "animal_dog"); + TrainedModelConfig config1 = buildTrainedModelConfigBuilder(modelId2, 0) + .setDefinition(new TrainedModelDefinition.Builder() + .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding))) + .setTrainedModel(buildClassification())) + .build(Version.CURRENT); + TrainedModelConfig config2 = buildTrainedModelConfigBuilder(modelId1, 0) + .setDefinition(new TrainedModelDefinition.Builder() + .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding))) + .setTrainedModel(buildRegression())) + .build(Version.CURRENT); + AtomicReference putConfigHolder = new AtomicReference<>(); + AtomicReference exceptionHolder = new AtomicReference<>(); + + blockingCall(listener -> trainedModelProvider.storeTrainedModel(config1, listener), putConfigHolder, exceptionHolder); + assertThat(putConfigHolder.get(), is(true)); + assertThat(exceptionHolder.get(), is(nullValue())); + blockingCall(listener -> trainedModelProvider.storeTrainedModel(config2, listener), putConfigHolder, exceptionHolder); + assertThat(putConfigHolder.get(), is(true)); + assertThat(exceptionHolder.get(), is(nullValue())); + + + List> toInfer = new ArrayList<>(); + toInfer.add(new HashMap<>() {{ + put("foo", 1.0); + put("bar", 0.5); + put("categorical", "dog"); + }}); + toInfer.add(new HashMap<>() {{ + put("foo", 0.9); + put("bar", 1.5); + put("categorical", "cat"); + }}); + + List> toInfer2 = new ArrayList<>(); + toInfer2.add(new HashMap<>() {{ + put("foo", 0.0); + put("bar", 0.01); + put("categorical", "dog"); + }}); + toInfer2.add(new HashMap<>() {{ + put("foo", 1.0); + put("bar", 0.0); + put("categorical", "cat"); + }}); + + // Test regression + InferModelAction.Request request = new InferModelAction.Request(modelId1, 0, toInfer, null); + InferModelAction.Response response = client().execute(InferModelAction.INSTANCE, request).actionGet(); + assertThat(response.getInferenceResponse(), contains(1.3, 1.25)); + + request = new InferModelAction.Request(modelId1, 0, toInfer2, null); + response = client().execute(InferModelAction.INSTANCE, request).actionGet(); + assertThat(response.getInferenceResponse(), contains(1.65, 1.55)); + + + // Test classification + request = new InferModelAction.Request(modelId2, 0, toInfer, null); + response = client().execute(InferModelAction.INSTANCE, request).actionGet(); + assertThat(response.getInferenceResponse(), contains("not_to_be", "to_be")); + + // Get top classes + request = new InferModelAction.Request(modelId2, 0, toInfer, 2); + response = client().execute(InferModelAction.INSTANCE, request).actionGet(); + + Map probabilities = (Map) response.getInferenceResponse().get(0); + assertThat(probabilities.get("not_to_be"), greaterThan(probabilities.get("to_be"))); + + probabilities = (Map) response.getInferenceResponse().get(1); + assertThat(probabilities.get("to_be"), greaterThan(probabilities.get("not_to_be"))); + + + request = new InferModelAction.Request(modelId2, 0, toInfer2, null); + response = client().execute(InferModelAction.INSTANCE, request).actionGet(); + assertThat(response.getInferenceResponse(), contains("to_be", "not_to_be")); + + request = new InferModelAction.Request(modelId2, 0, toInfer2, 2); + response = client().execute(InferModelAction.INSTANCE, request).actionGet(); + + probabilities = (Map) response.getInferenceResponse().get(0); + assertThat(probabilities.get("to_be"), greaterThan(probabilities.get("not_to_be"))); + + probabilities = (Map) response.getInferenceResponse().get(1); + assertThat(probabilities.get("not_to_be"), greaterThan(probabilities.get("to_be"))); + + // Test that top classes restrict the number returned + request = new InferModelAction.Request(modelId2, 0, toInfer2, 1); + response = client().execute(InferModelAction.INSTANCE, request).actionGet(); + probabilities = (Map) response.getInferenceResponse().get(0); + assertThat(probabilities.size(), equalTo(1)); + + probabilities = (Map) response.getInferenceResponse().get(1); + assertThat(probabilities.size(), equalTo(1)); + + // test -1 gets all top classes + request = new InferModelAction.Request(modelId2, 0, toInfer2, -1); + response = client().execute(InferModelAction.INSTANCE, request).actionGet(); + probabilities = (Map) response.getInferenceResponse().get(0); + assertThat(probabilities.size(), equalTo(2)); + } + + public void testInferMissingModel() { + String model = "test-infer-missing-model"; + InferModelAction.Request request = new InferModelAction.Request(model, 0, Collections.emptyList(), null); + try { + client().execute(InferModelAction.INSTANCE, request).actionGet(); + } catch (ElasticsearchException ex) { + assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model, 0))); + } + } + + private static TrainedModel buildClassification() { + List featureNames = Arrays.asList("foo", "bar", "animal_cat", "animal_dog"); + Tree tree1 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(0) + .setThreshold(0.5)) + .addNode(TreeNode.builder(1).setLeafValue(1.0)) + .addNode(TreeNode.builder(2) + .setThreshold(0.8) + .setSplitFeature(1) + .setLeftChild(3) + .setRightChild(4)) + .addNode(TreeNode.builder(3).setLeafValue(0.0)) + .addNode(TreeNode.builder(4).setLeafValue(1.0)).build(); + Tree tree2 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(3) + .setThreshold(1.0)) + .addNode(TreeNode.builder(1).setLeafValue(0.0)) + .addNode(TreeNode.builder(2).setLeafValue(1.0)) + .build(); + Tree tree3 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(0) + .setThreshold(1.0)) + .addNode(TreeNode.builder(1).setLeafValue(1.0)) + .addNode(TreeNode.builder(2).setLeafValue(0.0)) + .build(); + return Ensemble.builder() + .setClassificationLabels(Arrays.asList("not_to_be", "to_be")) + .setTargetType(TargetType.CLASSIFICATION) + .setFeatureNames(featureNames) + .setTrainedModels(Arrays.asList(tree1, tree2, tree3)) + .setOutputAggregator(new WeightedMode(Arrays.asList(0.7, 0.5, 1.0))) + .build(); + } + + private static TrainedModel buildRegression() { + List featureNames = Arrays.asList("foo", "bar", "animal_cat", "animal_dog"); + Tree tree1 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(0) + .setThreshold(0.5)) + .addNode(TreeNode.builder(1).setLeafValue(0.3)) + .addNode(TreeNode.builder(2) + .setThreshold(0.0) + .setSplitFeature(3) + .setLeftChild(3) + .setRightChild(4)) + .addNode(TreeNode.builder(3).setLeafValue(0.1)) + .addNode(TreeNode.builder(4).setLeafValue(0.2)).build(); + Tree tree2 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(2) + .setThreshold(1.0)) + .addNode(TreeNode.builder(1).setLeafValue(1.5)) + .addNode(TreeNode.builder(2).setLeafValue(0.9)) + .build(); + Tree tree3 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0) + .setLeftChild(1) + .setRightChild(2) + .setSplitFeature(1) + .setThreshold(0.2)) + .addNode(TreeNode.builder(1).setLeafValue(1.5)) + .addNode(TreeNode.builder(2).setLeafValue(0.9)) + .build(); + return Ensemble.builder() + .setTargetType(TargetType.REGRESSION) + .setFeatureNames(featureNames) + .setTrainedModels(Arrays.asList(tree1, tree2, tree3)) + .setOutputAggregator(new WeightedSum(Arrays.asList(0.5, 0.5, 0.5))) + .build(); + } + + + public void testLoadMissingModels() throws Exception { + + } + + private static TrainedModelConfig buildTrainedModelConfig(String modelId, long modelVersion) { + return buildTrainedModelConfigBuilder(modelId, modelVersion).build(Version.CURRENT); + } + + private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId, long modelVersion) { + return TrainedModelConfig.builder() + .setCreatedBy("ml_test") + .setDefinition(TrainedModelDefinitionTests.createRandomBuilder()) + .setDescription("trained model config for test") + .setModelId(modelId) + .setModelType("binary_decision_tree") + .setModelVersion(modelVersion); + } + + @Override + public NamedXContentRegistry xContentRegistry() { + List namedXContent = new ArrayList<>(); + namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); + namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); + return new NamedXContentRegistry(namedXContent); + + } + +} From 2d00f52f80e16c720e80f8732187157810b65299 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Wed, 2 Oct 2019 08:18:59 -0400 Subject: [PATCH 02/13] Adding inference results object and using that in response object --- .../ml/action/TransportInferModelAction.java | 5 +- .../ml/inference/action/InferModelAction.java | 11 +- .../ml/inference/action/InferenceResults.java | 161 ++++++++++++++++++ .../inference/loadingservice/LocalModel.java | 26 +-- .../ml/inference/loadingservice/Model.java | 5 +- .../loadingservice/ModelLoadingService.java | 1 - .../action/InferModelActionResponseTests.java | 13 +- .../action/InferenceResultsTests.java | 36 ++++ .../loadingservice/LocalModelTests.java | 49 ++++-- .../integration/ModelInferenceActionIT.java | 56 +++--- 10 files changed, 280 insertions(+), 83 deletions(-) create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/InferenceResults.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/InferenceResultsTests.java diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java index b5d33758ce260..98790a24e6f21 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java @@ -14,6 +14,7 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.ml.inference.action.InferModelAction; +import org.elasticsearch.xpack.ml.inference.action.InferenceResults; import org.elasticsearch.xpack.ml.inference.loadingservice.Model; import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor; @@ -39,14 +40,14 @@ public TransportInferModelAction(String actionName, @Override protected void doExecute(Task task, InferModelAction.Request request, ActionListener listener) { - ActionListener> inferenceCompleteListener = ActionListener.wrap( + ActionListener> inferenceCompleteListener = ActionListener.wrap( inferenceResponse -> listener.onResponse(new InferModelAction.Response(inferenceResponse)), listener::onFailure ); ActionListener getModelListener = ActionListener.wrap( model -> { - TypedChainTaskExecutor typedChainTaskExecutor = + TypedChainTaskExecutor typedChainTaskExecutor = new TypedChainTaskExecutor<>(client.threadPool().executor(ThreadPool.Names.SAME), // run through all tasks r -> true, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/InferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/InferModelAction.java index 134243343f32e..7446f01fe2944 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/InferModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/InferModelAction.java @@ -129,26 +129,25 @@ public RequestBuilder(ElasticsearchClient client, Request request) { public static class Response extends ActionResponse { - // TODO come up with a better union type object - private final List inferenceResponse; + private final List inferenceResponse; - public Response(List inferenceResponse) { + public Response(List inferenceResponse) { super(); this.inferenceResponse = Collections.unmodifiableList(inferenceResponse); } public Response(StreamInput in) throws IOException { super(in); - this.inferenceResponse = Collections.unmodifiableList(in.readList(StreamInput::readGenericValue)); + this.inferenceResponse = Collections.unmodifiableList(in.readList(InferenceResults::new)); } - public List getInferenceResponse() { + public List getInferenceResponse() { return inferenceResponse; } @Override public void writeTo(StreamOutput out) throws IOException { - out.writeCollection(inferenceResponse, StreamOutput::writeGenericValue); + out.writeCollection(inferenceResponse); } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/InferenceResults.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/InferenceResults.java new file mode 100644 index 0000000000000..e9ba0721c9ed7 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/InferenceResults.java @@ -0,0 +1,161 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.action; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +public class InferenceResults implements ToXContentObject, Writeable { + + public final ParseField TOP_CLASSES = new ParseField("top_classes"); + public final ParseField NUMERIC_VALUE = new ParseField("numeric_value"); + public final ParseField CLASSIFICATION_LABEL = new ParseField("classification_label"); + + private final double numericValue; + private final String classificationLabel; + private final List topClasses; + + public static InferenceResults valueOnly(double value) { + return new InferenceResults(value, null, null); + } + + public static InferenceResults valueAndLabel(double value, String classificationLabel) { + return new InferenceResults(value, classificationLabel, null); + } + + public InferenceResults(Double numericValue, String classificationLabel, List topClasses) { + this.numericValue = ExceptionsHelper.requireNonNull(numericValue, NUMERIC_VALUE); + this.classificationLabel = classificationLabel; + this.topClasses = topClasses == null ? null : Collections.unmodifiableList(topClasses); + } + + public InferenceResults(StreamInput in) throws IOException { + this.numericValue = in.readDouble(); + this.classificationLabel = in.readOptionalString(); + if (in.readBoolean()) { + this.topClasses = Collections.unmodifiableList(in.readList(TopClassEntry::new)); + } else { + this.topClasses = null; + } + } + + public double getNumericValue() { + return numericValue; + } + + public String getClassificationLabel() { + return classificationLabel; + } + + public List getTopClasses() { + return topClasses; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDouble(numericValue); + out.writeOptionalString(classificationLabel); + out.writeBoolean(topClasses != null); + if (topClasses != null) { + out.writeList(topClasses); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(NUMERIC_VALUE.getPreferredName(), numericValue); + if (classificationLabel != null) { + builder.field(CLASSIFICATION_LABEL.getPreferredName(), classificationLabel); + } + if (topClasses != null) { + builder.field(TOP_CLASSES.getPreferredName(), topClasses); + } + builder.endObject(); + return null; + } + + @Override + public boolean equals(Object object) { + if (object == this) { return true; } + if (object == null || getClass() != object.getClass()) { return false; } + InferenceResults that = (InferenceResults) object; + return Objects.equals(numericValue, that.numericValue) && + Objects.equals(classificationLabel, that.classificationLabel) && + Objects.equals(topClasses, that.topClasses); + } + + @Override + public int hashCode() { + return Objects.hash(numericValue, classificationLabel, topClasses); + } + + public static class TopClassEntry implements ToXContentObject, Writeable { + + public final ParseField LABEL = new ParseField("label"); + public final ParseField PROBABILITY = new ParseField("probability"); + + private final String label; + private final double probability; + + public TopClassEntry(String label, Double probability) { + this.label = ExceptionsHelper.requireNonNull(label, LABEL); + this.probability = ExceptionsHelper.requireNonNull(probability, PROBABILITY); + } + + public TopClassEntry(StreamInput in) throws IOException { + this.label = in.readString(); + this.probability = in.readDouble(); + } + + public String getLabel() { + return label; + } + + public double getProbability() { + return probability; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(label); + out.writeDouble(probability); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(LABEL.getPreferredName(), label); + builder.field(PROBABILITY.getPreferredName(), probability); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object object) { + if (object == this) { return true; } + if (object == null || getClass() != object.getClass()) { return false; } + TopClassEntry that = (TopClassEntry) object; + return Objects.equals(label, that.label) && + Objects.equals(probability, that.probability); + } + + @Override + public int hashCode() { + return Objects.hash(label, probability); + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java index 7ef8ba121dab1..85b5dd0c68e1d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java @@ -11,10 +11,10 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.inference.action.InferenceResults; -import java.util.Collections; +import java.util.ArrayList; import java.util.Comparator; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -28,7 +28,7 @@ public LocalModel(TrainedModelDefinition trainedModelDefinition) { } @Override - public void infer(Map fields, ActionListener listener) { + public void infer(Map fields, ActionListener listener) { trainedModelDefinition.getPreProcessors().forEach(preProcessor -> preProcessor.process(fields)); double value = trainedModelDefinition.getTrainedModel().infer(fields); if (trainedModelDefinition.getTrainedModel().targetType() == TargetType.CLASSIFICATION && @@ -42,16 +42,17 @@ public void infer(Map fields, ActionListener listener) { trainedModelDefinition.getTrainedModel().classificationLabels())); return; } - listener.onResponse(trainedModelDefinition.getTrainedModel().classificationLabels().get(classIndex)); + listener.onResponse(InferenceResults.valueAndLabel(value, + trainedModelDefinition.getTrainedModel().classificationLabels().get(classIndex))); return; } - listener.onResponse(Double.valueOf(value)); + listener.onResponse(InferenceResults.valueOnly(value)); } @Override - public void confidence(Map fields, int topN, ActionListener listener) { + public void confidence(Map fields, int topN, ActionListener listener) { if (topN == 0) { - listener.onResponse(Collections.emptyMap()); + infer(fields, listener); return; } if (trainedModelDefinition.getTrainedModel().targetType() != TargetType.CLASSIFICATION) { @@ -82,11 +83,16 @@ public void confidence(Map fields, int topN, ActionListener probabilityMap = new HashMap<>(count); + List topClassEntries = new ArrayList<>(count); for(int i = 0; i < count; i++) { int idx = sortedIndices[i]; - probabilityMap.put(labels.get(idx), probabilities.get(idx)); + topClassEntries.add(new InferenceResults.TopClassEntry(labels.get(idx), probabilities.get(idx))); } - listener.onResponse(probabilityMap); + + listener.onResponse(new InferenceResults(((Number)sortedIndices[0]).doubleValue(), + trainedModelDefinition.getTrainedModel().classificationLabels() == null ? + null : + trainedModelDefinition.getTrainedModel().classificationLabels().get(sortedIndices[0]), + topClassEntries)); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java index ea6a8022eacd4..e6b1214058827 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java @@ -6,12 +6,13 @@ package org.elasticsearch.xpack.ml.inference.loadingservice; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.xpack.ml.inference.action.InferenceResults; import java.util.Map; public interface Model { - void infer(Map fields, ActionListener listener); + void infer(Map fields, ActionListener listener); - void confidence(Map fields, int topN, ActionListener listener); + void confidence(Map fields, int topN, ActionListener listener); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java index aae17c90e21fb..65de20026e6db 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -12,7 +12,6 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.cluster.ClusterStateListener; -import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/InferModelActionResponseTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/InferModelActionResponseTests.java index 06bf1ab983b2d..dbf9ce0708c41 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/InferModelActionResponseTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/InferModelActionResponseTests.java @@ -7,25 +7,18 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.ml.inference.action.InferModelAction.Response; -import java.util.function.Function; -import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; public class InferModelActionResponseTests extends AbstractWireSerializingTestCase { @Override - @SuppressWarnings("unchecked") protected Response createTestInstance() { - Supplier resultSupplier = randomFrom(() -> randomAlphaOfLength(10), - ESTestCase::randomDouble, - () -> Stream.generate(() -> randomAlphaOfLength(10)) - .limit(randomIntBetween(1, 10)) - .collect(Collectors.toMap(Function.identity(), v -> randomDouble()))); - return new Response(Stream.generate(resultSupplier).limit(randomIntBetween(0, 10)).collect(Collectors.toList())); + return new Response(Stream.generate(InferenceResultsTests::createRandomResults) + .limit(randomIntBetween(0, 10)) + .collect(Collectors.toList())); } @Override diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/InferenceResultsTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/InferenceResultsTests.java new file mode 100644 index 0000000000000..becf96ee6bc55 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/InferenceResultsTests.java @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class InferenceResultsTests extends AbstractWireSerializingTestCase { + + public static InferenceResults createRandomResults() { + return new InferenceResults(randomDouble(), + randomBoolean() ? null : randomAlphaOfLength(10), + randomBoolean() ? null : + Stream.generate(InferenceResultsTests::createRandomClassEntry).limit(randomIntBetween(0, 10)).collect(Collectors.toList())); + } + + private static InferenceResults.TopClassEntry createRandomClassEntry() { + return new InferenceResults.TopClassEntry(randomAlphaOfLength(10), randomDouble()); + } + + @Override + protected InferenceResults createTestInstance() { + return createRandomResults(); + } + + @Override + protected Writeable.Reader instanceReader() { + return InferenceResults::new; + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java index d9c91afa34b9d..d54456dba1c96 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java @@ -16,21 +16,22 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode; +import org.elasticsearch.xpack.ml.inference.action.InferenceResults; import java.util.Arrays; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ExecutionException; -import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.nullValue; public class LocalModelTests extends ESTestCase { - @SuppressWarnings("unchecked") public void testClassificationInfer() throws Exception { TrainedModelDefinition definition = new TrainedModelDefinition.Builder() .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) @@ -44,17 +45,25 @@ public void testClassificationInfer() throws Exception { put("categorical", "dog"); }}; - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture future = new PlainActionFuture<>(); model.infer(fields, future); - assertThat(future.get(), equalTo(0.0)); + InferenceResults result = future.get(); + assertThat(result.getNumericValue(), equalTo(0.0)); + assertThat(result.getClassificationLabel(), is(nullValue())); + assertThat(result.getTopClasses(), is(nullValue())); future = new PlainActionFuture<>(); model.confidence(fields, 0, future); - assertThat(future.get(), equalTo(Collections.emptyMap())); + result = future.get(); + assertThat(result.getNumericValue(), equalTo(0.0)); + assertThat(result.getClassificationLabel(), is(nullValue())); + assertThat(result.getTopClasses(), is(nullValue())); future = new PlainActionFuture<>(); model.confidence(fields, 1, future); - assertThat(((Map)future.get()).get("0"), closeTo(0.5498339973124778, 0.0000001)); + result = future.get(); + assertThat(result.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001)); + assertThat(result.getTopClasses().get(0).getLabel(), equalTo("0")); // Test with labels definition = new TrainedModelDefinition.Builder() @@ -64,23 +73,32 @@ public void testClassificationInfer() throws Exception { model = new LocalModel(definition); future = new PlainActionFuture<>(); model.infer(fields, future); - assertThat(future.get(), equalTo("not_to_be")); + result = future.get(); + assertThat(result.getNumericValue(), equalTo(0.0)); + assertThat(result.getClassificationLabel(), equalTo("not_to_be")); future = new PlainActionFuture<>(); model.confidence(fields, 0, future); - assertThat(future.get(), equalTo(Collections.emptyMap())); + result = future.get(); + assertThat(result.getNumericValue(), equalTo(0.0)); + assertThat(result.getClassificationLabel(), equalTo("not_to_be")); + assertThat(result.getTopClasses(), is(nullValue())); future = new PlainActionFuture<>(); model.confidence(fields, 1, future); - assertThat(((Map)future.get()).get("not_to_be"), closeTo(0.5498339973124778, 0.0000001)); + result = future.get(); + assertThat(result.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001)); + assertThat(result.getTopClasses().get(0).getLabel(), equalTo("not_to_be")); future = new PlainActionFuture<>(); model.confidence(fields, 2, future); - assertThat((Map)future.get(), aMapWithSize(2)); + result = future.get(); + assertThat(result.getTopClasses(), hasSize(2)); future = new PlainActionFuture<>(); model.confidence(fields, -1, future); - assertThat((Map)future.get(), aMapWithSize(2)); + result = future.get(); + assertThat(result.getTopClasses(), hasSize(2)); } public void testRegression() throws Exception { @@ -96,11 +114,12 @@ public void testRegression() throws Exception { put("categorical", "dog"); }}; - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture future = new PlainActionFuture<>(); model.infer(fields, future); - assertThat(future.get(), equalTo(1.3)); + InferenceResults results = future.get(); + assertThat(results.getNumericValue(), equalTo(1.3)); - PlainActionFuture failedFuture = new PlainActionFuture<>(); + PlainActionFuture failedFuture = new PlainActionFuture<>(); model.confidence(fields, -1, failedFuture); ExecutionException ex = expectThrows(ExecutionException.class, failedFuture::get); assertThat(ex.getCause().getMessage(), equalTo("top result probabilities is only available for classification models")); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java index 3c1c3915a8e11..328f4126f8fff 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java @@ -25,7 +25,7 @@ import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; import org.elasticsearch.xpack.ml.inference.action.InferModelAction; -import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; +import org.elasticsearch.xpack.ml.inference.action.InferenceResults; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.junit.Before; @@ -36,11 +36,13 @@ import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.nullValue; public class ModelInferenceActionIT extends MlSingleNodeTestCase { @@ -108,56 +110,40 @@ public void testInferModels() throws Exception { // Test regression InferModelAction.Request request = new InferModelAction.Request(modelId1, 0, toInfer, null); InferModelAction.Response response = client().execute(InferModelAction.INSTANCE, request).actionGet(); - assertThat(response.getInferenceResponse(), contains(1.3, 1.25)); + assertThat(response.getInferenceResponse().stream().map(InferenceResults::getNumericValue).collect(Collectors.toList()), + contains(1.3, 1.25)); request = new InferModelAction.Request(modelId1, 0, toInfer2, null); response = client().execute(InferModelAction.INSTANCE, request).actionGet(); - assertThat(response.getInferenceResponse(), contains(1.65, 1.55)); + assertThat(response.getInferenceResponse().stream().map(InferenceResults::getNumericValue).collect(Collectors.toList()), + contains(1.65, 1.55)); // Test classification request = new InferModelAction.Request(modelId2, 0, toInfer, null); response = client().execute(InferModelAction.INSTANCE, request).actionGet(); - assertThat(response.getInferenceResponse(), contains("not_to_be", "to_be")); + assertThat(response.getInferenceResponse().stream().map(InferenceResults::getClassificationLabel).collect(Collectors.toList()), + contains("not_to_be", "to_be")); // Get top classes request = new InferModelAction.Request(modelId2, 0, toInfer, 2); response = client().execute(InferModelAction.INSTANCE, request).actionGet(); - Map probabilities = (Map) response.getInferenceResponse().get(0); - assertThat(probabilities.get("not_to_be"), greaterThan(probabilities.get("to_be"))); + assertThat(response.getInferenceResponse().get(0).getTopClasses().get(0).getLabel(), equalTo("not_to_be")); + assertThat(response.getInferenceResponse().get(0).getTopClasses().get(1).getLabel(), equalTo("to_be")); + assertThat(response.getInferenceResponse().get(0).getTopClasses().get(0).getProbability(), + greaterThan(response.getInferenceResponse().get(0).getTopClasses().get(1).getProbability())); - probabilities = (Map) response.getInferenceResponse().get(1); - assertThat(probabilities.get("to_be"), greaterThan(probabilities.get("not_to_be"))); - - - request = new InferModelAction.Request(modelId2, 0, toInfer2, null); - response = client().execute(InferModelAction.INSTANCE, request).actionGet(); - assertThat(response.getInferenceResponse(), contains("to_be", "not_to_be")); - - request = new InferModelAction.Request(modelId2, 0, toInfer2, 2); - response = client().execute(InferModelAction.INSTANCE, request).actionGet(); - - probabilities = (Map) response.getInferenceResponse().get(0); - assertThat(probabilities.get("to_be"), greaterThan(probabilities.get("not_to_be"))); - - probabilities = (Map) response.getInferenceResponse().get(1); - assertThat(probabilities.get("not_to_be"), greaterThan(probabilities.get("to_be"))); + assertThat(response.getInferenceResponse().get(1).getTopClasses().get(0).getLabel(), equalTo("to_be")); + assertThat(response.getInferenceResponse().get(1).getTopClasses().get(1).getLabel(), equalTo("not_to_be")); + assertThat(response.getInferenceResponse().get(1).getTopClasses().get(0).getProbability(), + greaterThan(response.getInferenceResponse().get(1).getTopClasses().get(1).getProbability())); // Test that top classes restrict the number returned request = new InferModelAction.Request(modelId2, 0, toInfer2, 1); response = client().execute(InferModelAction.INSTANCE, request).actionGet(); - probabilities = (Map) response.getInferenceResponse().get(0); - assertThat(probabilities.size(), equalTo(1)); - - probabilities = (Map) response.getInferenceResponse().get(1); - assertThat(probabilities.size(), equalTo(1)); - - // test -1 gets all top classes - request = new InferModelAction.Request(modelId2, 0, toInfer2, -1); - response = client().execute(InferModelAction.INSTANCE, request).actionGet(); - probabilities = (Map) response.getInferenceResponse().get(0); - assertThat(probabilities.size(), equalTo(2)); + assertThat(response.getInferenceResponse().get(0).getTopClasses(), hasSize(1)); + assertThat(response.getInferenceResponse().get(0).getTopClasses().get(0).getLabel(), equalTo("to_be")); } public void testInferMissingModel() { @@ -262,10 +248,6 @@ private static TrainedModel buildRegression() { } - public void testLoadMissingModels() throws Exception { - - } - private static TrainedModelConfig buildTrainedModelConfig(String modelId, long modelVersion) { return buildTrainedModelConfigBuilder(modelId, modelVersion).build(Version.CURRENT); } From a4b764328158181cde022505a6c1119c7d464242 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Wed, 2 Oct 2019 20:11:44 -0400 Subject: [PATCH 03/13] fixing model caching and loading given cluster state updates --- .../xpack/ml/MachineLearning.java | 2 +- .../ml/action/TransportInferModelAction.java | 5 +- .../inference/ingest/InferenceProcessor.java | 40 ++++ .../inference/loadingservice/LocalModel.java | 2 +- .../ml/inference/loadingservice/Model.java | 2 +- .../loadingservice/ModelLoadingService.java | 200 +++++++++++++++--- .../loadingservice/LocalModelTests.java | 14 +- .../ModelLoadingServiceTests.java | 128 ++++++++++- .../integration/ModelInferenceActionIT.java | 6 - 9 files changed, 338 insertions(+), 61 deletions(-) create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 4296e46e232a7..f2b617a1d0a19 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -500,7 +500,7 @@ public Collection createComponents(Client client, ClusterService cluster xContentRegistry); final TrainedModelProvider trainedModelProvider = new TrainedModelProvider(client, xContentRegistry); - final ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider); + final ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, threadPool, clusterService); // special holder for @link(MachineLearningFeatureSetUsage) which needs access to job manager if ML is enabled JobManagerHolder jobManagerHolder = new JobManagerHolder(jobManager); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java index 98790a24e6f21..a6668da4ee786 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java @@ -55,7 +55,8 @@ protected void doExecute(Task task, InferModelAction.Request request, ActionList ex -> true); if (request.getTopClasses() != null) { request.getObjectsToInfer().forEach(stringObjectMap -> - typedChainTaskExecutor.add(chainedTask -> model.confidence(stringObjectMap, request.getTopClasses(), chainedTask)) + typedChainTaskExecutor.add(chainedTask -> + model.classificationProbability(stringObjectMap, request.getTopClasses(), chainedTask)) ); } else { request.getObjectsToInfer().forEach(stringObjectMap -> @@ -67,6 +68,6 @@ protected void doExecute(Task task, InferModelAction.Request request, ActionList listener::onFailure ); - this.modelLoadingService.getModelAndCache(request.getModelId(), request.getModelVersion(), getModelListener); + this.modelLoadingService.getModel(request.getModelId(), request.getModelVersion(), getModelListener); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java new file mode 100644 index 0000000000000..5ea4c77a4c0ce --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java @@ -0,0 +1,40 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.ingest; + +import org.elasticsearch.client.Client; +import org.elasticsearch.ingest.AbstractProcessor; +import org.elasticsearch.ingest.IngestDocument; + +import java.util.function.BiConsumer; + +public class InferenceProcessor extends AbstractProcessor { + + public static final String TYPE = "inference"; + public static final String MODEL_ID = "model_id"; + + private final Client client; + public InferenceProcessor(Client client, String tag) { + super(tag); + this.client = client; + } + + @Override + public void execute(IngestDocument ingestDocument, BiConsumer handler) { + //TODO actually work + handler.accept(ingestDocument, null); + } + + @Override + public IngestDocument execute(IngestDocument ingestDocument) { + throw new UnsupportedOperationException("should never be called"); + } + + @Override + public String getType() { + return TYPE; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java index 85b5dd0c68e1d..9701a31dba17c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java @@ -50,7 +50,7 @@ public void infer(Map fields, ActionListener l } @Override - public void confidence(Map fields, int topN, ActionListener listener) { + public void classificationProbability(Map fields, int topN, ActionListener listener) { if (topN == 0) { infer(fields, listener); return; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java index e6b1214058827..6e8e8f41fb14e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java @@ -14,5 +14,5 @@ public interface Model { void infer(Map fields, ActionListener listener); - void confidence(Map fields, int topN, ActionListener listener); + void classificationProbability(Map fields, int topN, ActionListener listener); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java index 65de20026e6db..687fcce28f4e1 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -8,58 +8,95 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; -import org.apache.lucene.util.SetOnce; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.ClusterChangedEvent; +import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.ClusterStateListener; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.ingest.IngestMetadata; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import java.util.ArrayDeque; -import java.util.Deque; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; import java.util.Queue; +import java.util.Set; import java.util.concurrent.ConcurrentHashMap; public class ModelLoadingService implements ClusterStateListener { private static final Logger logger = LogManager.getLogger(ModelLoadingService.class); - private final ConcurrentHashMap loadedModels = new ConcurrentHashMap<>(); + // TODO should these be ConcurrentHashMaps if all interactions are synchronized? + private final ConcurrentHashMap> loadedModels = new ConcurrentHashMap<>(); private final ConcurrentHashMap>> loadingListeners = new ConcurrentHashMap<>(); private final TrainedModelProvider provider; + private final ThreadPool threadPool; - public ModelLoadingService(TrainedModelProvider trainedModelProvider) { + public ModelLoadingService(TrainedModelProvider trainedModelProvider, ThreadPool threadPool, ClusterService clusterService) { this.provider = trainedModelProvider; + this.threadPool = threadPool; + clusterService.addListener(this); + // TODO should we load state here? Or will it get the applied state because it is a listener? } - public void getModelAndCache(String modelId, long modelVersion, ActionListener modelActionListener) { + public void getModel(String modelId, long modelVersion, ActionListener modelActionListener) { String key = modelKey(modelId, modelVersion); - Model cachedModel = loadedModels.get(key); + Optional cachedModel = loadedModels.get(key); if (cachedModel != null) { - modelActionListener.onResponse(cachedModel); - return; + if (cachedModel.isPresent()) { + modelActionListener.onResponse(cachedModel.get()); + return; + } } - SetOnce newLoad = new SetOnce<>(); + if (loadModelIfNecessary(key, modelId, modelVersion, modelActionListener) == false) { + // If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called + // by a simulated pipeline + logger.debug("[{}] version [{}] not actively loading, eager loading without cache", modelId, modelVersion); + provider.getTrainedModel(modelId, modelVersion, ActionListener.wrap( + trainedModelConfig -> modelActionListener.onResponse(new LocalModel(trainedModelConfig.getDefinition())), + modelActionListener::onFailure + )); + } else { + logger.debug("[{}] version [{}] is currently loading, added new listener to queue", modelId, modelVersion); + } + } + + /** + * Returns true if the model is loaded and the listener has been given the cached model + * Returns true if the model is CURRENTLY being loaded and the listener was added to be notified when it is loaded + * Returns false if the model is not loaded or actively being loaded + */ + private boolean loadModelIfNecessary(String key, String modelId, long modelVersion, ActionListener modelActionListener) { synchronized (loadingListeners) { - cachedModel = loadedModels.get(key); + Optional cachedModel = loadedModels.get(key); if (cachedModel != null) { - modelActionListener.onResponse(cachedModel); - return; - } - loadingListeners.compute(key, (modelKey, listeners) -> { - if (listeners == null) { - newLoad.set(true); - Deque> listenerDeque = new ArrayDeque<>(); - listenerDeque.addLast(modelActionListener); - return listenerDeque; + if (cachedModel.isPresent()) { + modelActionListener.onResponse(cachedModel.get()); + return true; } - newLoad.set(false); - listeners.add(modelActionListener); - return listeners; - }); - } - if (newLoad.get()) { - // TODO support loading other types of models? - loadModel(key, modelId, modelVersion); + // If the loaded model entry is there but is not present, that means the previous load attempt ran into an issue + // Attempt to load and cache the model if necessary + if (loadingListeners.computeIfPresent( + key, + (storedModelKey, listenerQueue) -> addFluently(listenerQueue, modelActionListener)) == null) { + logger.debug("[{}] version [{}] attempting to load and cache", modelId, modelVersion); + loadingListeners.put(key, addFluently(new ArrayDeque<>(), modelActionListener)); + loadModel(key, modelId, modelVersion); + } + return true; + } + // if the cachedModel entry is null, but there are listeners present, that means it is being loaded + return loadingListeners.computeIfPresent(key, + (storedModelKey, listenerQueue) -> addFluently(listenerQueue, modelActionListener)) != null; } } @@ -80,8 +117,12 @@ private void handleLoadSuccess(String modelKey, TrainedModelConfig trainedModelC Queue> listeners; Model loadedModel = new LocalModel(trainedModelConfig.getDefinition()); synchronized (loadingListeners) { - loadedModels.put(modelKey, loadedModel); listeners = loadingListeners.remove(modelKey); + // If there is no loadingListener that means the loading was canceled and the listener was already notified as such + // Consequently, we should not store the retrieved model + if (listeners != null) { + loadedModels.put(modelKey, Optional.of(loadedModel)); + } } if (listeners != null) { for(ActionListener listener = listeners.poll(); listener != null; listener = listeners.poll()) { @@ -93,8 +134,12 @@ private void handleLoadSuccess(String modelKey, TrainedModelConfig trainedModelC private void handleLoadFailure(String modelKey, Exception failure) { Queue> listeners; synchronized (loadingListeners) { - // TODO do we want to cache the failure? listeners = loadingListeners.remove(modelKey); + if (listeners != null) { + // If we failed to load and there were listeners present, that means that this model is referenced by a processor + // Add an empty entry here so that we can attempt to load and cache the model again when it is accessed again. + loadedModels.computeIfAbsent(modelKey, (key) -> Optional.empty()); + } } if (listeners != null) { for(ActionListener listener = listeners.poll(); listener != null; listener = listeners.poll()) { @@ -103,12 +148,101 @@ private void handleLoadFailure(String modelKey, Exception failure) { } } - private String modelKey(String modelId, long modelVersion) { + @Override + public void clusterChanged(ClusterChangedEvent event) { + if (event.changedCustomMetaDataSet().contains(IngestMetadata.TYPE)) { + ClusterState state = event.state(); + IngestMetadata currentIngestMetadata = state.metaData().custom(IngestMetadata.TYPE); + Set allReferencedModelKeys = getReferencedModelKeys(currentIngestMetadata); + // The listeners still waiting for a model and we are canceling the load? + Queue> drainWithFailure = new ArrayDeque<>(); + synchronized (loadingListeners) { + // If we had models still loading here but are no longer referenced + // we should remove them from loadingListeners and alert the listeners + Iterator keyIterator = loadingListeners.keys().asIterator(); + while(keyIterator.hasNext()) { + String modelKey = keyIterator.next(); + if (allReferencedModelKeys.contains(modelKey) == false) { + drainWithFailure.addAll(loadingListeners.remove(modelKey)); + } + } + + // Remove all cached models that are not referenced by any processors + loadedModels.keySet().retainAll(allReferencedModelKeys); + + // After removing the unreferenced models, now we need to know what referenced models should be loaded + + // Remove all that are currently being loaded + allReferencedModelKeys.removeAll(loadingListeners.keySet()); + + // Remove all that are fully loaded, will attempt empty model loading again + loadedModels.forEach((id, optionalModel) -> { + if(optionalModel.isPresent()) { + allReferencedModelKeys.remove(id); + } + }); + // Populate loadingListeners key so we know that we are currently loading the model + for(String modelId : allReferencedModelKeys) { + loadingListeners.put(modelId, new ArrayDeque<>()); + } + } + for(ActionListener listener = drainWithFailure.poll(); listener != null; listener = drainWithFailure.poll()) { + listener.onFailure( + new ElasticsearchException("Cancelling model load and inference as it is no longer referenced by a pipeline")); + } + loadModels(allReferencedModelKeys); + } + } + + private void loadModels(Set modelKeys) { + if (modelKeys.isEmpty()) { + return; + } + // Execute this on a utility thread as when the callbacks occur we don't want them tying up the cluster listener thread pool + threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> { + for(String modelKey : modelKeys) { + Tuple modelIdAndVersion = splitModelKey(modelKey); + this.loadModel(modelKey, modelIdAndVersion.v1(), modelIdAndVersion.v2()); + } + }); + } + + private static Queue addFluently(Queue queue, T object) { + queue.add(object); + return queue; + } + + private static String modelKey(String modelId, long modelVersion) { return modelId + "_" + modelVersion; } - @Override - public void clusterChanged(ClusterChangedEvent event) { - // TODO + private static Tuple splitModelKey(String modelKey) { + int delim = modelKey.lastIndexOf('_'); + String modelId = modelKey.substring(0, delim); + Long version = Long.valueOf(modelKey.substring(delim + 1)); + return Tuple.tuple(modelId, version); + } + + private static Set getReferencedModelKeys(IngestMetadata ingestMetadata) { + Set allReferencedModelKeys = new HashSet<>(); + if (ingestMetadata != null) { + ingestMetadata.getPipelines().forEach((pipelineId, pipelineConfiguration) -> { + Object processors = pipelineConfiguration.getConfigAsMap().get("processors"); + if (processors instanceof List) { + for(Object processor : (List)processors) { + if (processor instanceof Map) { + Object processorConfig = ((Map)processor).get(InferenceProcessor.TYPE); + if (processorConfig instanceof Map) { + String modelId = ((Map)processorConfig).get(InferenceProcessor.MODEL_ID).toString(); + // TODO also read model version + allReferencedModelKeys.add(modelKey(modelId, 0)); + } + } + } + } + }); + } + return allReferencedModelKeys; } + } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java index d54456dba1c96..4da00f9a7ba5d 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java @@ -53,14 +53,14 @@ public void testClassificationInfer() throws Exception { assertThat(result.getTopClasses(), is(nullValue())); future = new PlainActionFuture<>(); - model.confidence(fields, 0, future); + model.classificationProbability(fields, 0, future); result = future.get(); assertThat(result.getNumericValue(), equalTo(0.0)); assertThat(result.getClassificationLabel(), is(nullValue())); assertThat(result.getTopClasses(), is(nullValue())); future = new PlainActionFuture<>(); - model.confidence(fields, 1, future); + model.classificationProbability(fields, 1, future); result = future.get(); assertThat(result.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001)); assertThat(result.getTopClasses().get(0).getLabel(), equalTo("0")); @@ -78,25 +78,25 @@ public void testClassificationInfer() throws Exception { assertThat(result.getClassificationLabel(), equalTo("not_to_be")); future = new PlainActionFuture<>(); - model.confidence(fields, 0, future); + model.classificationProbability(fields, 0, future); result = future.get(); assertThat(result.getNumericValue(), equalTo(0.0)); assertThat(result.getClassificationLabel(), equalTo("not_to_be")); assertThat(result.getTopClasses(), is(nullValue())); future = new PlainActionFuture<>(); - model.confidence(fields, 1, future); + model.classificationProbability(fields, 1, future); result = future.get(); assertThat(result.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001)); assertThat(result.getTopClasses().get(0).getLabel(), equalTo("not_to_be")); future = new PlainActionFuture<>(); - model.confidence(fields, 2, future); + model.classificationProbability(fields, 2, future); result = future.get(); assertThat(result.getTopClasses(), hasSize(2)); future = new PlainActionFuture<>(); - model.confidence(fields, -1, future); + model.classificationProbability(fields, -1, future); result = future.get(); assertThat(result.getTopClasses(), hasSize(2)); } @@ -120,7 +120,7 @@ public void testRegression() throws Exception { assertThat(results.getNumericValue(), equalTo(1.3)); PlainActionFuture failedFuture = new PlainActionFuture<>(); - model.confidence(fields, -1, failedFuture); + model.classificationProbability(fields, -1, failedFuture); ExecutionException ex = expectThrows(ExecutionException.class, failedFuture::get); assertThat(ex.getCause().getMessage(), equalTo("top result probabilities is only available for classification models")); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java index c9779abff8b2b..4fbf973aae7f0 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java @@ -9,34 +9,72 @@ import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.ClusterChangedEvent; +import org.elasticsearch.cluster.ClusterName; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.ClusterStateListener; +import org.elasticsearch.cluster.metadata.MetaData; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.ingest.IngestMetadata; +import org.elasticsearch.ingest.PipelineConfiguration; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ScalingExecutorBuilder; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; +import org.junit.After; import org.junit.Before; +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.nullValue; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.atMost; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; public class ModelLoadingServiceTests extends ESTestCase { - TrainedModelProvider trainedModelProvider; + private TrainedModelProvider trainedModelProvider; + private ThreadPool threadPool; + private ClusterService clusterService; @Before public void setUpComponents() { + threadPool = new TestThreadPool("ModelLoadingServiceTests", new ScalingExecutorBuilder(UTILITY_THREAD_POOL_NAME, + 1, 4, TimeValue.timeValueMinutes(10), "xpack.ml.utility_thread_pool")); trainedModelProvider = mock(TrainedModelProvider.class); + clusterService = mock(ClusterService.class); + doAnswer((invocationOnMock) -> null).when(clusterService).addListener(any(ClusterStateListener.class)); + when(clusterService.state()).thenReturn(ClusterState.builder(new ClusterName("_name")).build()); + } + + @After + public void terminateThreadPool() { + terminate(threadPool); } - public void testGetModelAndCache() throws Exception { + public void testGetCachedModels() throws Exception { String model1 = "test-load-model-1"; String model2 = "test-load-model-2"; String model3 = "test-load-model-3"; @@ -44,13 +82,15 @@ public void testGetModelAndCache() throws Exception { withTrainedModel(model2, 0); withTrainedModel(model3, 0); - ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider); + ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, threadPool, clusterService); + + modelLoadingService.clusterChanged(ingestChangedEvent(model1, model2, model3)); String[] modelIds = new String[]{model1, model2, model3}; for(int i = 0; i < 10; i++) { String model = modelIds[i%3]; PlainActionFuture future = new PlainActionFuture<>(); - modelLoadingService.getModelAndCache(model, 0, future); + modelLoadingService.getModel(model, 0, future); assertThat(future.get(), is(not(nullValue()))); } @@ -59,16 +99,55 @@ public void testGetModelAndCache() throws Exception { verify(trainedModelProvider, times(1)).getTrainedModel(eq(model3), eq(0L), any()); } - public void testGetMissingModelAndCache() { + public void testGetCachedMissingModel() throws Exception { + String model = "test-load-cached-missing-model"; + withMissingModel(model, 0); + + ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, threadPool, clusterService); + modelLoadingService.clusterChanged(ingestChangedEvent(model)); + + PlainActionFuture future = new PlainActionFuture<>(); + modelLoadingService.getModel(model, 0, future); + + try { + future.get(); + fail("Should not have succeeded in loaded model"); + } catch (Exception ex) { + assertThat(ex.getCause().getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model, 0))); + } + + verify(trainedModelProvider, atMost(2)).getTrainedModel(eq(model), eq(0L), any()); + } + + public void testGetMissingModel() { String model = "test-load-missing-model"; withMissingModel(model, 0); - ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider); + ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, threadPool, clusterService); + + PlainActionFuture future = new PlainActionFuture<>(); + modelLoadingService.getModel(model, 0, future); + try { + future.get(); + fail("Should not have succeeded"); + } catch (Exception ex) { + assertThat(ex.getCause().getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model, 0))); + } + } + + public void testGetModelEagerly() throws Exception { + String model = "test-get-model-eagerly"; + withTrainedModel(model, 0); + + ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, threadPool, clusterService); + + for(int i = 0; i < 3; i++) { + PlainActionFuture future = new PlainActionFuture<>(); + modelLoadingService.getModel(model, 0, future); + assertThat(future.get(), is(not(nullValue()))); + } - modelLoadingService.getModelAndCache(model, 0, ActionListener.wrap( - m -> fail("Should not have succeeded"), - f -> assertThat(f.getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, model, 0))) - )); + verify(trainedModelProvider, times(3)).getTrainedModel(eq(model), eq(0L), any()); } @SuppressWarnings("unchecked") @@ -101,4 +180,33 @@ private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String .setModelType("binary_decision_tree") .setModelVersion(modelVersion); } + + private static ClusterChangedEvent ingestChangedEvent(String... modelId) throws IOException { + ClusterChangedEvent event = mock(ClusterChangedEvent.class); + when(event.changedCustomMetaDataSet()).thenReturn(Collections.singleton(IngestMetadata.TYPE)); + when(event.state()).thenReturn(buildClusterStateWithModelReferences(modelId)); + return event; + } + + private static ClusterState buildClusterStateWithModelReferences(String... modelId) throws IOException { + Map configurations = new HashMap<>(modelId.length); + for (String id : modelId) { + configurations.put("pipeline_with_model_" + id, newConfigurationWithInferenceProcessor(id)); + } + IngestMetadata ingestMetadata = new IngestMetadata(configurations); + + return ClusterState.builder(new ClusterName("_name")) + .metaData(MetaData.builder().putCustom(IngestMetadata.TYPE, ingestMetadata)) + .build(); + } + + private static PipelineConfiguration newConfigurationWithInferenceProcessor(String modelId) throws IOException { + try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(Collections.singletonMap("processors", + Collections.singletonList( + Collections.singletonMap(InferenceProcessor.TYPE, + Collections.singletonMap(InferenceProcessor.MODEL_ID, + modelId)))))) { + return new PipelineConfiguration("pipeline_with_model_" + modelId, BytesReference.bytes(xContentBuilder), XContentType.JSON); + } + } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java index 328f4126f8fff..d167f6360b4c2 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java @@ -55,7 +55,6 @@ public void createComponents() throws Exception { waitForMlTemplates(); } - @SuppressWarnings("unchecked") public void testInferModels() throws Exception { String modelId1 = "test-load-models-regression"; String modelId2 = "test-load-models-classification"; @@ -247,11 +246,6 @@ private static TrainedModel buildRegression() { .build(); } - - private static TrainedModelConfig buildTrainedModelConfig(String modelId, long modelVersion) { - return buildTrainedModelConfigBuilder(modelId, modelVersion).build(Version.CURRENT); - } - private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId, long modelVersion) { return TrainedModelConfig.builder() .setCreatedBy("ml_test") From e2a98cd2a2001cf4b1be09920bca5ead41e41ed9 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Thu, 3 Oct 2019 09:24:01 -0400 Subject: [PATCH 04/13] Making inference results more general for response object --- .../ml/action/TransportInferModelAction.java | 14 +- .../ClassificationInferenceResults.java | 155 ++++++++++++++++++ .../ml/inference/action/InferModelAction.java | 40 +++-- .../ml/inference/action/InferenceResults.java | 150 +---------------- .../action/RegressionInferenceResults.java | 54 ++++++ .../action/SingleValueInferenceResults.java | 53 ++++++ .../inference/loadingservice/LocalModel.java | 63 ++++--- .../ml/inference/loadingservice/Model.java | 6 +- .../loadingservice/ModelLoadingService.java | 15 +- .../ClassificationInferenceResultsTests.java | 38 +++++ .../action/InferModelActionResponseTests.java | 18 +- .../action/InferenceResultsTests.java | 36 ---- .../RegressionInferenceResultsTests.java | 27 +++ .../loadingservice/LocalModelTests.java | 57 +++---- .../integration/ModelInferenceActionIT.java | 35 ++-- 15 files changed, 489 insertions(+), 272 deletions(-) create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/ClassificationInferenceResults.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/RegressionInferenceResults.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/SingleValueInferenceResults.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/ClassificationInferenceResultsTests.java delete mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/InferenceResultsTests.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/RegressionInferenceResultsTests.java diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java index a6668da4ee786..b5e81c52aa316 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java @@ -19,7 +19,6 @@ import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor; -import java.util.List; public class TransportInferModelAction extends HandledTransportAction { @@ -40,14 +39,9 @@ public TransportInferModelAction(String actionName, @Override protected void doExecute(Task task, InferModelAction.Request request, ActionListener listener) { - ActionListener> inferenceCompleteListener = ActionListener.wrap( - inferenceResponse -> listener.onResponse(new InferModelAction.Response(inferenceResponse)), - listener::onFailure - ); - ActionListener getModelListener = ActionListener.wrap( model -> { - TypedChainTaskExecutor typedChainTaskExecutor = + TypedChainTaskExecutor> typedChainTaskExecutor = new TypedChainTaskExecutor<>(client.threadPool().executor(ThreadPool.Names.SAME), // run through all tasks r -> true, @@ -63,7 +57,11 @@ protected void doExecute(Task task, InferModelAction.Request request, ActionList typedChainTaskExecutor.add(chainedTask -> model.infer(stringObjectMap, chainedTask)) ); } - typedChainTaskExecutor.execute(inferenceCompleteListener); + typedChainTaskExecutor.execute(ActionListener.wrap( + inferenceResultsInterfaces -> + listener.onResponse(new InferModelAction.Response(inferenceResultsInterfaces, model.getResultsType())), + listener::onFailure + )); }, listener::onFailure ); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/ClassificationInferenceResults.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/ClassificationInferenceResults.java new file mode 100644 index 0000000000000..591151186a3c8 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/ClassificationInferenceResults.java @@ -0,0 +1,155 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.action; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +public class ClassificationInferenceResults extends SingleValueInferenceResults { + + public static final String RESULT_TYPE = "classification"; + public static final ParseField CLASSIFICATION_LABEL = new ParseField("classification_label"); + public static final ParseField TOP_CLASSES = new ParseField("top_classes"); + + private final String classificationLabel; + private final List topClasses; + + public ClassificationInferenceResults(double value, String classificationLabel, List topClasses) { + super(value); + this.classificationLabel = classificationLabel; + this.topClasses = topClasses == null ? null : Collections.unmodifiableList(topClasses); + } + + public ClassificationInferenceResults(StreamInput in) throws IOException { + super(in); + this.classificationLabel = in.readOptionalString(); + if (in.readBoolean()) { + this.topClasses = Collections.unmodifiableList(in.readList(TopClassEntry::new)); + } else { + this.topClasses = null; + } + } + + public String getClassificationLabel() { + return classificationLabel; + } + + public List getTopClasses() { + return topClasses; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeOptionalString(classificationLabel); + out.writeBoolean(topClasses != null); + if (topClasses != null) { + out.writeCollection(topClasses); + } + } + + @Override + XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException { + if (classificationLabel != null) { + builder.field(CLASSIFICATION_LABEL.getPreferredName(), classificationLabel); + } + if (topClasses != null) { + builder.field(TOP_CLASSES.getPreferredName(), topClasses); + } + return builder; + } + + @Override + public boolean equals(Object object) { + if (object == this) { return true; } + if (object == null || getClass() != object.getClass()) { return false; } + ClassificationInferenceResults that = (ClassificationInferenceResults) object; + return Objects.equals(value(), that.value()) && + Objects.equals(classificationLabel, that.classificationLabel) && + Objects.equals(topClasses, that.topClasses); + } + + @Override + public int hashCode() { + return Objects.hash(value(), classificationLabel, topClasses); + } + + @Override + public String resultType() { + return RESULT_TYPE; + } + + @Override + public String valueAsString() { + return classificationLabel == null ? super.valueAsString() : classificationLabel; + } + + public static class TopClassEntry implements ToXContentObject, Writeable { + + public final ParseField LABEL = new ParseField("label"); + public final ParseField PROBABILITY = new ParseField("probability"); + + private final String label; + private final double probability; + + public TopClassEntry(String label, Double probability) { + this.label = ExceptionsHelper.requireNonNull(label, LABEL); + this.probability = ExceptionsHelper.requireNonNull(probability, PROBABILITY); + } + + public TopClassEntry(StreamInput in) throws IOException { + this.label = in.readString(); + this.probability = in.readDouble(); + } + + public String getLabel() { + return label; + } + + public double getProbability() { + return probability; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(label); + out.writeDouble(probability); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(LABEL.getPreferredName(), label); + builder.field(PROBABILITY.getPreferredName(), probability); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object object) { + if (object == this) { return true; } + if (object == null || getClass() != object.getClass()) { return false; } + TopClassEntry that = (TopClassEntry) object; + return Objects.equals(label, that.label) && + Objects.equals(probability, that.probability); + } + + @Override + public int hashCode() { + return Objects.hash(label, probability); + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/InferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/InferModelAction.java index 7446f01fe2944..756800798e6ba 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/InferModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/InferModelAction.java @@ -13,9 +13,10 @@ import org.elasticsearch.client.ElasticsearchClient; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; -import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; @@ -45,8 +46,9 @@ public Request(String modelId, long modelVersion) { public Request(String modelId, long modelVersion, List> objectsToInfer, Integer topClasses) { this.modelId = modelId; this.modelVersion = modelVersion; - this.objectsToInfer = objectsToInfer == null ? Collections.emptyList() : - Collections.unmodifiableList(new ArrayList<>(objectsToInfer)); + this.objectsToInfer = objectsToInfer == null ? + Collections.emptyList() : + Collections.unmodifiableList(objectsToInfer); this.cacheModel = true; this.topClasses = topClasses; } @@ -54,7 +56,7 @@ public Request(String modelId, long modelVersion, List> obje public Request(String modelId, long modelVersion, Map objectToInfer, Integer topClasses) { this(modelId, modelVersion, - objectToInfer == null ? Collections.emptyList() : Collections.singletonList(objectToInfer), + objectToInfer == null ? null : Arrays.asList(objectToInfer), topClasses); } @@ -129,24 +131,40 @@ public RequestBuilder(ElasticsearchClient client, Request request) { public static class Response extends ActionResponse { - private final List inferenceResponse; + private final List> inferenceResponse; + private final String resultsType; - public Response(List inferenceResponse) { + public Response(List> inferenceResponse, String resultsType) { super(); - this.inferenceResponse = Collections.unmodifiableList(inferenceResponse); + this.resultsType = ExceptionsHelper.requireNonNull(resultsType, "resultsType"); + this.inferenceResponse = inferenceResponse == null ? + Collections.emptyList() : + Collections.unmodifiableList(inferenceResponse); } public Response(StreamInput in) throws IOException { super(in); - this.inferenceResponse = Collections.unmodifiableList(in.readList(InferenceResults::new)); + this.resultsType = in.readString(); + if(resultsType.equals(ClassificationInferenceResults.RESULT_TYPE)) { + this.inferenceResponse = Collections.unmodifiableList(in.readList(ClassificationInferenceResults::new)); + } else if (this.resultsType.equals(RegressionInferenceResults.RESULT_TYPE)) { + this.inferenceResponse = Collections.unmodifiableList(in.readList(RegressionInferenceResults::new)); + } else { + throw new IOException("Unrecognized result type [" + resultsType + "]"); + } } - public List getInferenceResponse() { + public List> getInferenceResponse() { return inferenceResponse; } + public String getResultsType() { + return resultsType; + } + @Override public void writeTo(StreamOutput out) throws IOException { + out.writeString(resultsType); out.writeCollection(inferenceResponse); } @@ -155,12 +173,12 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; InferModelAction.Response that = (InferModelAction.Response) o; - return Objects.equals(inferenceResponse, that.inferenceResponse); + return Objects.equals(resultsType, that.resultsType) && Objects.equals(inferenceResponse, that.inferenceResponse); } @Override public int hashCode() { - return Objects.hash(inferenceResponse); + return Objects.hash(resultsType, inferenceResponse); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/InferenceResults.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/InferenceResults.java index e9ba0721c9ed7..98fb317f8c158 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/InferenceResults.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/InferenceResults.java @@ -5,157 +5,15 @@ */ package org.elasticsearch.xpack.ml.inference.action; -import org.elasticsearch.common.ParseField; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.ToXContentObject; -import org.elasticsearch.common.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; -import java.io.IOException; -import java.util.Collections; -import java.util.List; -import java.util.Objects; +public interface InferenceResults extends ToXContentObject, Writeable { -public class InferenceResults implements ToXContentObject, Writeable { + String resultType(); - public final ParseField TOP_CLASSES = new ParseField("top_classes"); - public final ParseField NUMERIC_VALUE = new ParseField("numeric_value"); - public final ParseField CLASSIFICATION_LABEL = new ParseField("classification_label"); + T value(); - private final double numericValue; - private final String classificationLabel; - private final List topClasses; + String valueAsString(); - public static InferenceResults valueOnly(double value) { - return new InferenceResults(value, null, null); - } - - public static InferenceResults valueAndLabel(double value, String classificationLabel) { - return new InferenceResults(value, classificationLabel, null); - } - - public InferenceResults(Double numericValue, String classificationLabel, List topClasses) { - this.numericValue = ExceptionsHelper.requireNonNull(numericValue, NUMERIC_VALUE); - this.classificationLabel = classificationLabel; - this.topClasses = topClasses == null ? null : Collections.unmodifiableList(topClasses); - } - - public InferenceResults(StreamInput in) throws IOException { - this.numericValue = in.readDouble(); - this.classificationLabel = in.readOptionalString(); - if (in.readBoolean()) { - this.topClasses = Collections.unmodifiableList(in.readList(TopClassEntry::new)); - } else { - this.topClasses = null; - } - } - - public double getNumericValue() { - return numericValue; - } - - public String getClassificationLabel() { - return classificationLabel; - } - - public List getTopClasses() { - return topClasses; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeDouble(numericValue); - out.writeOptionalString(classificationLabel); - out.writeBoolean(topClasses != null); - if (topClasses != null) { - out.writeList(topClasses); - } - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(NUMERIC_VALUE.getPreferredName(), numericValue); - if (classificationLabel != null) { - builder.field(CLASSIFICATION_LABEL.getPreferredName(), classificationLabel); - } - if (topClasses != null) { - builder.field(TOP_CLASSES.getPreferredName(), topClasses); - } - builder.endObject(); - return null; - } - - @Override - public boolean equals(Object object) { - if (object == this) { return true; } - if (object == null || getClass() != object.getClass()) { return false; } - InferenceResults that = (InferenceResults) object; - return Objects.equals(numericValue, that.numericValue) && - Objects.equals(classificationLabel, that.classificationLabel) && - Objects.equals(topClasses, that.topClasses); - } - - @Override - public int hashCode() { - return Objects.hash(numericValue, classificationLabel, topClasses); - } - - public static class TopClassEntry implements ToXContentObject, Writeable { - - public final ParseField LABEL = new ParseField("label"); - public final ParseField PROBABILITY = new ParseField("probability"); - - private final String label; - private final double probability; - - public TopClassEntry(String label, Double probability) { - this.label = ExceptionsHelper.requireNonNull(label, LABEL); - this.probability = ExceptionsHelper.requireNonNull(probability, PROBABILITY); - } - - public TopClassEntry(StreamInput in) throws IOException { - this.label = in.readString(); - this.probability = in.readDouble(); - } - - public String getLabel() { - return label; - } - - public double getProbability() { - return probability; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(label); - out.writeDouble(probability); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(LABEL.getPreferredName(), label); - builder.field(PROBABILITY.getPreferredName(), probability); - builder.endObject(); - return builder; - } - - @Override - public boolean equals(Object object) { - if (object == this) { return true; } - if (object == null || getClass() != object.getClass()) { return false; } - TopClassEntry that = (TopClassEntry) object; - return Objects.equals(label, that.label) && - Objects.equals(probability, that.probability); - } - - @Override - public int hashCode() { - return Objects.hash(label, probability); - } - } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/RegressionInferenceResults.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/RegressionInferenceResults.java new file mode 100644 index 0000000000000..38841f8e41d8e --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/RegressionInferenceResults.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.action; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +public class RegressionInferenceResults extends SingleValueInferenceResults { + + public static final String RESULT_TYPE = "regression"; + + public RegressionInferenceResults(double value) { + super(value); + } + + public RegressionInferenceResults(StreamInput in) throws IOException { + super(in.readDouble()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + } + + @Override + XContentBuilder innerToXContent(XContentBuilder builder, Params params) { + return builder; + } + + @Override + public String resultType() { + return RESULT_TYPE; + } + + @Override + public boolean equals(Object object) { + if (object == this) { return true; } + if (object == null || getClass() != object.getClass()) { return false; } + RegressionInferenceResults that = (RegressionInferenceResults) object; + return Objects.equals(value(), that.value()); + } + + @Override + public int hashCode() { + return Objects.hash(value()); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/SingleValueInferenceResults.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/SingleValueInferenceResults.java new file mode 100644 index 0000000000000..d3dae1563d7be --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/SingleValueInferenceResults.java @@ -0,0 +1,53 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.action; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; + +public abstract class SingleValueInferenceResults implements InferenceResults { + + public final ParseField VALUE = new ParseField("value"); + + private final double value; + + SingleValueInferenceResults(StreamInput in) throws IOException { + value = in.readDouble(); + } + + SingleValueInferenceResults(double value) { + this.value = value; + } + + @Override + public Double value() { + return value; + } + + @Override + public String valueAsString() { + return String.valueOf(value); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDouble(value); + } + + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(VALUE.getPreferredName(), value); + innerToXContent(builder, params); + builder.endObject(); + return builder; + } + + abstract XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws IOException; +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java index 9701a31dba17c..accefdc94c358 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java @@ -11,7 +11,9 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.ml.inference.action.ClassificationInferenceResults; import org.elasticsearch.xpack.ml.inference.action.InferenceResults; +import org.elasticsearch.xpack.ml.inference.action.RegressionInferenceResults; import java.util.ArrayList; import java.util.Comparator; @@ -23,34 +25,55 @@ public class LocalModel implements Model { private final TrainedModelDefinition trainedModelDefinition; - public LocalModel(TrainedModelDefinition trainedModelDefinition) { + private final String modelId; + public LocalModel(String modelId, TrainedModelDefinition trainedModelDefinition) { this.trainedModelDefinition = trainedModelDefinition; + this.modelId = modelId; } @Override - public void infer(Map fields, ActionListener listener) { + public String getResultsType() { + switch (trainedModelDefinition.getTrainedModel().targetType()) { + case CLASSIFICATION: + return ClassificationInferenceResults.RESULT_TYPE; + case REGRESSION: + return RegressionInferenceResults.RESULT_TYPE; + default: + throw ExceptionsHelper.badRequestException("Model [{}] has unsupported target type [{}]", + modelId, + trainedModelDefinition.getTrainedModel().targetType()); + } + } + + @Override + public void infer(Map fields, ActionListener> listener) { trainedModelDefinition.getPreProcessors().forEach(preProcessor -> preProcessor.process(fields)); double value = trainedModelDefinition.getTrainedModel().infer(fields); - if (trainedModelDefinition.getTrainedModel().targetType() == TargetType.CLASSIFICATION && - trainedModelDefinition.getTrainedModel().classificationLabels() != null) { - assert value == Math.rint(value); - int classIndex = Double.valueOf(value).intValue(); - if (classIndex < 0 || classIndex >= trainedModelDefinition.getTrainedModel().classificationLabels().size()) { - listener.onFailure(new ElasticsearchStatusException("model returned classification [{}] which is invalid given labels {}", - RestStatus.INTERNAL_SERVER_ERROR, - classIndex, - trainedModelDefinition.getTrainedModel().classificationLabels())); - return; + InferenceResults inferenceResults; + if (trainedModelDefinition.getTrainedModel().targetType() == TargetType.CLASSIFICATION) { + String classificationLabel = null; + if (trainedModelDefinition.getTrainedModel().classificationLabels() != null) { + assert value == Math.rint(value); + int classIndex = Double.valueOf(value).intValue(); + if (classIndex < 0 || classIndex >= trainedModelDefinition.getTrainedModel().classificationLabels().size()) { + listener.onFailure(new ElasticsearchStatusException( + "model returned classification [{}] which is invalid given labels {}", + RestStatus.INTERNAL_SERVER_ERROR, + classIndex, + trainedModelDefinition.getTrainedModel().classificationLabels())); + return; + } + classificationLabel = trainedModelDefinition.getTrainedModel().classificationLabels().get(classIndex); } - listener.onResponse(InferenceResults.valueAndLabel(value, - trainedModelDefinition.getTrainedModel().classificationLabels().get(classIndex))); - return; + inferenceResults = new ClassificationInferenceResults(value, classificationLabel, null); + } else { + inferenceResults = new RegressionInferenceResults(value); } - listener.onResponse(InferenceResults.valueOnly(value)); + listener.onResponse(inferenceResults); } @Override - public void classificationProbability(Map fields, int topN, ActionListener listener) { + public void classificationProbability(Map fields, int topN, ActionListener> listener) { if (topN == 0) { infer(fields, listener); return; @@ -83,13 +106,13 @@ public void classificationProbability(Map fields, int topN, Acti trainedModelDefinition.getTrainedModel().classificationLabels(); int count = topN < 0 ? probabilities.size() : topN; - List topClassEntries = new ArrayList<>(count); + List topClassEntries = new ArrayList<>(count); for(int i = 0; i < count; i++) { int idx = sortedIndices[i]; - topClassEntries.add(new InferenceResults.TopClassEntry(labels.get(idx), probabilities.get(idx))); + topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(labels.get(idx), probabilities.get(idx))); } - listener.onResponse(new InferenceResults(((Number)sortedIndices[0]).doubleValue(), + listener.onResponse(new ClassificationInferenceResults(((Number)sortedIndices[0]).doubleValue(), trainedModelDefinition.getTrainedModel().classificationLabels() == null ? null : trainedModelDefinition.getTrainedModel().classificationLabels().get(sortedIndices[0]), diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java index 6e8e8f41fb14e..5cabd3773f28f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java @@ -12,7 +12,9 @@ public interface Model { - void infer(Map fields, ActionListener listener); + String getResultsType(); - void classificationProbability(Map fields, int topN, ActionListener listener); + void infer(Map fields, ActionListener> listener); + + void classificationProbability(Map fields, int topN, ActionListener> listener); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java index 687fcce28f4e1..989d87a7e79f2 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -45,7 +45,6 @@ public ModelLoadingService(TrainedModelProvider trainedModelProvider, ThreadPool this.provider = trainedModelProvider; this.threadPool = threadPool; clusterService.addListener(this); - // TODO should we load state here? Or will it get the applied state because it is a listener? } public void getModel(String modelId, long modelVersion, ActionListener modelActionListener) { @@ -62,7 +61,8 @@ public void getModel(String modelId, long modelVersion, ActionListener mo // by a simulated pipeline logger.debug("[{}] version [{}] not actively loading, eager loading without cache", modelId, modelVersion); provider.getTrainedModel(modelId, modelVersion, ActionListener.wrap( - trainedModelConfig -> modelActionListener.onResponse(new LocalModel(trainedModelConfig.getDefinition())), + trainedModelConfig -> + modelActionListener.onResponse(new LocalModel(trainedModelConfig.getModelId(), trainedModelConfig.getDefinition())), modelActionListener::onFailure )); } else { @@ -115,7 +115,7 @@ private void loadModel(String modelKey, String modelId, long modelVersion) { private void handleLoadSuccess(String modelKey, TrainedModelConfig trainedModelConfig) { Queue> listeners; - Model loadedModel = new LocalModel(trainedModelConfig.getDefinition()); + Model loadedModel = new LocalModel(trainedModelConfig.getModelId(), trainedModelConfig.getDefinition()); synchronized (loadingListeners) { listeners = loadingListeners.remove(modelKey); // If there is no loadingListener that means the loading was canceled and the listener was already notified as such @@ -233,9 +233,12 @@ private static Set getReferencedModelKeys(IngestMetadata ingestMetadata) if (processor instanceof Map) { Object processorConfig = ((Map)processor).get(InferenceProcessor.TYPE); if (processorConfig instanceof Map) { - String modelId = ((Map)processorConfig).get(InferenceProcessor.MODEL_ID).toString(); - // TODO also read model version - allReferencedModelKeys.add(modelKey(modelId, 0)); + Object modelId = ((Map)processorConfig).get(InferenceProcessor.MODEL_ID); + if (modelId != null) { + assert modelId instanceof String; + // TODO also read model version + allReferencedModelKeys.add(modelKey(modelId.toString(), 0)); + } } } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/ClassificationInferenceResultsTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/ClassificationInferenceResultsTests.java new file mode 100644 index 0000000000000..2706ca24f4049 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/ClassificationInferenceResultsTests.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class ClassificationInferenceResultsTests extends AbstractWireSerializingTestCase { + + public static ClassificationInferenceResults createRandomResults() { + return new ClassificationInferenceResults(randomDouble(), + randomBoolean() ? null : randomAlphaOfLength(10), + randomBoolean() ? null : + Stream.generate(ClassificationInferenceResultsTests::createRandomClassEntry) + .limit(randomIntBetween(0, 10)) + .collect(Collectors.toList())); + } + + private static ClassificationInferenceResults.TopClassEntry createRandomClassEntry() { + return new ClassificationInferenceResults.TopClassEntry(randomAlphaOfLength(10), randomDouble()); + } + + @Override + protected ClassificationInferenceResults createTestInstance() { + return createRandomResults(); + } + + @Override + protected Writeable.Reader instanceReader() { + return ClassificationInferenceResults::new; + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/InferModelActionResponseTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/InferModelActionResponseTests.java index dbf9ce0708c41..43542097fe9fc 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/InferModelActionResponseTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/InferModelActionResponseTests.java @@ -16,9 +16,23 @@ public class InferModelActionResponseTests extends AbstractWireSerializingTestCa @Override protected Response createTestInstance() { - return new Response(Stream.generate(InferenceResultsTests::createRandomResults) + String resultType = randomFrom(ClassificationInferenceResults.RESULT_TYPE, RegressionInferenceResults.RESULT_TYPE); + return new Response( + Stream.generate(() -> randomInferenceResult(resultType)) .limit(randomIntBetween(0, 10)) - .collect(Collectors.toList())); + .collect(Collectors.toList()), + resultType); + } + + private static InferenceResults randomInferenceResult(String resultType) { + if (resultType.equals(ClassificationInferenceResults.RESULT_TYPE)) { + return ClassificationInferenceResultsTests.createRandomResults(); + } else if (resultType.equals(RegressionInferenceResults.RESULT_TYPE)) { + return RegressionInferenceResultsTests.createRandomResults(); + } else { + fail("unexpected result type [" + resultType + "]"); + return null; + } } @Override diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/InferenceResultsTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/InferenceResultsTests.java deleted file mode 100644 index becf96ee6bc55..0000000000000 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/InferenceResultsTests.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ -package org.elasticsearch.xpack.ml.inference.action; - -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; - -import java.util.stream.Collectors; -import java.util.stream.Stream; - -public class InferenceResultsTests extends AbstractWireSerializingTestCase { - - public static InferenceResults createRandomResults() { - return new InferenceResults(randomDouble(), - randomBoolean() ? null : randomAlphaOfLength(10), - randomBoolean() ? null : - Stream.generate(InferenceResultsTests::createRandomClassEntry).limit(randomIntBetween(0, 10)).collect(Collectors.toList())); - } - - private static InferenceResults.TopClassEntry createRandomClassEntry() { - return new InferenceResults.TopClassEntry(randomAlphaOfLength(10), randomDouble()); - } - - @Override - protected InferenceResults createTestInstance() { - return createRandomResults(); - } - - @Override - protected Writeable.Reader instanceReader() { - return InferenceResults::new; - } -} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/RegressionInferenceResultsTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/RegressionInferenceResultsTests.java new file mode 100644 index 0000000000000..2654f86d87cc0 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/RegressionInferenceResultsTests.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.inference.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + + +public class RegressionInferenceResultsTests extends AbstractWireSerializingTestCase { + + public static RegressionInferenceResults createRandomResults() { + return new RegressionInferenceResults(randomDouble()); + } + + @Override + protected RegressionInferenceResults createTestInstance() { + return createRandomResults(); + } + + @Override + protected Writeable.Reader instanceReader() { + return RegressionInferenceResults::new; + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java index 4da00f9a7ba5d..6d0dd5e10d357 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java @@ -16,6 +16,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode; +import org.elasticsearch.xpack.ml.inference.action.ClassificationInferenceResults; import org.elasticsearch.xpack.ml.inference.action.InferenceResults; import java.util.Arrays; @@ -28,77 +29,77 @@ import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; -import static org.hamcrest.Matchers.nullValue; public class LocalModelTests extends ESTestCase { public void testClassificationInfer() throws Exception { + String modelId = "classification_model"; TrainedModelDefinition definition = new TrainedModelDefinition.Builder() .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) .setTrainedModel(buildClassification(false)) .build(); - Model model = new LocalModel(definition); + Model model = new LocalModel(modelId, definition); Map fields = new HashMap<>() {{ put("foo", 1.0); put("bar", 0.5); put("categorical", "dog"); }}; - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture> future = new PlainActionFuture<>(); model.infer(fields, future); - InferenceResults result = future.get(); - assertThat(result.getNumericValue(), equalTo(0.0)); - assertThat(result.getClassificationLabel(), is(nullValue())); - assertThat(result.getTopClasses(), is(nullValue())); + InferenceResults result = future.get(); + assertThat(result.value(), equalTo(0.0)); + assertThat(result.valueAsString(), is("0.0")); future = new PlainActionFuture<>(); model.classificationProbability(fields, 0, future); result = future.get(); - assertThat(result.getNumericValue(), equalTo(0.0)); - assertThat(result.getClassificationLabel(), is(nullValue())); - assertThat(result.getTopClasses(), is(nullValue())); + assertThat(result.value(), equalTo(0.0)); + assertThat(result.valueAsString(), is("0.0")); future = new PlainActionFuture<>(); model.classificationProbability(fields, 1, future); - result = future.get(); - assertThat(result.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001)); - assertThat(result.getTopClasses().get(0).getLabel(), equalTo("0")); + ClassificationInferenceResults classificationResult = (ClassificationInferenceResults)future.get(); + assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001)); + assertThat(classificationResult.getTopClasses().get(0).getLabel(), equalTo("0")); // Test with labels definition = new TrainedModelDefinition.Builder() .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) .setTrainedModel(buildClassification(true)) .build(); - model = new LocalModel(definition); + model = new LocalModel(modelId, definition); future = new PlainActionFuture<>(); model.infer(fields, future); result = future.get(); - assertThat(result.getNumericValue(), equalTo(0.0)); - assertThat(result.getClassificationLabel(), equalTo("not_to_be")); + assertThat(result.value(), equalTo(0.0)); + assertThat(result.valueAsString(), equalTo("not_to_be")); future = new PlainActionFuture<>(); model.classificationProbability(fields, 0, future); result = future.get(); - assertThat(result.getNumericValue(), equalTo(0.0)); - assertThat(result.getClassificationLabel(), equalTo("not_to_be")); - assertThat(result.getTopClasses(), is(nullValue())); + assertThat(result.value(), equalTo(0.0)); + assertThat(result.valueAsString(), equalTo("not_to_be")); future = new PlainActionFuture<>(); model.classificationProbability(fields, 1, future); result = future.get(); - assertThat(result.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001)); - assertThat(result.getTopClasses().get(0).getLabel(), equalTo("not_to_be")); + classificationResult = (ClassificationInferenceResults)result; + assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001)); + assertThat(classificationResult.getTopClasses().get(0).getLabel(), equalTo("not_to_be")); future = new PlainActionFuture<>(); model.classificationProbability(fields, 2, future); result = future.get(); - assertThat(result.getTopClasses(), hasSize(2)); + classificationResult = (ClassificationInferenceResults)result; + assertThat(classificationResult.getTopClasses(), hasSize(2)); future = new PlainActionFuture<>(); model.classificationProbability(fields, -1, future); result = future.get(); - assertThat(result.getTopClasses(), hasSize(2)); + classificationResult = (ClassificationInferenceResults)result; + assertThat(classificationResult.getTopClasses(), hasSize(2)); } public void testRegression() throws Exception { @@ -106,7 +107,7 @@ public void testRegression() throws Exception { .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) .setTrainedModel(buildRegression()) .build(); - Model model = new LocalModel(trainedModelDefinition); + Model model = new LocalModel("regression_model", trainedModelDefinition); Map fields = new HashMap<>() {{ put("foo", 1.0); @@ -114,12 +115,12 @@ public void testRegression() throws Exception { put("categorical", "dog"); }}; - PlainActionFuture future = new PlainActionFuture<>(); + PlainActionFuture> future = new PlainActionFuture<>(); model.infer(fields, future); - InferenceResults results = future.get(); - assertThat(results.getNumericValue(), equalTo(1.3)); + InferenceResults results = future.get(); + assertThat(results.value(), equalTo(1.3)); - PlainActionFuture failedFuture = new PlainActionFuture<>(); + PlainActionFuture> failedFuture = new PlainActionFuture<>(); model.classificationProbability(fields, -1, failedFuture); ExecutionException ex = expectThrows(ExecutionException.class, failedFuture::get); assertThat(ex.getCause().getMessage(), equalTo("top result probabilities is only available for classification models")); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java index d167f6360b4c2..4aa1025ffe099 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java @@ -24,6 +24,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; +import org.elasticsearch.xpack.ml.inference.action.ClassificationInferenceResults; import org.elasticsearch.xpack.ml.inference.action.InferModelAction; import org.elasticsearch.xpack.ml.inference.action.InferenceResults; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; @@ -109,40 +110,48 @@ public void testInferModels() throws Exception { // Test regression InferModelAction.Request request = new InferModelAction.Request(modelId1, 0, toInfer, null); InferModelAction.Response response = client().execute(InferModelAction.INSTANCE, request).actionGet(); - assertThat(response.getInferenceResponse().stream().map(InferenceResults::getNumericValue).collect(Collectors.toList()), + assertThat(response.getInferenceResponse().stream().map(InferenceResults::value).collect(Collectors.toList()), contains(1.3, 1.25)); request = new InferModelAction.Request(modelId1, 0, toInfer2, null); response = client().execute(InferModelAction.INSTANCE, request).actionGet(); - assertThat(response.getInferenceResponse().stream().map(InferenceResults::getNumericValue).collect(Collectors.toList()), + assertThat(response.getInferenceResponse().stream().map(InferenceResults::value).collect(Collectors.toList()), contains(1.65, 1.55)); // Test classification request = new InferModelAction.Request(modelId2, 0, toInfer, null); response = client().execute(InferModelAction.INSTANCE, request).actionGet(); - assertThat(response.getInferenceResponse().stream().map(InferenceResults::getClassificationLabel).collect(Collectors.toList()), + assertThat(response.getInferenceResponse().stream().map(InferenceResults::valueAsString).collect(Collectors.toList()), contains("not_to_be", "to_be")); // Get top classes request = new InferModelAction.Request(modelId2, 0, toInfer, 2); response = client().execute(InferModelAction.INSTANCE, request).actionGet(); + assertThat(response.getResultsType(), equalTo(ClassificationInferenceResults.RESULT_TYPE)); - assertThat(response.getInferenceResponse().get(0).getTopClasses().get(0).getLabel(), equalTo("not_to_be")); - assertThat(response.getInferenceResponse().get(0).getTopClasses().get(1).getLabel(), equalTo("to_be")); - assertThat(response.getInferenceResponse().get(0).getTopClasses().get(0).getProbability(), - greaterThan(response.getInferenceResponse().get(0).getTopClasses().get(1).getProbability())); + ClassificationInferenceResults classificationInferenceResults = + (ClassificationInferenceResults)response.getInferenceResponse().get(0); - assertThat(response.getInferenceResponse().get(1).getTopClasses().get(0).getLabel(), equalTo("to_be")); - assertThat(response.getInferenceResponse().get(1).getTopClasses().get(1).getLabel(), equalTo("not_to_be")); - assertThat(response.getInferenceResponse().get(1).getTopClasses().get(0).getProbability(), - greaterThan(response.getInferenceResponse().get(1).getTopClasses().get(1).getProbability())); + assertThat(classificationInferenceResults.getTopClasses().get(0).getLabel(), equalTo("not_to_be")); + assertThat(classificationInferenceResults.getTopClasses().get(1).getLabel(), equalTo("to_be")); + assertThat(classificationInferenceResults.getTopClasses().get(0).getProbability(), + greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability())); + + classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResponse().get(1); + assertThat(classificationInferenceResults.getTopClasses().get(0).getLabel(), equalTo("to_be")); + assertThat(classificationInferenceResults.getTopClasses().get(1).getLabel(), equalTo("not_to_be")); + assertThat(classificationInferenceResults.getTopClasses().get(0).getProbability(), + greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability())); // Test that top classes restrict the number returned request = new InferModelAction.Request(modelId2, 0, toInfer2, 1); response = client().execute(InferModelAction.INSTANCE, request).actionGet(); - assertThat(response.getInferenceResponse().get(0).getTopClasses(), hasSize(1)); - assertThat(response.getInferenceResponse().get(0).getTopClasses().get(0).getLabel(), equalTo("to_be")); + assertThat(response.getResultsType(), equalTo(ClassificationInferenceResults.RESULT_TYPE)); + + classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResponse().get(0); + assertThat(classificationInferenceResults.getTopClasses(), hasSize(1)); + assertThat(classificationInferenceResults.getTopClasses().get(0).getLabel(), equalTo("to_be")); } public void testInferMissingModel() { From 8140446f6c8ceb13e5355407db108466bb8a99f4 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Thu, 3 Oct 2019 16:19:59 -0400 Subject: [PATCH 05/13] moving things to core --- .../org/elasticsearch/xpack/core/XPackClientPlugin.java | 2 ++ .../xpack/core/ml}/action/InferModelAction.java | 5 ++++- .../results}/ClassificationInferenceResults.java | 2 +- .../core/ml/inference/results}/InferenceResults.java | 2 +- .../inference/results}/RegressionInferenceResults.java | 2 +- .../inference/results}/SingleValueInferenceResults.java | 2 +- .../core/ml}/action/InferModelActionRequestTests.java | 4 ++-- .../core/ml}/action/InferModelActionResponseTests.java | 9 +++++++-- .../results}/ClassificationInferenceResultsTests.java | 2 +- .../results}/RegressionInferenceResultsTests.java | 3 ++- .../java/org/elasticsearch/xpack/ml/MachineLearning.java | 2 +- .../xpack/ml/action/TransportInferModelAction.java | 4 ++-- .../xpack/ml/inference/loadingservice/LocalModel.java | 6 +++--- .../xpack/ml/inference/loadingservice/Model.java | 2 +- .../ml/inference/loadingservice/LocalModelTests.java | 4 ++-- .../xpack/ml/integration/ModelInferenceActionIT.java | 6 +++--- 16 files changed, 34 insertions(+), 23 deletions(-) rename x-pack/plugin/{ml/src/main/java/org/elasticsearch/xpack/ml/inference => core/src/main/java/org/elasticsearch/xpack/core/ml}/action/InferModelAction.java (95%) rename x-pack/plugin/{ml/src/main/java/org/elasticsearch/xpack/ml/inference/action => core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results}/ClassificationInferenceResults.java (98%) rename x-pack/plugin/{ml/src/main/java/org/elasticsearch/xpack/ml/inference/action => core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results}/InferenceResults.java (89%) rename x-pack/plugin/{ml/src/main/java/org/elasticsearch/xpack/ml/inference/action => core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results}/RegressionInferenceResults.java (96%) rename x-pack/plugin/{ml/src/main/java/org/elasticsearch/xpack/ml/inference/action => core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results}/SingleValueInferenceResults.java (96%) rename x-pack/plugin/{ml/src/test/java/org/elasticsearch/xpack/ml/inference => core/src/test/java/org/elasticsearch/xpack/core/ml}/action/InferModelActionRequestTests.java (89%) rename x-pack/plugin/{ml/src/test/java/org/elasticsearch/xpack/ml/inference => core/src/test/java/org/elasticsearch/xpack/core/ml}/action/InferModelActionResponseTests.java (73%) rename x-pack/plugin/{ml/src/test/java/org/elasticsearch/xpack/ml/inference/action => core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results}/ClassificationInferenceResultsTests.java (96%) rename x-pack/plugin/{ml/src/test/java/org/elasticsearch/xpack/ml/inference/action => core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results}/RegressionInferenceResultsTests.java (86%) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java index 19451e5833e94..1d5b7716135eb 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java @@ -98,6 +98,7 @@ import org.elasticsearch.xpack.core.ml.action.GetModelSnapshotsAction; import org.elasticsearch.xpack.core.ml.action.GetOverallBucketsAction; import org.elasticsearch.xpack.core.ml.action.GetRecordsAction; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.action.IsolateDatafeedAction; import org.elasticsearch.xpack.core.ml.action.KillProcessAction; import org.elasticsearch.xpack.core.ml.action.MlInfoAction; @@ -322,6 +323,7 @@ public List> getClientActions() { StartDataFrameAnalyticsAction.INSTANCE, EvaluateDataFrameAction.INSTANCE, EstimateMemoryUsageAction.INSTANCE, + InferModelAction.INSTANCE, // security ClearRealmCacheAction.INSTANCE, ClearRolesCacheAction.INSTANCE, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/InferModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java similarity index 95% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/InferModelAction.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java index 756800798e6ba..7b5d55285f92d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/InferModelAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.inference.action; +package org.elasticsearch.xpack.core.ml.action; import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionRequestBuilder; @@ -13,6 +13,9 @@ import org.elasticsearch.client.ElasticsearchClient; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/ClassificationInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java similarity index 98% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/ClassificationInferenceResults.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java index 591151186a3c8..a1e77c6bc122d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/ClassificationInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.inference.action; +package org.elasticsearch.xpack.core.ml.inference.results; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/InferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResults.java similarity index 89% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/InferenceResults.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResults.java index 98fb317f8c158..2c349945e1ced 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/InferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResults.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.inference.action; +package org.elasticsearch.xpack.core.ml.inference.results; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.ToXContentObject; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/RegressionInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java similarity index 96% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/RegressionInferenceResults.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java index 38841f8e41d8e..e4bb0b0136a25 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/RegressionInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.inference.action; +package org.elasticsearch.xpack.core.ml.inference.results; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/SingleValueInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java similarity index 96% rename from x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/SingleValueInferenceResults.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java index d3dae1563d7be..f3d006a4caa6a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/action/SingleValueInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.inference.action; +package org.elasticsearch.xpack.core.ml.inference.results; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/InferModelActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java similarity index 89% rename from x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/InferModelActionRequestTests.java rename to x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java index e29cc63e30bce..25142bc7acab9 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/InferModelActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java @@ -3,11 +3,11 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.inference.action; +package org.elasticsearch.xpack.core.ml.action; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.ml.inference.action.InferModelAction.Request; +import org.elasticsearch.xpack.core.ml.action.InferModelAction.Request; import java.util.function.Function; import java.util.stream.Collectors; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/InferModelActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java similarity index 73% rename from x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/InferModelActionResponseTests.java rename to x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java index 43542097fe9fc..4261b6b05aff5 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/InferModelActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java @@ -3,11 +3,16 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.inference.action; +package org.elasticsearch.xpack.core.ml.action; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.ml.inference.action.InferModelAction.Response; +import org.elasticsearch.xpack.core.ml.action.InferModelAction.Response; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResultsTests; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResultsTests; import java.util.stream.Collectors; import java.util.stream.Stream; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/ClassificationInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java similarity index 96% rename from x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/ClassificationInferenceResultsTests.java rename to x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java index 2706ca24f4049..c446c4d8225bc 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/ClassificationInferenceResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java @@ -3,7 +3,7 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.inference.action; +package org.elasticsearch.xpack.core.ml.inference.results; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/RegressionInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java similarity index 86% rename from x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/RegressionInferenceResultsTests.java rename to x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java index 2654f86d87cc0..479aca9d08c1b 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/action/RegressionInferenceResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java @@ -3,10 +3,11 @@ * or more contributor license agreements. Licensed under the Elastic License; * you may not use this file except in compliance with the Elastic License. */ -package org.elasticsearch.xpack.ml.inference.action; +package org.elasticsearch.xpack.core.ml.inference.results; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; public class RegressionInferenceResultsTests extends AbstractWireSerializingTestCase { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index f2b617a1d0a19..fa51b6ee79e45 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -95,6 +95,7 @@ import org.elasticsearch.xpack.core.ml.action.GetOverallBucketsAction; import org.elasticsearch.xpack.core.ml.action.GetRecordsAction; import org.elasticsearch.xpack.core.ml.action.IsolateDatafeedAction; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; import org.elasticsearch.xpack.core.ml.action.KillProcessAction; import org.elasticsearch.xpack.core.ml.action.MlInfoAction; import org.elasticsearch.xpack.core.ml.action.OpenJobAction; @@ -201,7 +202,6 @@ import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationProcessFactory; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult; -import org.elasticsearch.xpack.ml.inference.action.InferModelAction; import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import org.elasticsearch.xpack.ml.inference.persistence.InferenceInternalIndex; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java index b5e81c52aa316..cedd4bbfc6219 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java @@ -13,8 +13,8 @@ import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; -import org.elasticsearch.xpack.ml.inference.action.InferModelAction; -import org.elasticsearch.xpack.ml.inference.action.InferenceResults; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.ml.inference.loadingservice.Model; import org.elasticsearch.xpack.ml.inference.loadingservice.ModelLoadingService; import org.elasticsearch.xpack.ml.utils.TypedChainTaskExecutor; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java index accefdc94c358..a19283d49c9fc 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java @@ -11,9 +11,9 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; -import org.elasticsearch.xpack.ml.inference.action.ClassificationInferenceResults; -import org.elasticsearch.xpack.ml.inference.action.InferenceResults; -import org.elasticsearch.xpack.ml.inference.action.RegressionInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import java.util.ArrayList; import java.util.Comparator; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java index 5cabd3773f28f..398382d7d19e7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java @@ -6,7 +6,7 @@ package org.elasticsearch.xpack.ml.inference.loadingservice; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.xpack.ml.inference.action.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import java.util.Map; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java index 6d0dd5e10d357..4cfc782a4658b 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java @@ -16,8 +16,8 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode; -import org.elasticsearch.xpack.ml.inference.action.ClassificationInferenceResults; -import org.elasticsearch.xpack.ml.inference.action.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import java.util.Arrays; import java.util.HashMap; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java index 4aa1025ffe099..34ac5732dc362 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java @@ -24,9 +24,9 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; -import org.elasticsearch.xpack.ml.inference.action.ClassificationInferenceResults; -import org.elasticsearch.xpack.ml.inference.action.InferModelAction; -import org.elasticsearch.xpack.ml.inference.action.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.action.InferModelAction; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.junit.Before; From f8877815ca75411a1a9ce9784882b97bf9dd29b4 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Fri, 4 Oct 2019 11:06:17 -0400 Subject: [PATCH 06/13] fixing transport handler --- .../xpack/ml/action/TransportInferModelAction.java | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java index cedd4bbfc6219..597caa41fc557 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java @@ -26,12 +26,11 @@ public class TransportInferModelAction extends HandledTransportAction Date: Sun, 6 Oct 2019 15:35:04 -0400 Subject: [PATCH 07/13] addressing PR comments --- .../core/ml/action/InferModelAction.java | 56 +++++++------------ .../ClassificationInferenceResults.java | 37 +++++------- .../action/InferModelActionRequestTests.java | 24 ++++++-- .../ml/action/TransportInferModelAction.java | 2 +- .../inference/ingest/InferenceProcessor.java | 1 + .../loadingservice/LocalModelTests.java | 4 +- .../integration/ModelInferenceActionIT.java | 24 ++++---- 7 files changed, 68 insertions(+), 80 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java index 7b5d55285f92d..940e892c48c42 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java @@ -6,13 +6,12 @@ package org.elasticsearch.xpack.core.ml.action; import org.elasticsearch.action.ActionRequest; -import org.elasticsearch.action.ActionRequestBuilder; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionType; -import org.elasticsearch.client.ElasticsearchClient; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; @@ -39,27 +38,23 @@ public static class Request extends ActionRequest { private final String modelId; private final long modelVersion; private final List> objectsToInfer; - private final boolean cacheModel; - private final Integer topClasses; + private final int topClasses; public Request(String modelId, long modelVersion) { this(modelId, modelVersion, Collections.emptyList(), null); } public Request(String modelId, long modelVersion, List> objectsToInfer, Integer topClasses) { - this.modelId = modelId; + this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID); this.modelVersion = modelVersion; - this.objectsToInfer = objectsToInfer == null ? - Collections.emptyList() : - Collections.unmodifiableList(objectsToInfer); - this.cacheModel = true; - this.topClasses = topClasses; + this.objectsToInfer = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(objectsToInfer, "objects_to_infer")); + this.topClasses = topClasses == null ? 0 : topClasses; } public Request(String modelId, long modelVersion, Map objectToInfer, Integer topClasses) { this(modelId, modelVersion, - objectToInfer == null ? null : Arrays.asList(objectToInfer), + Arrays.asList(ExceptionsHelper.requireNonNull(objectToInfer, "objects_to_infer")), topClasses); } @@ -68,8 +63,7 @@ public Request(StreamInput in) throws IOException { this.modelId = in.readString(); this.modelVersion = in.readVLong(); this.objectsToInfer = Collections.unmodifiableList(in.readList(StreamInput::readMap)); - this.topClasses = in.readOptionalInt(); - this.cacheModel = in.readBoolean(); + this.topClasses = in.readInt(); } public String getModelId() { @@ -84,11 +78,7 @@ public List> getObjectsToInfer() { return objectsToInfer; } - public boolean isCacheModel() { - return cacheModel; - } - - public Integer getTopClasses() { + public int getTopClasses() { return topClasses; } @@ -103,8 +93,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(modelId); out.writeVLong(modelVersion); out.writeCollection(objectsToInfer, StreamOutput::writeMap); - out.writeOptionalInt(topClasses); - out.writeBoolean(cacheModel); + out.writeInt(topClasses); } @Override @@ -115,32 +104,25 @@ public boolean equals(Object o) { return Objects.equals(modelId, that.modelId) && Objects.equals(modelVersion, that.modelVersion) && Objects.equals(topClasses, that.topClasses) - && Objects.equals(cacheModel, that.cacheModel) && Objects.equals(objectsToInfer, that.objectsToInfer); } @Override public int hashCode() { - return Objects.hash(modelId, modelVersion, objectsToInfer, topClasses, cacheModel); + return Objects.hash(modelId, modelVersion, objectsToInfer, topClasses); } } - public static class RequestBuilder extends ActionRequestBuilder { - public RequestBuilder(ElasticsearchClient client, Request request) { - super(client, INSTANCE, request); - } - } - public static class Response extends ActionResponse { - private final List> inferenceResponse; + private final List> inferenceResults; private final String resultsType; public Response(List> inferenceResponse, String resultsType) { super(); this.resultsType = ExceptionsHelper.requireNonNull(resultsType, "resultsType"); - this.inferenceResponse = inferenceResponse == null ? + this.inferenceResults = inferenceResponse == null ? Collections.emptyList() : Collections.unmodifiableList(inferenceResponse); } @@ -149,16 +131,16 @@ public Response(StreamInput in) throws IOException { super(in); this.resultsType = in.readString(); if(resultsType.equals(ClassificationInferenceResults.RESULT_TYPE)) { - this.inferenceResponse = Collections.unmodifiableList(in.readList(ClassificationInferenceResults::new)); + this.inferenceResults = Collections.unmodifiableList(in.readList(ClassificationInferenceResults::new)); } else if (this.resultsType.equals(RegressionInferenceResults.RESULT_TYPE)) { - this.inferenceResponse = Collections.unmodifiableList(in.readList(RegressionInferenceResults::new)); + this.inferenceResults = Collections.unmodifiableList(in.readList(RegressionInferenceResults::new)); } else { throw new IOException("Unrecognized result type [" + resultsType + "]"); } } - public List> getInferenceResponse() { - return inferenceResponse; + public List> getInferenceResults() { + return inferenceResults; } public String getResultsType() { @@ -168,7 +150,7 @@ public String getResultsType() { @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(resultsType); - out.writeCollection(inferenceResponse); + out.writeCollection(inferenceResults); } @Override @@ -176,12 +158,12 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; InferModelAction.Response that = (InferModelAction.Response) o; - return Objects.equals(resultsType, that.resultsType) && Objects.equals(inferenceResponse, that.inferenceResponse); + return Objects.equals(resultsType, that.resultsType) && Objects.equals(inferenceResults, that.inferenceResults); } @Override public int hashCode() { - return Objects.hash(resultsType, inferenceResponse); + return Objects.hash(resultsType, inferenceResults); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java index a1e77c6bc122d..497cdab06bff8 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java @@ -30,17 +30,13 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults public ClassificationInferenceResults(double value, String classificationLabel, List topClasses) { super(value); this.classificationLabel = classificationLabel; - this.topClasses = topClasses == null ? null : Collections.unmodifiableList(topClasses); + this.topClasses = topClasses == null ? Collections.emptyList() : Collections.unmodifiableList(topClasses); } public ClassificationInferenceResults(StreamInput in) throws IOException { super(in); this.classificationLabel = in.readOptionalString(); - if (in.readBoolean()) { - this.topClasses = Collections.unmodifiableList(in.readList(TopClassEntry::new)); - } else { - this.topClasses = null; - } + this.topClasses = Collections.unmodifiableList(in.readList(TopClassEntry::new)); } public String getClassificationLabel() { @@ -55,10 +51,7 @@ public List getTopClasses() { public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeOptionalString(classificationLabel); - out.writeBoolean(topClasses != null); - if (topClasses != null) { - out.writeCollection(topClasses); - } + out.writeCollection(topClasses); } @Override @@ -66,7 +59,7 @@ XContentBuilder innerToXContent(XContentBuilder builder, Params params) throws I if (classificationLabel != null) { builder.field(CLASSIFICATION_LABEL.getPreferredName(), classificationLabel); } - if (topClasses != null) { + if (topClasses.isEmpty() == false) { builder.field(TOP_CLASSES.getPreferredName(), topClasses); } return builder; @@ -99,24 +92,24 @@ public String valueAsString() { public static class TopClassEntry implements ToXContentObject, Writeable { - public final ParseField LABEL = new ParseField("label"); + public final ParseField CLASSIFICATION = new ParseField("classification"); public final ParseField PROBABILITY = new ParseField("probability"); - private final String label; + private final String classification; private final double probability; - public TopClassEntry(String label, Double probability) { - this.label = ExceptionsHelper.requireNonNull(label, LABEL); + public TopClassEntry(String classification, Double probability) { + this.classification = ExceptionsHelper.requireNonNull(classification, CLASSIFICATION); this.probability = ExceptionsHelper.requireNonNull(probability, PROBABILITY); } public TopClassEntry(StreamInput in) throws IOException { - this.label = in.readString(); + this.classification = in.readString(); this.probability = in.readDouble(); } - public String getLabel() { - return label; + public String getClassification() { + return classification; } public double getProbability() { @@ -125,14 +118,14 @@ public double getProbability() { @Override public void writeTo(StreamOutput out) throws IOException { - out.writeString(label); + out.writeString(classification); out.writeDouble(probability); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(LABEL.getPreferredName(), label); + builder.field(CLASSIFICATION.getPreferredName(), classification); builder.field(PROBABILITY.getPreferredName(), probability); builder.endObject(); return builder; @@ -143,13 +136,13 @@ public boolean equals(Object object) { if (object == this) { return true; } if (object == null || getClass() != object.getClass()) { return false; } TopClassEntry that = (TopClassEntry) object; - return Objects.equals(label, that.label) && + return Objects.equals(classification, that.classification) && Objects.equals(probability, that.probability); } @Override public int hashCode() { - return Objects.hash(label, probability); + return Objects.hash(classification, probability); } } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java index 25142bc7acab9..d816939d57a59 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.core.ml.action.InferModelAction.Request; +import java.util.Map; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -17,12 +18,23 @@ public class InferModelActionRequestTests extends AbstractWireSerializingTestCas @Override protected Request createTestInstance() { - return new Request(randomAlphaOfLength(10), - randomLongBetween(1, 100), - randomBoolean() ? null : Stream.generate(()-> randomAlphaOfLength(10)) - .limit(randomInt(10)) - .collect(Collectors.toMap(Function.identity(), (v) -> randomAlphaOfLength(10))), - randomBoolean() ? null : randomIntBetween(-1, 100)); + return randomBoolean() ? + new Request( + randomAlphaOfLength(10), + randomLongBetween(1, 100), + Stream.generate(InferModelActionRequestTests::randomMap).limit(randomInt(10)).collect(Collectors.toList()), + randomBoolean() ? null : randomIntBetween(-1, 100)) : + new Request( + randomAlphaOfLength(10), + randomLongBetween(1, 100), + randomMap(), + randomBoolean() ? null : randomIntBetween(-1, 100)); + } + + private static Map randomMap() { + return Stream.generate(()-> randomAlphaOfLength(10)) + .limit(randomInt(10)) + .collect(Collectors.toMap(Function.identity(), (v) -> randomAlphaOfLength(10))); } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java index 597caa41fc557..ed487ed761b34 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java @@ -46,7 +46,7 @@ protected void doExecute(Task task, InferModelAction.Request request, ActionList r -> true, // Always fail immediately and return an error ex -> true); - if (request.getTopClasses() != null) { + if (request.getTopClasses() != 0) { request.getObjectsToInfer().forEach(stringObjectMap -> typedChainTaskExecutor.add(chainedTask -> model.classificationProbability(stringObjectMap, request.getTopClasses(), chainedTask)) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java index 5ea4c77a4c0ce..59f6c62a7f55e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java @@ -17,6 +17,7 @@ public class InferenceProcessor extends AbstractProcessor { public static final String MODEL_ID = "model_id"; private final Client client; + public InferenceProcessor(Client client, String tag) { super(tag); this.client = client; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java index 4cfc782a4658b..2ca8d03103a7c 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java @@ -62,7 +62,7 @@ public void testClassificationInfer() throws Exception { model.classificationProbability(fields, 1, future); ClassificationInferenceResults classificationResult = (ClassificationInferenceResults)future.get(); assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001)); - assertThat(classificationResult.getTopClasses().get(0).getLabel(), equalTo("0")); + assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("0")); // Test with labels definition = new TrainedModelDefinition.Builder() @@ -87,7 +87,7 @@ public void testClassificationInfer() throws Exception { result = future.get(); classificationResult = (ClassificationInferenceResults)result; assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001)); - assertThat(classificationResult.getTopClasses().get(0).getLabel(), equalTo("not_to_be")); + assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("not_to_be")); future = new PlainActionFuture<>(); model.classificationProbability(fields, 2, future); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java index 34ac5732dc362..60606536686a0 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java @@ -110,19 +110,19 @@ public void testInferModels() throws Exception { // Test regression InferModelAction.Request request = new InferModelAction.Request(modelId1, 0, toInfer, null); InferModelAction.Response response = client().execute(InferModelAction.INSTANCE, request).actionGet(); - assertThat(response.getInferenceResponse().stream().map(InferenceResults::value).collect(Collectors.toList()), + assertThat(response.getInferenceResults().stream().map(InferenceResults::value).collect(Collectors.toList()), contains(1.3, 1.25)); request = new InferModelAction.Request(modelId1, 0, toInfer2, null); response = client().execute(InferModelAction.INSTANCE, request).actionGet(); - assertThat(response.getInferenceResponse().stream().map(InferenceResults::value).collect(Collectors.toList()), + assertThat(response.getInferenceResults().stream().map(InferenceResults::value).collect(Collectors.toList()), contains(1.65, 1.55)); // Test classification request = new InferModelAction.Request(modelId2, 0, toInfer, null); response = client().execute(InferModelAction.INSTANCE, request).actionGet(); - assertThat(response.getInferenceResponse().stream().map(InferenceResults::valueAsString).collect(Collectors.toList()), + assertThat(response.getInferenceResults().stream().map(InferenceResults::valueAsString).collect(Collectors.toList()), contains("not_to_be", "to_be")); // Get top classes @@ -131,16 +131,17 @@ public void testInferModels() throws Exception { assertThat(response.getResultsType(), equalTo(ClassificationInferenceResults.RESULT_TYPE)); ClassificationInferenceResults classificationInferenceResults = - (ClassificationInferenceResults)response.getInferenceResponse().get(0); + (ClassificationInferenceResults)response.getInferenceResults().get(0); - assertThat(classificationInferenceResults.getTopClasses().get(0).getLabel(), equalTo("not_to_be")); - assertThat(classificationInferenceResults.getTopClasses().get(1).getLabel(), equalTo("to_be")); + assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("not_to_be")); + assertThat(classificationInferenceResults.getTopClasses().get(1).getClassification(), equalTo("to_be")); assertThat(classificationInferenceResults.getTopClasses().get(0).getProbability(), greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability())); - classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResponse().get(1); - assertThat(classificationInferenceResults.getTopClasses().get(0).getLabel(), equalTo("to_be")); - assertThat(classificationInferenceResults.getTopClasses().get(1).getLabel(), equalTo("not_to_be")); + classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(1); + assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("to_be")); + assertThat(classificationInferenceResults.getTopClasses().get(1).getClassification(), equalTo("not_to_be")); + // they should always be in order of Most probable to least assertThat(classificationInferenceResults.getTopClasses().get(0).getProbability(), greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability())); @@ -149,9 +150,9 @@ public void testInferModels() throws Exception { response = client().execute(InferModelAction.INSTANCE, request).actionGet(); assertThat(response.getResultsType(), equalTo(ClassificationInferenceResults.RESULT_TYPE)); - classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResponse().get(0); + classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(0); assertThat(classificationInferenceResults.getTopClasses(), hasSize(1)); - assertThat(classificationInferenceResults.getTopClasses().get(0).getLabel(), equalTo("to_be")); + assertThat(classificationInferenceResults.getTopClasses().get(0).getClassification(), equalTo("to_be")); } public void testInferMissingModel() { @@ -271,7 +272,6 @@ public NamedXContentRegistry xContentRegistry() { namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); namedXContent.addAll(new SearchModule(Settings.EMPTY, Collections.emptyList()).getNamedXContents()); return new NamedXContentRegistry(namedXContent); - } } From d92886b1b2eb304c0153efa1323a7772a98d7520 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Mon, 7 Oct 2019 09:51:26 -0400 Subject: [PATCH 08/13] partial commit --- .../MlInferenceNamedXContentProvider.java | 11 +++ .../ClassificationInferenceResults.java | 32 ++++++++- .../inference/results/InferenceResults.java | 13 ++-- .../results/RegressionInferenceResults.java | 24 +++++-- .../results/SingleValueInferenceResults.java | 4 +- .../trainedmodel/InferenceHelpers.java | 53 +++++++++++++++ .../trainedmodel/InferenceParams.java | 58 ++++++++++++++++ .../inference/trainedmodel/TrainedModel.java | 30 +-------- .../trainedmodel/ensemble/Ensemble.java | 10 +-- .../ml/inference/trainedmodel/tree/Tree.java | 67 ++++++++++++------- 10 files changed, 226 insertions(+), 76 deletions(-) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParams.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java index 7fff4d6abbd3b..ee4eb0e9c1280 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java @@ -8,6 +8,9 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.plugins.spi.NamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; @@ -100,6 +103,14 @@ public List getNamedWriteables() { WeightedMode.NAME.getPreferredName(), WeightedMode::new)); + // Inference Results + namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, + ClassificationInferenceResults.RESULT_TYPE, + ClassificationInferenceResults::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, + RegressionInferenceResults.RESULT_TYPE, + RegressionInferenceResults::new)); + return namedWriteables; } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java index 497cdab06bff8..a9f8b3fc66627 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java @@ -11,11 +11,14 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.ingest.IngestDocument; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Objects; public class ClassificationInferenceResults extends SingleValueInferenceResults { @@ -81,13 +84,29 @@ public int hashCode() { } @Override - public String resultType() { + public String valueAsString() { + return classificationLabel == null ? super.valueAsString() : classificationLabel; + } + + @Override + public void writeResult(IngestDocument document, String resultField) { + ExceptionsHelper.requireNonNull(document, "document"); + ExceptionsHelper.requireNonNull(resultField, "resultField"); + if (topClasses.isEmpty()) { + document.setFieldValue(resultField, valueAsString()); + } else { + document.setFieldValue(resultField, topClasses.stream().map(TopClassEntry::asValueMap)); + } + } + + @Override + public String getWriteableName() { return RESULT_TYPE; } @Override - public String valueAsString() { - return classificationLabel == null ? super.valueAsString() : classificationLabel; + public String getName() { + return RESULT_TYPE; } public static class TopClassEntry implements ToXContentObject, Writeable { @@ -116,6 +135,13 @@ public double getProbability() { return probability; } + public Map asValueMap() { + Map map = new HashMap<>(2); + map.put(CLASSIFICATION.getPreferredName(), classification); + map.put(PROBABILITY.getPreferredName(), probability); + return map; + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(classification); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResults.java index 2c349945e1ced..00744f6982f46 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResults.java @@ -5,15 +5,12 @@ */ package org.elasticsearch.xpack.core.ml.inference.results; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; -public interface InferenceResults extends ToXContentObject, Writeable { +public interface InferenceResults extends NamedXContentObject, NamedWriteable { - String resultType(); - - T value(); - - String valueAsString(); + void writeResult(IngestDocument document, String resultField); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java index e4bb0b0136a25..1b01a2dff1f08 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java @@ -8,6 +8,8 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.ingest.IngestDocument; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; import java.util.Objects; @@ -34,11 +36,6 @@ XContentBuilder innerToXContent(XContentBuilder builder, Params params) { return builder; } - @Override - public String resultType() { - return RESULT_TYPE; - } - @Override public boolean equals(Object object) { if (object == this) { return true; } @@ -51,4 +48,21 @@ public boolean equals(Object object) { public int hashCode() { return Objects.hash(value()); } + + @Override + public void writeResult(IngestDocument document, String resultField) { + ExceptionsHelper.requireNonNull(document, "document"); + ExceptionsHelper.requireNonNull(resultField, "resultField"); + document.setFieldValue(resultField, value()); + } + + @Override + public String getWriteableName() { + return RESULT_TYPE; + } + + @Override + public String getName() { + return RESULT_TYPE; + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java index f3d006a4caa6a..2905a6679584c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java @@ -12,7 +12,7 @@ import java.io.IOException; -public abstract class SingleValueInferenceResults implements InferenceResults { +public abstract class SingleValueInferenceResults implements InferenceResults { public final ParseField VALUE = new ParseField("value"); @@ -26,12 +26,10 @@ public abstract class SingleValueInferenceResults implements InferenceResults topClasses(List probabilities, + List classificationLabels, + int numToInclude) { + if (numToInclude == 0) { + return Collections.emptyList(); + } + int[] sortedIndices = IntStream.range(0, probabilities.size()) + .boxed() + .sorted(Comparator.comparing(probabilities::get).reversed()) + .mapToInt(i -> i) + .toArray(); + + List labels = classificationLabels == null ? + // If we don't have the labels we should return the top classification values anyways, they will just be numeric + IntStream.range(0, probabilities.size()).boxed().map(String::valueOf).collect(Collectors.toList()) : + classificationLabels; + + if (probabilities.size() != labels.size()) { + throw ExceptionsHelper + .badRequestException( + "model returned classification probabilities of size [{}] which is not equal to classification labels size [{}]", + probabilities.size(), + classificationLabels); + } + + + int count = numToInclude < 0 ? probabilities.size() : numToInclude; + List topClassEntries = new ArrayList<>(count); + for(int i = 0; i < count; i++) { + int idx = sortedIndices[i]; + topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(labels.get(idx), probabilities.get(idx))); + } + + return topClassEntries; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParams.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParams.java new file mode 100644 index 0000000000000..9eeff10d76bd6 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParams.java @@ -0,0 +1,58 @@ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +public class InferenceParams implements ToXContentObject, Writeable { + + public static ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); + + private final int numTopClasses; + + public InferenceParams(Integer numTopClasses) { + this.numTopClasses = numTopClasses == null ? 0 : numTopClasses; + } + + public InferenceParams(StreamInput in) throws IOException { + this.numTopClasses = in.readInt(); + } + + public int getNumTopClasses() { + return numTopClasses; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeInt(numTopClasses); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + InferenceParams that = (InferenceParams) o; + return Objects.equals(numTopClasses, that.numTopClasses); + } + + @Override + public int hashCode() { + return Objects.hash(numTopClasses); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (numTopClasses != 0) { + builder.field(NUM_TOP_CLASSES.getPreferredName(), numTopClasses); + } + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java index cad5a6c0a8c74..a6c6f1eff011d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java @@ -7,6 +7,7 @@ import org.elasticsearch.common.Nullable; import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; import java.util.List; @@ -26,13 +27,7 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable { * @return The predicted value. For classification this will be discrete values (e.g. 0.0, or 1.0). * For regression this is continuous. */ - double infer(Map fields); - - /** - * @param fields similar to {@link TrainedModel#infer(Map)}, but fields are already in order and doubles - * @return The predicted value. - */ - double infer(List fields); + InferenceResults infer(Map fields, InferenceParams params); /** * @return {@link TargetType} for the model. @@ -40,26 +35,7 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable { TargetType targetType(); /** - * This gathers the probabilities for each potential classification value. - * - * The probabilities are indexed by classification ordinal label encoding. - * The length of this list is equal to the number of classification labels. - * - * This only should return if the implementation model is inferring classification values and not regression - * @param fields The fields and their values to infer against - * @return The probabilities of each classification value - */ - List classificationProbability(Map fields); - - /** - * @param fields similar to {@link TrainedModel#classificationProbability(Map)} but the fields are already in order and doubles - * @return The probabilities of each classification value - */ - List classificationProbability(List fields); - - /** - * The ordinal encoded list of the classification labels. - * @return Oridinal encoded list of classification labels. + * @return Ordinal encoded list of classification labels. */ @Nullable List classificationLabels(); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java index add3a8d7c151c..30dea57a2815b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java @@ -12,6 +12,8 @@ import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; @@ -106,17 +108,11 @@ public List getFeatureNames() { } @Override - public double infer(Map fields) { + public InferenceResults infer(Map fields, InferenceParams params) { List features = featureNames.stream().map(f -> ((Number)fields.get(f)).doubleValue()).collect(Collectors.toList()); return infer(features); } - @Override - public double infer(List fields) { - List processedInferences = inferAndProcess(fields); - return outputAggregator.aggregate(processedInferences); - } - @Override public TargetType targetType() { return targetType; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java index 5dca29d58437e..55651fda7b46c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java @@ -13,6 +13,11 @@ import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; @@ -105,18 +110,34 @@ public List getNodes() { } @Override - public double infer(Map fields) { + public InferenceResults infer(Map fields, InferenceParams params) { + if (targetType != TargetType.CLASSIFICATION && params.getNumTopClasses() != 0) { + throw new UnsupportedOperationException( + "Cannot return top classes for target_type [" + targetType.toString() + "]"); + } List features = featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList()); - return infer(features); + return infer(features, params); } - @Override - public double infer(List features) { + private InferenceResults infer(List features, InferenceParams params) { TreeNode node = nodes.get(0); while(node.isLeaf() == false) { node = nodes.get(node.compare(features)); } - return node.getLeafValue(); + return buildResult(node.getLeafValue(), params); + } + + private InferenceResults buildResult(Double value, InferenceParams params) { + switch (targetType) { + case CLASSIFICATION: + List topClasses = + InferenceHelpers.topClasses(classificationProbability(value), classificationLabels, params.getNumTopClasses()); + return new ClassificationInferenceResults(value, classificationLabel(value), topClasses); + case REGRESSION: + return new RegressionInferenceResults(value); + default: + throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on tree model"); + } } /** @@ -140,33 +161,33 @@ public TargetType targetType() { return targetType; } - @Override - public List classificationProbability(Map fields) { - if ((targetType == TargetType.CLASSIFICATION) == false) { - throw new UnsupportedOperationException( - "Cannot determine classification probability with target_type [" + targetType.toString() + "]"); - } - return classificationProbability(featureNames.stream().map(f -> (Double) fields.get(f)).collect(Collectors.toList())); - } - - @Override - public List classificationProbability(List fields) { - if ((targetType == TargetType.CLASSIFICATION) == false) { - throw new UnsupportedOperationException( - "Cannot determine classification probability with target_type [" + targetType.toString() + "]"); - } - double label = infer(fields); + private List classificationProbability(double inferenceValue) { // If we are classification, we should assume that the inference return value is whole. - assert label == Math.rint(label); + assert inferenceValue == Math.rint(inferenceValue); double maxCategory = this.highestOrderCategory.get(); // If we are classification, we should assume that the largest leaf value is whole. assert maxCategory == Math.rint(maxCategory); List list = new ArrayList<>(Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0)); // TODO, eventually have TreeNodes contain confidence levels - list.set(Double.valueOf(label).intValue(), 1.0); + list.set(Double.valueOf(inferenceValue).intValue(), 1.0); return list; } + private String classificationLabel(double inferenceValue) { + assert inferenceValue == Math.rint(inferenceValue); + if (classificationLabels == null) { + return String.valueOf(inferenceValue); + } + int label = Double.valueOf(inferenceValue).intValue(); + if (label < 0 || label >= classificationLabels.size()) { + throw ExceptionsHelper.badRequestException( + "model returned classification value of [{}] which is not a valid index in classification labels [{}]", + label, + classificationLabels); + } + return classificationLabels.get(label); + } + @Override public List classificationLabels() { return classificationLabels; From f753a224bf1d5fb04b3a7b414f37b15b9e11f7ae Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Mon, 7 Oct 2019 15:47:48 -0400 Subject: [PATCH 09/13] adjusting underlying modeling and addressing PR comments --- .../xpack/core/XPackClientPlugin.java | 18 ++++ .../core/ml/action/InferModelAction.java | 31 +++---- .../ml/inference/TrainedModelDefinition.java | 12 +++ .../ClassificationInferenceResults.java | 3 +- .../trainedmodel/InferenceHelpers.java | 23 ++++- .../trainedmodel/InferenceParams.java | 7 ++ .../trainedmodel/ensemble/Ensemble.java | 48 +++++++---- .../ml/inference/trainedmodel/tree/Tree.java | 23 ++--- .../action/InferModelActionRequestTests.java | 6 +- .../action/InferModelActionResponseTests.java | 2 +- .../ClassificationInferenceResultsTests.java | 45 ++++++++++ .../RegressionInferenceResultsTests.java | 13 +++ .../trainedmodel/InferenceParamsTests.java | 27 ++++++ .../trainedmodel/ensemble/EnsembleTests.java | 56 ++++++++----- .../trainedmodel/tree/TreeTests.java | 47 +++++++---- .../ml/action/TransportInferModelAction.java | 16 ++-- .../inference/loadingservice/LocalModel.java | 84 ++----------------- .../ml/inference/loadingservice/Model.java | 4 +- .../loadingservice/ModelLoadingService.java | 73 +++++++++++----- .../loadingservice/LocalModelTests.java | 64 ++++++-------- .../integration/ModelInferenceActionIT.java | 18 ++-- 21 files changed, 376 insertions(+), 244 deletions(-) create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParamsTests.java diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java index e83f7d63016c8..d6561379c0cfc 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java @@ -140,7 +140,14 @@ import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Recall; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ScoreByThresholdResult; import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding; import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; @@ -453,6 +460,17 @@ public List getNamedWriteables() { new NamedWriteableRegistry.Entry(PreProcessor.class, TargetMeanEncoding.NAME.getPreferredName(), TargetMeanEncoding::new), // ML - Inference models new NamedWriteableRegistry.Entry(TrainedModel.class, Tree.NAME.getPreferredName(), Tree::new), + new NamedWriteableRegistry.Entry(TrainedModel.class, Ensemble.NAME.getPreferredName(), Ensemble::new), + // ML - Inference aggregators + new NamedWriteableRegistry.Entry(OutputAggregator.class, WeightedSum.NAME.getPreferredName(), WeightedSum::new), + new NamedWriteableRegistry.Entry(OutputAggregator.class, WeightedMode.NAME.getPreferredName(), WeightedMode::new), + // ML - Inference Results + new NamedWriteableRegistry.Entry(InferenceResults.class, + ClassificationInferenceResults.RESULT_TYPE, + ClassificationInferenceResults::new), + new NamedWriteableRegistry.Entry(InferenceResults.class, + RegressionInferenceResults.RESULT_TYPE, + RegressionInferenceResults::new), // monitoring new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.MONITORING, MonitoringFeatureSetUsage::new), diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java index 940e892c48c42..a576c1935c4ad 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java @@ -15,6 +15,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; @@ -38,24 +39,24 @@ public static class Request extends ActionRequest { private final String modelId; private final long modelVersion; private final List> objectsToInfer; - private final int topClasses; + private final InferenceParams params; public Request(String modelId, long modelVersion) { - this(modelId, modelVersion, Collections.emptyList(), null); + this(modelId, modelVersion, Collections.emptyList(), InferenceParams.EMPTY_PARAMS); } - public Request(String modelId, long modelVersion, List> objectsToInfer, Integer topClasses) { + public Request(String modelId, long modelVersion, List> objectsToInfer, InferenceParams inferenceParams) { this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID); this.modelVersion = modelVersion; this.objectsToInfer = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(objectsToInfer, "objects_to_infer")); - this.topClasses = topClasses == null ? 0 : topClasses; + this.params = inferenceParams == null ? InferenceParams.EMPTY_PARAMS : inferenceParams; } - public Request(String modelId, long modelVersion, Map objectToInfer, Integer topClasses) { + public Request(String modelId, long modelVersion, Map objectToInfer, InferenceParams params) { this(modelId, modelVersion, Arrays.asList(ExceptionsHelper.requireNonNull(objectToInfer, "objects_to_infer")), - topClasses); + params); } public Request(StreamInput in) throws IOException { @@ -63,7 +64,7 @@ public Request(StreamInput in) throws IOException { this.modelId = in.readString(); this.modelVersion = in.readVLong(); this.objectsToInfer = Collections.unmodifiableList(in.readList(StreamInput::readMap)); - this.topClasses = in.readInt(); + this.params = new InferenceParams(in); } public String getModelId() { @@ -78,8 +79,8 @@ public List> getObjectsToInfer() { return objectsToInfer; } - public int getTopClasses() { - return topClasses; + public InferenceParams getParams() { + return params; } @Override @@ -93,7 +94,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(modelId); out.writeVLong(modelVersion); out.writeCollection(objectsToInfer, StreamOutput::writeMap); - out.writeInt(topClasses); + params.writeTo(out); } @Override @@ -103,23 +104,23 @@ public boolean equals(Object o) { InferModelAction.Request that = (InferModelAction.Request) o; return Objects.equals(modelId, that.modelId) && Objects.equals(modelVersion, that.modelVersion) - && Objects.equals(topClasses, that.topClasses) + && Objects.equals(params, that.params) && Objects.equals(objectsToInfer, that.objectsToInfer); } @Override public int hashCode() { - return Objects.hash(modelId, modelVersion, objectsToInfer, topClasses); + return Objects.hash(modelId, modelVersion, objectsToInfer, params); } } public static class Response extends ActionResponse { - private final List> inferenceResults; + private final List inferenceResults; private final String resultsType; - public Response(List> inferenceResponse, String resultsType) { + public Response(List inferenceResponse, String resultsType) { super(); this.resultsType = ExceptionsHelper.requireNonNull(resultsType, "resultsType"); this.inferenceResults = inferenceResponse == null ? @@ -139,7 +140,7 @@ public Response(StreamInput in) throws IOException { } } - public List> getInferenceResults() { + public List getInferenceResults() { return inferenceResults; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java index f85c184646e1f..e936d60bf87b7 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java @@ -18,6 +18,8 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor; import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor; +import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; @@ -27,6 +29,7 @@ import java.io.IOException; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Objects; public class TrainedModelDefinition implements ToXContentObject, Writeable { @@ -118,6 +121,15 @@ public Input getInput() { return input; } + private void preProcess(Map fields) { + preProcessors.forEach(preProcessor -> preProcessor.process(fields)); + } + + public InferenceResults infer(Map fields, InferenceParams params) { + preProcess(fields); + return trainedModel.infer(fields, params); + } + @Override public String toString() { return Strings.toString(this); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java index a9f8b3fc66627..1aeb8d34cdd00 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java @@ -20,6 +20,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.stream.Collectors; public class ClassificationInferenceResults extends SingleValueInferenceResults { @@ -95,7 +96,7 @@ public void writeResult(IngestDocument document, String resultField) { if (topClasses.isEmpty()) { document.setFieldValue(resultField, valueAsString()); } else { - document.setFieldValue(resultField, topClasses.stream().map(TopClassEntry::asValueMap)); + document.setFieldValue(resultField, topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList())); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java index 133d681978f44..18884ef252f87 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java @@ -1,5 +1,11 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; +import org.elasticsearch.common.Nullable; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -7,7 +13,6 @@ import java.util.Collections; import java.util.Comparator; import java.util.List; -import java.util.Map; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -50,4 +55,20 @@ public static List topClasses(List return topClassEntries; } + + + public static String classificationLabel(double inferenceValue, @Nullable List classificationLabels) { + assert inferenceValue == Math.rint(inferenceValue); + if (classificationLabels == null) { + return String.valueOf(inferenceValue); + } + int label = Double.valueOf(inferenceValue).intValue(); + if (label < 0 || label >= classificationLabels.size()) { + throw ExceptionsHelper.badRequestException( + "model returned classification value of [{}] which is not a valid index in classification labels [{}]", + label, + classificationLabels); + } + return classificationLabels.get(label); + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParams.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParams.java index 9eeff10d76bd6..150bf3d483f26 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParams.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParams.java @@ -1,3 +1,8 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; import org.elasticsearch.common.ParseField; @@ -14,6 +19,8 @@ public class InferenceParams implements ToXContentObject, Writeable { public static ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); + public static InferenceParams EMPTY_PARAMS = new InferenceParams(0); + private final int numTopClasses; public InferenceParams(Integer numTopClasses) { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java index b513eedfcfd9f..22326e1d17e13 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java @@ -12,7 +12,11 @@ import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel; @@ -28,6 +32,8 @@ import java.util.Objects; import java.util.stream.Collectors; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel; + public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel { // TODO should we have regression/classification sub-classes that accept the builder? @@ -109,24 +115,41 @@ public List getFeatureNames() { @Override public InferenceResults infer(Map fields, InferenceParams params) { - List features = featureNames.stream().map(f -> ((Number)fields.get(f)).doubleValue()).collect(Collectors.toList()); - return infer(features); + if ((targetType != TargetType.CLASSIFICATION || outputAggregator instanceof WeightedMode == false) && + params.getNumTopClasses() != 0) { + throw ExceptionsHelper.badRequestException( + "Cannot return top classes for target_type [{}] and aggregate_output [{}]", + targetType, + outputAggregator.getName()); + } + List inferenceResults = this.models.stream().map(model -> { + InferenceResults results = model.infer(fields, InferenceParams.EMPTY_PARAMS); + assert results instanceof SingleValueInferenceResults; + return ((SingleValueInferenceResults)results).value(); + }).collect(Collectors.toList()); + List processed = outputAggregator.processValues(inferenceResults); + return buildResults(processed, params); } - @Override public TargetType targetType() { return targetType; } - private List classificationProbability(Map fields) { - if ((targetType == TargetType.CLASSIFICATION) == false) { - throw new UnsupportedOperationException( - "Cannot determine classification probability with target_type [" + targetType.toString() + "]"); + private InferenceResults buildResults(List processedInferences, InferenceParams params) { + switch(targetType) { + case REGRESSION: + return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences)); + case CLASSIFICATION: + List topClasses = + InferenceHelpers.topClasses(processedInferences, classificationLabels, params.getNumTopClasses()); + double value = outputAggregator.aggregate(processedInferences); + return new ClassificationInferenceResults(outputAggregator.aggregate(processedInferences), + classificationLabel(value, classificationLabels), + topClasses); + default: + throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on ensemble model"); } - List features = featureNames.stream().map(f -> ((Number)fields.get(f)).doubleValue()).collect(Collectors.toList()); - return classificationProbability(features); - return inferAndProcess(fields); } @Override @@ -134,11 +157,6 @@ public List classificationLabels() { return classificationLabels; } - private List inferAndProcess(Map fields) { - List modelInferences = models.stream().map(m -> m.infer(fields)).collect(Collectors.toList()); - return outputAggregator.processValues(modelInferences); - } - @Override public String getWriteableName() { return NAME.getPreferredName(); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java index 1b90920dc7aa2..bce6b08b6ed4b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java @@ -36,6 +36,8 @@ import java.util.Set; import java.util.stream.Collectors; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel; + public class Tree implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel { // TODO should we have regression/classification sub-classes that accept the builder? @@ -112,8 +114,8 @@ public List getNodes() { @Override public InferenceResults infer(Map fields, InferenceParams params) { if (targetType != TargetType.CLASSIFICATION && params.getNumTopClasses() != 0) { - throw new UnsupportedOperationException( - "Cannot return top classes for target_type [" + targetType.toString() + "]"); + throw ExceptionsHelper.badRequestException( + "Cannot return top classes for target_type [{}]", targetType.toString()); } List features = featureNames.stream().map(f -> fields.get(f) instanceof Number ? ((Number) fields.get(f)).doubleValue() : null @@ -134,7 +136,7 @@ private InferenceResults buildResult(Double value, InferenceParams params) { case CLASSIFICATION: List topClasses = InferenceHelpers.topClasses(classificationProbability(value), classificationLabels, params.getNumTopClasses()); - return new ClassificationInferenceResults(value, classificationLabel(value), topClasses); + return new ClassificationInferenceResults(value, classificationLabel(value, classificationLabels), topClasses); case REGRESSION: return new RegressionInferenceResults(value); default: @@ -175,21 +177,6 @@ private List classificationProbability(double inferenceValue) { return list; } - private String classificationLabel(double inferenceValue) { - assert inferenceValue == Math.rint(inferenceValue); - if (classificationLabels == null) { - return String.valueOf(inferenceValue); - } - int label = Double.valueOf(inferenceValue).intValue(); - if (label < 0 || label >= classificationLabels.size()) { - throw ExceptionsHelper.badRequestException( - "model returned classification value of [{}] which is not a valid index in classification labels [{}]", - label, - classificationLabels); - } - return classificationLabels.get(label); - } - @Override public List classificationLabels() { return classificationLabels; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java index d816939d57a59..a49643d081957 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionRequestTests.java @@ -14,6 +14,8 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParamsTests.randomInferenceParams; + public class InferModelActionRequestTests extends AbstractWireSerializingTestCase { @Override @@ -23,12 +25,12 @@ protected Request createTestInstance() { randomAlphaOfLength(10), randomLongBetween(1, 100), Stream.generate(InferModelActionRequestTests::randomMap).limit(randomInt(10)).collect(Collectors.toList()), - randomBoolean() ? null : randomIntBetween(-1, 100)) : + randomBoolean() ? null : randomInferenceParams()) : new Request( randomAlphaOfLength(10), randomLongBetween(1, 100), randomMap(), - randomBoolean() ? null : randomIntBetween(-1, 100)); + randomBoolean() ? null : randomInferenceParams()); } private static Map randomMap() { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java index 4261b6b05aff5..4b80c927bed30 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java @@ -29,7 +29,7 @@ protected Response createTestInstance() { resultType); } - private static InferenceResults randomInferenceResult(String resultType) { + private static InferenceResults randomInferenceResult(String resultType) { if (resultType.equals(ClassificationInferenceResults.RESULT_TYPE)) { return ClassificationInferenceResultsTests.createRandomResults(); } else if (resultType.equals(RegressionInferenceResults.RESULT_TYPE)) { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java index c446c4d8225bc..ba90fece02f2a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java @@ -6,11 +6,19 @@ package org.elasticsearch.xpack.core.ml.inference.results; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.ingest.IngestDocument; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.hamcrest.Matchers.equalTo; + public class ClassificationInferenceResultsTests extends AbstractWireSerializingTestCase { public static ClassificationInferenceResults createRandomResults() { @@ -26,6 +34,43 @@ private static ClassificationInferenceResults.TopClassEntry createRandomClassEnt return new ClassificationInferenceResults.TopClassEntry(randomAlphaOfLength(10), randomDouble()); } + public void testWriteResultsWithClassificationLabel() { + ClassificationInferenceResults result = new ClassificationInferenceResults(1.0, "foo", Collections.emptyList()); + IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>()); + result.writeResult(document, "result_field"); + + assertThat(document.getFieldValue("result_field", String.class), equalTo("foo")); + } + + public void testWriteResultsWithoutClassificationLabel() { + ClassificationInferenceResults result = new ClassificationInferenceResults(1.0, null, Collections.emptyList()); + IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>()); + result.writeResult(document, "result_field"); + + assertThat(document.getFieldValue("result_field", String.class), equalTo("1.0")); + } + + @SuppressWarnings("unchecked") + public void testWriteResultsWithTopClasses() { + List entries = Arrays.asList( + new ClassificationInferenceResults.TopClassEntry("foo", 0.7), + new ClassificationInferenceResults.TopClassEntry("bar", 0.2), + new ClassificationInferenceResults.TopClassEntry("baz", 0.1)); + ClassificationInferenceResults result = new ClassificationInferenceResults(1.0, + "foo", + entries); + IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>()); + result.writeResult(document, "result_field"); + + List list = document.getFieldValue("result_field", List.class); + assertThat(list.size(), equalTo(3)); + + for(int i = 0; i < 3; i++) { + Map map = (Map)list.get(i); + assertThat(map, equalTo(entries.get(i).asValueMap())); + } + } + @Override protected ClassificationInferenceResults createTestInstance() { return createRandomResults(); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java index 479aca9d08c1b..4f2d5926c84dc 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java @@ -6,9 +6,14 @@ package org.elasticsearch.xpack.core.ml.inference.results; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.ingest.IngestDocument; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import java.util.HashMap; + +import static org.hamcrest.Matchers.equalTo; + public class RegressionInferenceResultsTests extends AbstractWireSerializingTestCase { @@ -16,6 +21,14 @@ public static RegressionInferenceResults createRandomResults() { return new RegressionInferenceResults(randomDouble()); } + public void testWriteResults() { + RegressionInferenceResults result = new RegressionInferenceResults(0.3); + IngestDocument document = new IngestDocument(new HashMap<>(), new HashMap<>()); + result.writeResult(document, "result_field"); + + assertThat(document.getFieldValue("result_field", Double.class), equalTo(0.3)); + } + @Override protected RegressionInferenceResults createTestInstance() { return createRandomResults(); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParamsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParamsTests.java new file mode 100644 index 0000000000000..2586cdd75de47 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceParamsTests.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; + +public class InferenceParamsTests extends AbstractWireSerializingTestCase { + + public static InferenceParams randomInferenceParams() { + return randomBoolean() ? InferenceParams.EMPTY_PARAMS : new InferenceParams(randomIntBetween(-1, 100)); + } + + @Override + protected InferenceParams createTestInstance() { + return randomInferenceParams(); + } + + @Override + protected Writeable.Reader instanceReader() { + return InferenceParams::new; + } + +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java index eb537e247e994..8da0c15718f24 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java @@ -15,6 +15,9 @@ import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; @@ -239,27 +242,30 @@ public void testClassificationProbability() { List featureVector = Arrays.asList(0.4, 0.0); Map featureMap = zipObjMap(featureNames, featureVector); - List expected = Arrays.asList(0.231475216, 0.768524783); + List expected = Arrays.asList(0.768524783, 0.231475216); double eps = 0.000001; - List probabilities = ensemble.classificationProbability(featureMap); + List probabilities = + ((ClassificationInferenceResults)ensemble.infer(featureMap, new InferenceParams(2))).getTopClasses(); for(int i = 0; i < expected.size(); i++) { - assertThat(probabilities.get(i), closeTo(expected.get(i), eps)); + assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps)); } featureVector = Arrays.asList(2.0, 0.7); featureMap = zipObjMap(featureNames, featureVector); - expected = Arrays.asList(0.3100255188, 0.689974481); - probabilities = ensemble.classificationProbability(featureMap); + expected = Arrays.asList(0.689974481, 0.3100255188); + probabilities = + ((ClassificationInferenceResults)ensemble.infer(featureMap, new InferenceParams(2))).getTopClasses(); for(int i = 0; i < expected.size(); i++) { - assertThat(probabilities.get(i), closeTo(expected.get(i), eps)); + assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps)); } featureVector = Arrays.asList(0.0, 1.0); featureMap = zipObjMap(featureNames, featureVector); - expected = Arrays.asList(0.231475216, 0.768524783); - probabilities = ensemble.classificationProbability(featureMap); + expected = Arrays.asList(0.768524783, 0.231475216); + probabilities = + ((ClassificationInferenceResults)ensemble.infer(featureMap, new InferenceParams(2))).getTopClasses(); for(int i = 0; i < expected.size(); i++) { - assertThat(probabilities.get(i), closeTo(expected.get(i), eps)); + assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps)); } // This should handle missing values and take the default_left path @@ -268,9 +274,10 @@ public void testClassificationProbability() { put("bar", null); }}; expected = Arrays.asList(0.6899744811, 0.3100255188); - probabilities = ensemble.classificationProbability(featureMap); + probabilities = + ((ClassificationInferenceResults)ensemble.infer(featureMap, new InferenceParams(2))).getTopClasses(); for(int i = 0; i < expected.size(); i++) { - assertThat(probabilities.get(i), closeTo(expected.get(i), eps)); + assertThat(probabilities.get(i).getProbability(), closeTo(expected.get(i), eps)); } } @@ -320,21 +327,25 @@ public void testClassificationInference() { List featureVector = Arrays.asList(0.4, 0.0); Map featureMap = zipObjMap(featureNames, featureVector); - assertEquals(1.0, ensemble.infer(featureMap), 0.00001); + assertThat(1.0, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); featureVector = Arrays.asList(2.0, 0.7); featureMap = zipObjMap(featureNames, featureVector); - assertEquals(1.0, ensemble.infer(featureMap), 0.00001); + assertThat(1.0, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); featureVector = Arrays.asList(0.0, 1.0); featureMap = zipObjMap(featureNames, featureVector); - assertEquals(1.0, ensemble.infer(featureMap), 0.00001); + assertThat(1.0, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); featureMap = new HashMap<>(2) {{ put("foo", 0.3); put("bar", null); }}; - assertEquals(0.0, ensemble.infer(featureMap), 0.00001); + assertThat(0.0, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); } public void testRegressionInference() { @@ -373,11 +384,13 @@ public void testRegressionInference() { List featureVector = Arrays.asList(0.4, 0.0); Map featureMap = zipObjMap(featureNames, featureVector); - assertEquals(0.9, ensemble.infer(featureMap), 0.00001); + assertThat(0.9, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); featureVector = Arrays.asList(2.0, 0.7); featureMap = zipObjMap(featureNames, featureVector); - assertEquals(0.5, ensemble.infer(featureMap), 0.00001); + assertThat(0.5, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); // Test with NO aggregator supplied, verifies default behavior of non-weighted sum ensemble = Ensemble.builder() @@ -388,17 +401,20 @@ public void testRegressionInference() { featureVector = Arrays.asList(0.4, 0.0); featureMap = zipObjMap(featureNames, featureVector); - assertEquals(1.8, ensemble.infer(featureMap), 0.00001); + assertThat(1.8, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); featureVector = Arrays.asList(2.0, 0.7); featureMap = zipObjMap(featureNames, featureVector); - assertEquals(1.0, ensemble.infer(featureMap), 0.00001); + assertThat(1.0, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); featureMap = new HashMap<>(2) {{ put("foo", 0.3); put("bar", null); }}; - assertEquals(1.8, ensemble.infer(featureMap), 0.00001); + assertThat(1.8, + closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); } private static Map zipObjMap(List keys, List values) { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java index 81030585f1889..c362c17fd579d 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java @@ -10,6 +10,9 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.test.AbstractSerializingTestCase; +import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.junit.Before; @@ -120,26 +123,30 @@ public void testInfer() { // This feature vector should hit the right child of the root node List featureVector = Arrays.asList(0.6, 0.0); Map featureMap = zipObjMap(featureNames, featureVector); - assertThat(0.3, closeTo(tree.infer(featureMap), 0.00001)); + assertThat(0.3, + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); // This should hit the left child of the left child of the root node // i.e. it takes the path left, left featureVector = Arrays.asList(0.3, 0.7); featureMap = zipObjMap(featureNames, featureVector); - assertThat(0.1, closeTo(tree.infer(featureMap), 0.00001)); + assertThat(0.1, + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); // This should hit the right child of the left child of the root node // i.e. it takes the path left, right featureVector = Arrays.asList(0.3, 0.9); featureMap = zipObjMap(featureNames, featureVector); - assertThat(0.2, closeTo(tree.infer(featureMap), 0.00001)); + assertThat(0.2, + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); // This should handle missing values and take the default_left path featureMap = new HashMap<>(2) {{ put("foo", 0.3); put("bar", null); }}; - assertThat(0.1, closeTo(tree.infer(featureMap), 0.00001)); + assertThat(0.1, + closeTo(((SingleValueInferenceResults)tree.infer(featureMap, InferenceParams.EMPTY_PARAMS)).value(), 0.00001)); } public void testTreeClassificationProbability() { @@ -153,31 +160,43 @@ public void testTreeClassificationProbability() { builder.addLeaf(leftChildNode.getRightChild(), 0.0); List featureNames = Arrays.asList("foo", "bar"); - Tree tree = builder.setFeatureNames(featureNames).build(); + Tree tree = builder.setFeatureNames(featureNames).setClassificationLabels(Arrays.asList("cat", "dog")).build(); + double eps = 0.000001; // This feature vector should hit the right child of the root node List featureVector = Arrays.asList(0.6, 0.0); + List expectedProbs = Arrays.asList(1.0, 0.0); + List expectedFields = Arrays.asList("dog", "cat"); Map featureMap = zipObjMap(featureNames, featureVector); - assertEquals(Arrays.asList(0.0, 1.0), tree.classificationProbability(featureMap)); + List probabilities = + ((ClassificationInferenceResults)tree.infer(featureMap, new InferenceParams(2))).getTopClasses(); + for(int i = 0; i < expectedProbs.size(); i++) { + assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps)); + assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i))); + } // This should hit the left child of the left child of the root node // i.e. it takes the path left, left featureVector = Arrays.asList(0.3, 0.7); featureMap = zipObjMap(featureNames, featureVector); - assertEquals(Arrays.asList(0.0, 1.0), tree.classificationProbability(featureMap)); - - // This should hit the right child of the left child of the root node - // i.e. it takes the path left, right - featureVector = Arrays.asList(0.3, 0.9); - featureMap = zipObjMap(featureNames, featureVector); - assertEquals(Arrays.asList(1.0, 0.0), tree.classificationProbability(featureMap)); + probabilities = + ((ClassificationInferenceResults)tree.infer(featureMap, new InferenceParams(2))).getTopClasses(); + for(int i = 0; i < expectedProbs.size(); i++) { + assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps)); + assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i))); + } // This should handle missing values and take the default_left path featureMap = new HashMap<>(2) {{ put("foo", 0.3); put("bar", null); }}; - assertEquals(1.0, tree.infer(featureMap), 0.00001); + probabilities = + ((ClassificationInferenceResults)tree.infer(featureMap, new InferenceParams(2))).getTopClasses(); + for(int i = 0; i < expectedProbs.size(); i++) { + assertThat(probabilities.get(i).getProbability(), closeTo(expectedProbs.get(i), eps)); + assertThat(probabilities.get(i).getClassification(), equalTo(expectedFields.get(i))); + } } public void testTreeWithNullRoot() { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java index ed487ed761b34..60df68eae7480 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java @@ -40,22 +40,16 @@ protected void doExecute(Task task, InferModelAction.Request request, ActionList ActionListener getModelListener = ActionListener.wrap( model -> { - TypedChainTaskExecutor> typedChainTaskExecutor = + TypedChainTaskExecutor typedChainTaskExecutor = new TypedChainTaskExecutor<>(client.threadPool().executor(ThreadPool.Names.SAME), // run through all tasks r -> true, // Always fail immediately and return an error ex -> true); - if (request.getTopClasses() != 0) { - request.getObjectsToInfer().forEach(stringObjectMap -> - typedChainTaskExecutor.add(chainedTask -> - model.classificationProbability(stringObjectMap, request.getTopClasses(), chainedTask)) - ); - } else { - request.getObjectsToInfer().forEach(stringObjectMap -> - typedChainTaskExecutor.add(chainedTask -> model.infer(stringObjectMap, chainedTask)) - ); - } + request.getObjectsToInfer().forEach(stringObjectMap -> + typedChainTaskExecutor.add(chainedTask -> + model.infer(stringObjectMap, request.getParams(), chainedTask))); + typedChainTaskExecutor.execute(ActionListener.wrap( inferenceResultsInterfaces -> listener.onResponse(new InferModelAction.Response(inferenceResultsInterfaces, model.getResultsType())), diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java index a19283d49c9fc..0019f452321c8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java @@ -5,27 +5,21 @@ */ package org.elasticsearch.xpack.ml.inference.loadingservice; -import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; -import java.util.ArrayList; -import java.util.Comparator; -import java.util.List; import java.util.Map; -import java.util.stream.Collectors; -import java.util.stream.IntStream; public class LocalModel implements Model { private final TrainedModelDefinition trainedModelDefinition; private final String modelId; + public LocalModel(String modelId, TrainedModelDefinition trainedModelDefinition) { this.trainedModelDefinition = trainedModelDefinition; this.modelId = modelId; @@ -46,76 +40,12 @@ public String getResultsType() { } @Override - public void infer(Map fields, ActionListener> listener) { - trainedModelDefinition.getPreProcessors().forEach(preProcessor -> preProcessor.process(fields)); - double value = trainedModelDefinition.getTrainedModel().infer(fields); - InferenceResults inferenceResults; - if (trainedModelDefinition.getTrainedModel().targetType() == TargetType.CLASSIFICATION) { - String classificationLabel = null; - if (trainedModelDefinition.getTrainedModel().classificationLabels() != null) { - assert value == Math.rint(value); - int classIndex = Double.valueOf(value).intValue(); - if (classIndex < 0 || classIndex >= trainedModelDefinition.getTrainedModel().classificationLabels().size()) { - listener.onFailure(new ElasticsearchStatusException( - "model returned classification [{}] which is invalid given labels {}", - RestStatus.INTERNAL_SERVER_ERROR, - classIndex, - trainedModelDefinition.getTrainedModel().classificationLabels())); - return; - } - classificationLabel = trainedModelDefinition.getTrainedModel().classificationLabels().get(classIndex); - } - inferenceResults = new ClassificationInferenceResults(value, classificationLabel, null); - } else { - inferenceResults = new RegressionInferenceResults(value); + public void infer(Map fields, InferenceParams params, ActionListener listener) { + try { + listener.onResponse(trainedModelDefinition.infer(fields, params)); + } catch (Exception e) { + listener.onFailure(e); } - listener.onResponse(inferenceResults); } - @Override - public void classificationProbability(Map fields, int topN, ActionListener> listener) { - if (topN == 0) { - infer(fields, listener); - return; - } - if (trainedModelDefinition.getTrainedModel().targetType() != TargetType.CLASSIFICATION) { - listener.onFailure(ExceptionsHelper - .badRequestException("top result probabilities is only available for classification models")); - return; - } - trainedModelDefinition.getPreProcessors().forEach(preProcessor -> preProcessor.process(fields)); - List probabilities = trainedModelDefinition.getTrainedModel().classificationProbability(fields); - int[] sortedIndices = IntStream.range(0, probabilities.size()) - .boxed() - .sorted(Comparator.comparing(probabilities::get).reversed()) - .mapToInt(i -> i) - .toArray(); - if (trainedModelDefinition.getTrainedModel().classificationLabels() != null) { - if (probabilities.size() != trainedModelDefinition.getTrainedModel().classificationLabels().size()) { - listener.onFailure(ExceptionsHelper - .badRequestException( - "model returned classification probabilities of size [{}] which is not equal to classification labels size [{}]", - probabilities.size(), - trainedModelDefinition.getTrainedModel().classificationLabels())); - return; - } - } - List labels = trainedModelDefinition.getTrainedModel().classificationLabels() == null ? - // If we don't have the labels we should return the top classification values anyways, they will just be numeric - IntStream.range(0, probabilities.size()).boxed().map(String::valueOf).collect(Collectors.toList()) : - trainedModelDefinition.getTrainedModel().classificationLabels(); - - int count = topN < 0 ? probabilities.size() : topN; - List topClassEntries = new ArrayList<>(count); - for(int i = 0; i < count; i++) { - int idx = sortedIndices[i]; - topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(labels.get(idx), probabilities.get(idx))); - } - - listener.onResponse(new ClassificationInferenceResults(((Number)sortedIndices[0]).doubleValue(), - trainedModelDefinition.getTrainedModel().classificationLabels() == null ? - null : - trainedModelDefinition.getTrainedModel().classificationLabels().get(sortedIndices[0]), - topClassEntries)); - } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java index 398382d7d19e7..27924a47aa153 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/Model.java @@ -7,6 +7,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; import java.util.Map; @@ -14,7 +15,6 @@ public interface Model { String getResultsType(); - void infer(Map fields, ActionListener> listener); + void infer(Map fields, InferenceParams inferenceParams, ActionListener listener); - void classificationProbability(Map fields, int topN, ActionListener> listener); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java index 989d87a7e79f2..9b5ddb3c1e857 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -23,21 +23,18 @@ import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import java.util.ArrayDeque; +import java.util.HashMap; import java.util.HashSet; -import java.util.Iterator; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Queue; import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; public class ModelLoadingService implements ClusterStateListener { private static final Logger logger = LogManager.getLogger(ModelLoadingService.class); - // TODO should these be ConcurrentHashMaps if all interactions are synchronized? - private final ConcurrentHashMap> loadedModels = new ConcurrentHashMap<>(); - private final ConcurrentHashMap>> loadingListeners = new ConcurrentHashMap<>(); + private final Map loadedModels = new HashMap<>(); + private final Map>> loadingListeners = new HashMap<>(); private final TrainedModelProvider provider; private final ThreadPool threadPool; @@ -49,10 +46,10 @@ public ModelLoadingService(TrainedModelProvider trainedModelProvider, ThreadPool public void getModel(String modelId, long modelVersion, ActionListener modelActionListener) { String key = modelKey(modelId, modelVersion); - Optional cachedModel = loadedModels.get(key); + MaybeModel cachedModel = loadedModels.get(key); if (cachedModel != null) { - if (cachedModel.isPresent()) { - modelActionListener.onResponse(cachedModel.get()); + if (cachedModel.isSuccess()) { + modelActionListener.onResponse(cachedModel.getModel()); return; } } @@ -77,10 +74,10 @@ public void getModel(String modelId, long modelVersion, ActionListener mo */ private boolean loadModelIfNecessary(String key, String modelId, long modelVersion, ActionListener modelActionListener) { synchronized (loadingListeners) { - Optional cachedModel = loadedModels.get(key); + MaybeModel cachedModel = loadedModels.get(key); if (cachedModel != null) { - if (cachedModel.isPresent()) { - modelActionListener.onResponse(cachedModel.get()); + if (cachedModel.isSuccess()) { + modelActionListener.onResponse(cachedModel.getModel()); return true; } // If the loaded model entry is there but is not present, that means the previous load attempt ran into an issue @@ -121,11 +118,11 @@ private void handleLoadSuccess(String modelKey, TrainedModelConfig trainedModelC // If there is no loadingListener that means the loading was canceled and the listener was already notified as such // Consequently, we should not store the retrieved model if (listeners != null) { - loadedModels.put(modelKey, Optional.of(loadedModel)); + loadedModels.put(modelKey, MaybeModel.of(loadedModel)); } } if (listeners != null) { - for(ActionListener listener = listeners.poll(); listener != null; listener = listeners.poll()) { + for (ActionListener listener = listeners.poll(); listener != null; listener = listeners.poll()) { listener.onResponse(loadedModel); } } @@ -138,11 +135,11 @@ private void handleLoadFailure(String modelKey, Exception failure) { if (listeners != null) { // If we failed to load and there were listeners present, that means that this model is referenced by a processor // Add an empty entry here so that we can attempt to load and cache the model again when it is accessed again. - loadedModels.computeIfAbsent(modelKey, (key) -> Optional.empty()); + loadedModels.computeIfAbsent(modelKey, (key) -> MaybeModel.of(failure)); } } if (listeners != null) { - for(ActionListener listener = listeners.poll(); listener != null; listener = listeners.poll()) { + for (ActionListener listener = listeners.poll(); listener != null; listener = listeners.poll()) { listener.onFailure(failure); } } @@ -159,9 +156,7 @@ public void clusterChanged(ClusterChangedEvent event) { synchronized (loadingListeners) { // If we had models still loading here but are no longer referenced // we should remove them from loadingListeners and alert the listeners - Iterator keyIterator = loadingListeners.keys().asIterator(); - while(keyIterator.hasNext()) { - String modelKey = keyIterator.next(); + for(String modelKey : loadingListeners.keySet()) { if (allReferencedModelKeys.contains(modelKey) == false) { drainWithFailure.addAll(loadingListeners.remove(modelKey)); } @@ -170,14 +165,12 @@ public void clusterChanged(ClusterChangedEvent event) { // Remove all cached models that are not referenced by any processors loadedModels.keySet().retainAll(allReferencedModelKeys); - // After removing the unreferenced models, now we need to know what referenced models should be loaded - // Remove all that are currently being loaded allReferencedModelKeys.removeAll(loadingListeners.keySet()); // Remove all that are fully loaded, will attempt empty model loading again loadedModels.forEach((id, optionalModel) -> { - if(optionalModel.isPresent()) { + if(optionalModel.isSuccess()) { allReferencedModelKeys.remove(id); } }); @@ -248,4 +241,40 @@ private static Set getReferencedModelKeys(IngestMetadata ingestMetadata) return allReferencedModelKeys; } + private static class MaybeModel { + + private final Model model; + private final Exception exception; + + static MaybeModel of(Model model) { + return new MaybeModel(model, null); + } + + static MaybeModel of(Exception exception) { + return new MaybeModel(null, exception); + } + + private MaybeModel(Model model, Exception exception) { + this.model = model; + this.exception = exception; + } + + Model getModel() { + return model; + } + + Exception getException() { + return exception; + } + + boolean isSuccess() { + return this.model != null; + } + + boolean isFailure() { + return this.exception != null; + } + + } + } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java index 2ca8d03103a7c..a537b1fc4cc1c 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java @@ -9,6 +9,8 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; +import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; @@ -35,6 +37,7 @@ public class LocalModelTests extends ESTestCase { public void testClassificationInfer() throws Exception { String modelId = "classification_model"; TrainedModelDefinition definition = new TrainedModelDefinition.Builder() + .setInput(new TrainedModelDefinition.Input(Arrays.asList("field1", "field2"))) .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) .setTrainedModel(buildClassification(false)) .build(); @@ -46,64 +49,40 @@ public void testClassificationInfer() throws Exception { put("categorical", "dog"); }}; - PlainActionFuture> future = new PlainActionFuture<>(); - model.infer(fields, future); - InferenceResults result = future.get(); + SingleValueInferenceResults result = getSingleValue(model, fields, InferenceParams.EMPTY_PARAMS); assertThat(result.value(), equalTo(0.0)); assertThat(result.valueAsString(), is("0.0")); - future = new PlainActionFuture<>(); - model.classificationProbability(fields, 0, future); - result = future.get(); - assertThat(result.value(), equalTo(0.0)); - assertThat(result.valueAsString(), is("0.0")); - - future = new PlainActionFuture<>(); - model.classificationProbability(fields, 1, future); - ClassificationInferenceResults classificationResult = (ClassificationInferenceResults)future.get(); + ClassificationInferenceResults classificationResult = + (ClassificationInferenceResults)getSingleValue(model, fields, new InferenceParams(1)); assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001)); assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("0")); // Test with labels definition = new TrainedModelDefinition.Builder() + .setInput(new TrainedModelDefinition.Input(Arrays.asList("field1", "field2"))) .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) .setTrainedModel(buildClassification(true)) .build(); model = new LocalModel(modelId, definition); - future = new PlainActionFuture<>(); - model.infer(fields, future); - result = future.get(); - assertThat(result.value(), equalTo(0.0)); - assertThat(result.valueAsString(), equalTo("not_to_be")); - - future = new PlainActionFuture<>(); - model.classificationProbability(fields, 0, future); - result = future.get(); + result = getSingleValue(model, fields, InferenceParams.EMPTY_PARAMS); assertThat(result.value(), equalTo(0.0)); assertThat(result.valueAsString(), equalTo("not_to_be")); - future = new PlainActionFuture<>(); - model.classificationProbability(fields, 1, future); - result = future.get(); - classificationResult = (ClassificationInferenceResults)result; + classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new InferenceParams(1)); assertThat(classificationResult.getTopClasses().get(0).getProbability(), closeTo(0.5498339973124778, 0.0000001)); assertThat(classificationResult.getTopClasses().get(0).getClassification(), equalTo("not_to_be")); - future = new PlainActionFuture<>(); - model.classificationProbability(fields, 2, future); - result = future.get(); - classificationResult = (ClassificationInferenceResults)result; + classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new InferenceParams(2)); assertThat(classificationResult.getTopClasses(), hasSize(2)); - future = new PlainActionFuture<>(); - model.classificationProbability(fields, -1, future); - result = future.get(); - classificationResult = (ClassificationInferenceResults)result; + classificationResult = (ClassificationInferenceResults)getSingleValue(model, fields, new InferenceParams(-1)); assertThat(classificationResult.getTopClasses(), hasSize(2)); } public void testRegression() throws Exception { TrainedModelDefinition trainedModelDefinition = new TrainedModelDefinition.Builder() + .setInput(new TrainedModelDefinition.Input(Arrays.asList("field1", "field2"))) .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotMap()))) .setTrainedModel(buildRegression()) .build(); @@ -115,15 +94,22 @@ public void testRegression() throws Exception { put("categorical", "dog"); }}; - PlainActionFuture> future = new PlainActionFuture<>(); - model.infer(fields, future); - InferenceResults results = future.get(); + SingleValueInferenceResults results = getSingleValue(model, fields, InferenceParams.EMPTY_PARAMS); assertThat(results.value(), equalTo(1.3)); - PlainActionFuture> failedFuture = new PlainActionFuture<>(); - model.classificationProbability(fields, -1, failedFuture); + PlainActionFuture failedFuture = new PlainActionFuture<>(); + model.infer(fields, new InferenceParams(2), failedFuture); ExecutionException ex = expectThrows(ExecutionException.class, failedFuture::get); - assertThat(ex.getCause().getMessage(), equalTo("top result probabilities is only available for classification models")); + assertThat(ex.getCause().getMessage(), + equalTo("Cannot return top classes for target_type [regression] and aggregate_output [weighted_sum]")); + } + + private static SingleValueInferenceResults getSingleValue(Model model, + Map fields, + InferenceParams params) throws Exception { + PlainActionFuture future = new PlainActionFuture<>(); + model.infer(fields, params, future); + return (SingleValueInferenceResults)future.get(); } private static Map oneHotMap() { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java index 60606536686a0..9ba32352fab0e 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java @@ -15,6 +15,8 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; +import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; @@ -26,7 +28,6 @@ import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.action.InferModelAction; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.junit.Before; @@ -65,11 +66,13 @@ public void testInferModels() throws Exception { TrainedModelConfig config1 = buildTrainedModelConfigBuilder(modelId2, 0) .setDefinition(new TrainedModelDefinition.Builder() .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding))) + .setInput(new TrainedModelDefinition.Input(Arrays.asList("field1", "field2"))) .setTrainedModel(buildClassification())) .build(Version.CURRENT); TrainedModelConfig config2 = buildTrainedModelConfigBuilder(modelId1, 0) .setDefinition(new TrainedModelDefinition.Builder() .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding))) + .setInput(new TrainedModelDefinition.Input(Arrays.asList("field1", "field2"))) .setTrainedModel(buildRegression())) .build(Version.CURRENT); AtomicReference putConfigHolder = new AtomicReference<>(); @@ -110,23 +113,26 @@ public void testInferModels() throws Exception { // Test regression InferModelAction.Request request = new InferModelAction.Request(modelId1, 0, toInfer, null); InferModelAction.Response response = client().execute(InferModelAction.INSTANCE, request).actionGet(); - assertThat(response.getInferenceResults().stream().map(InferenceResults::value).collect(Collectors.toList()), + assertThat(response.getInferenceResults().stream().map(i -> ((SingleValueInferenceResults)i).value()).collect(Collectors.toList()), contains(1.3, 1.25)); request = new InferModelAction.Request(modelId1, 0, toInfer2, null); response = client().execute(InferModelAction.INSTANCE, request).actionGet(); - assertThat(response.getInferenceResults().stream().map(InferenceResults::value).collect(Collectors.toList()), + assertThat(response.getInferenceResults().stream().map(i -> ((SingleValueInferenceResults)i).value()).collect(Collectors.toList()), contains(1.65, 1.55)); // Test classification request = new InferModelAction.Request(modelId2, 0, toInfer, null); response = client().execute(InferModelAction.INSTANCE, request).actionGet(); - assertThat(response.getInferenceResults().stream().map(InferenceResults::valueAsString).collect(Collectors.toList()), + assertThat(response.getInferenceResults() + .stream() + .map(i -> ((SingleValueInferenceResults)i).valueAsString()) + .collect(Collectors.toList()), contains("not_to_be", "to_be")); // Get top classes - request = new InferModelAction.Request(modelId2, 0, toInfer, 2); + request = new InferModelAction.Request(modelId2, 0, toInfer, new InferenceParams(2)); response = client().execute(InferModelAction.INSTANCE, request).actionGet(); assertThat(response.getResultsType(), equalTo(ClassificationInferenceResults.RESULT_TYPE)); @@ -146,7 +152,7 @@ public void testInferModels() throws Exception { greaterThan(classificationInferenceResults.getTopClasses().get(1).getProbability())); // Test that top classes restrict the number returned - request = new InferModelAction.Request(modelId2, 0, toInfer2, 1); + request = new InferModelAction.Request(modelId2, 0, toInfer2, new InferenceParams(1)); response = client().execute(InferModelAction.INSTANCE, request).actionGet(); assertThat(response.getResultsType(), equalTo(ClassificationInferenceResults.RESULT_TYPE)); From 94b83e41765328c0cc30f4bf8b1d47493f076d50 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Mon, 7 Oct 2019 16:03:47 -0400 Subject: [PATCH 10/13] removing unnecessary result type, sticking with named writables --- .../core/ml/action/InferModelAction.java | 30 ++++--------------- .../action/InferModelActionResponseTests.java | 15 ++++++++-- .../ml/action/TransportInferModelAction.java | 2 +- .../integration/ModelInferenceActionIT.java | 2 -- 4 files changed, 20 insertions(+), 29 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java index a576c1935c4ad..248d6180d3256 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InferModelAction.java @@ -12,9 +12,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; -import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -118,40 +116,24 @@ public int hashCode() { public static class Response extends ActionResponse { private final List inferenceResults; - private final String resultsType; - public Response(List inferenceResponse, String resultsType) { + public Response(List inferenceResults) { super(); - this.resultsType = ExceptionsHelper.requireNonNull(resultsType, "resultsType"); - this.inferenceResults = inferenceResponse == null ? - Collections.emptyList() : - Collections.unmodifiableList(inferenceResponse); + this.inferenceResults = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(inferenceResults, "inferenceResults")); } public Response(StreamInput in) throws IOException { super(in); - this.resultsType = in.readString(); - if(resultsType.equals(ClassificationInferenceResults.RESULT_TYPE)) { - this.inferenceResults = Collections.unmodifiableList(in.readList(ClassificationInferenceResults::new)); - } else if (this.resultsType.equals(RegressionInferenceResults.RESULT_TYPE)) { - this.inferenceResults = Collections.unmodifiableList(in.readList(RegressionInferenceResults::new)); - } else { - throw new IOException("Unrecognized result type [" + resultsType + "]"); - } + this.inferenceResults = Collections.unmodifiableList(in.readNamedWriteableList(InferenceResults.class)); } public List getInferenceResults() { return inferenceResults; } - public String getResultsType() { - return resultsType; - } - @Override public void writeTo(StreamOutput out) throws IOException { - out.writeString(resultsType); - out.writeCollection(inferenceResults); + out.writeNamedWriteableList(inferenceResults); } @Override @@ -159,12 +141,12 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; InferModelAction.Response that = (InferModelAction.Response) o; - return Objects.equals(resultsType, that.resultsType) && Objects.equals(inferenceResults, that.inferenceResults); + return Objects.equals(inferenceResults, that.inferenceResults); } @Override public int hashCode() { - return Objects.hash(resultsType, inferenceResults); + return Objects.hash(inferenceResults); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java index 4b80c927bed30..c0c6ea3719fc6 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java @@ -5,15 +5,19 @@ */ package org.elasticsearch.xpack.core.ml.action; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.core.ml.action.InferModelAction.Response; +import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResultsTests; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResultsTests; +import java.util.ArrayList; +import java.util.List; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -25,8 +29,7 @@ protected Response createTestInstance() { return new Response( Stream.generate(() -> randomInferenceResult(resultType)) .limit(randomIntBetween(0, 10)) - .collect(Collectors.toList()), - resultType); + .collect(Collectors.toList())); } private static InferenceResults randomInferenceResult(String resultType) { @@ -44,4 +47,12 @@ private static InferenceResults randomInferenceResult(String resultType) { protected Writeable.Reader instanceReader() { return Response::new; } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + List entries = new ArrayList<>(); + entries.addAll(new MlInferenceNamedXContentProvider().getNamedWriteables()); + return new NamedWriteableRegistry(entries); + } + } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java index 60df68eae7480..a2f79fc9d0437 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInferModelAction.java @@ -52,7 +52,7 @@ protected void doExecute(Task task, InferModelAction.Request request, ActionList typedChainTaskExecutor.execute(ActionListener.wrap( inferenceResultsInterfaces -> - listener.onResponse(new InferModelAction.Response(inferenceResultsInterfaces, model.getResultsType())), + listener.onResponse(new InferModelAction.Response(inferenceResultsInterfaces)), listener::onFailure )); }, diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java index 9ba32352fab0e..c00206a072fd2 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java @@ -134,7 +134,6 @@ public void testInferModels() throws Exception { // Get top classes request = new InferModelAction.Request(modelId2, 0, toInfer, new InferenceParams(2)); response = client().execute(InferModelAction.INSTANCE, request).actionGet(); - assertThat(response.getResultsType(), equalTo(ClassificationInferenceResults.RESULT_TYPE)); ClassificationInferenceResults classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(0); @@ -154,7 +153,6 @@ public void testInferModels() throws Exception { // Test that top classes restrict the number returned request = new InferModelAction.Request(modelId2, 0, toInfer2, new InferenceParams(1)); response = client().execute(InferModelAction.INSTANCE, request).actionGet(); - assertThat(response.getResultsType(), equalTo(ClassificationInferenceResults.RESULT_TYPE)); classificationInferenceResults = (ClassificationInferenceResults)response.getInferenceResults().get(0); assertThat(classificationInferenceResults.getTopClasses(), hasSize(1)); From d626a0d29a7e2624859f2eb5ab593674c814a86c Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Mon, 7 Oct 2019 16:38:38 -0400 Subject: [PATCH 11/13] reusing some method helpers in tests --- .../loadingservice/LocalModelTests.java | 4 +- .../integration/ModelInferenceActionIT.java | 102 +----------------- 2 files changed, 5 insertions(+), 101 deletions(-) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java index a537b1fc4cc1c..e66a5790d85e5 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModelTests.java @@ -119,7 +119,7 @@ private static Map oneHotMap() { return oneHotEncoding; } - private static TrainedModel buildClassification(boolean includeLabels) { + public static TrainedModel buildClassification(boolean includeLabels) { List featureNames = Arrays.asList("foo", "bar", "animal_cat", "animal_dog"); Tree tree1 = Tree.builder() .setFeatureNames(featureNames) @@ -165,7 +165,7 @@ private static TrainedModel buildClassification(boolean includeLabels) { .build(); } - private static TrainedModel buildRegression() { + public static TrainedModel buildRegression() { List featureNames = Arrays.asList("foo", "bar", "animal_cat", "animal_dog"); Tree tree1 = Tree.builder() .setFeatureNames(featureNames) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java index c00206a072fd2..032f84d16b52d 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java @@ -17,13 +17,6 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding; import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceParams; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; -import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; @@ -40,6 +33,8 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; +import static org.elasticsearch.xpack.ml.inference.loadingservice.LocalModelTests.buildClassification; +import static org.elasticsearch.xpack.ml.inference.loadingservice.LocalModelTests.buildRegression; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.equalTo; @@ -67,7 +62,7 @@ public void testInferModels() throws Exception { .setDefinition(new TrainedModelDefinition.Builder() .setPreProcessors(Arrays.asList(new OneHotEncoding("categorical", oneHotEncoding))) .setInput(new TrainedModelDefinition.Input(Arrays.asList("field1", "field2"))) - .setTrainedModel(buildClassification())) + .setTrainedModel(buildClassification(true))) .build(Version.CURRENT); TrainedModelConfig config2 = buildTrainedModelConfigBuilder(modelId1, 0) .setDefinition(new TrainedModelDefinition.Builder() @@ -169,97 +164,6 @@ public void testInferMissingModel() { } } - private static TrainedModel buildClassification() { - List featureNames = Arrays.asList("foo", "bar", "animal_cat", "animal_dog"); - Tree tree1 = Tree.builder() - .setFeatureNames(featureNames) - .setRoot(TreeNode.builder(0) - .setLeftChild(1) - .setRightChild(2) - .setSplitFeature(0) - .setThreshold(0.5)) - .addNode(TreeNode.builder(1).setLeafValue(1.0)) - .addNode(TreeNode.builder(2) - .setThreshold(0.8) - .setSplitFeature(1) - .setLeftChild(3) - .setRightChild(4)) - .addNode(TreeNode.builder(3).setLeafValue(0.0)) - .addNode(TreeNode.builder(4).setLeafValue(1.0)).build(); - Tree tree2 = Tree.builder() - .setFeatureNames(featureNames) - .setRoot(TreeNode.builder(0) - .setLeftChild(1) - .setRightChild(2) - .setSplitFeature(3) - .setThreshold(1.0)) - .addNode(TreeNode.builder(1).setLeafValue(0.0)) - .addNode(TreeNode.builder(2).setLeafValue(1.0)) - .build(); - Tree tree3 = Tree.builder() - .setFeatureNames(featureNames) - .setRoot(TreeNode.builder(0) - .setLeftChild(1) - .setRightChild(2) - .setSplitFeature(0) - .setThreshold(1.0)) - .addNode(TreeNode.builder(1).setLeafValue(1.0)) - .addNode(TreeNode.builder(2).setLeafValue(0.0)) - .build(); - return Ensemble.builder() - .setClassificationLabels(Arrays.asList("not_to_be", "to_be")) - .setTargetType(TargetType.CLASSIFICATION) - .setFeatureNames(featureNames) - .setTrainedModels(Arrays.asList(tree1, tree2, tree3)) - .setOutputAggregator(new WeightedMode(Arrays.asList(0.7, 0.5, 1.0))) - .build(); - } - - private static TrainedModel buildRegression() { - List featureNames = Arrays.asList("foo", "bar", "animal_cat", "animal_dog"); - Tree tree1 = Tree.builder() - .setFeatureNames(featureNames) - .setRoot(TreeNode.builder(0) - .setLeftChild(1) - .setRightChild(2) - .setSplitFeature(0) - .setThreshold(0.5)) - .addNode(TreeNode.builder(1).setLeafValue(0.3)) - .addNode(TreeNode.builder(2) - .setThreshold(0.0) - .setSplitFeature(3) - .setLeftChild(3) - .setRightChild(4)) - .addNode(TreeNode.builder(3).setLeafValue(0.1)) - .addNode(TreeNode.builder(4).setLeafValue(0.2)).build(); - Tree tree2 = Tree.builder() - .setFeatureNames(featureNames) - .setRoot(TreeNode.builder(0) - .setLeftChild(1) - .setRightChild(2) - .setSplitFeature(2) - .setThreshold(1.0)) - .addNode(TreeNode.builder(1).setLeafValue(1.5)) - .addNode(TreeNode.builder(2).setLeafValue(0.9)) - .build(); - Tree tree3 = Tree.builder() - .setFeatureNames(featureNames) - .setRoot(TreeNode.builder(0) - .setLeftChild(1) - .setRightChild(2) - .setSplitFeature(1) - .setThreshold(0.2)) - .addNode(TreeNode.builder(1).setLeafValue(1.5)) - .addNode(TreeNode.builder(2).setLeafValue(0.9)) - .build(); - return Ensemble.builder() - .setTargetType(TargetType.REGRESSION) - .setFeatureNames(featureNames) - .setTrainedModels(Arrays.asList(tree1, tree2, tree3)) - .setOutputAggregator(new WeightedSum(Arrays.asList(0.5, 0.5, 0.5))) - .build(); - } - private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId, long modelVersion) { return TrainedModelConfig.builder() .setCreatedBy("ml_test") From 77509060b5bb576f142dbaad0410d685ed3727ac Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Tue, 8 Oct 2019 07:36:14 -0400 Subject: [PATCH 12/13] minor changes for PR --- .../trainedmodel/InferenceHelpers.java | 1 - .../trainedmodel/ensemble/Ensemble.java | 4 ++-- .../ensemble/OutputAggregator.java | 2 ++ .../trainedmodel/ensemble/WeightedMode.java | 5 +++++ .../trainedmodel/ensemble/WeightedSum.java | 5 +++++ .../core/ml/inference/utils/Statistics.java | 18 ++++++++++-------- 6 files changed, 24 insertions(+), 11 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java index 18884ef252f87..b8a89018e8483 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java @@ -45,7 +45,6 @@ public static List topClasses(List classificationLabels); } - int count = numToInclude < 0 ? probabilities.size() : numToInclude; List topClassEntries = new ArrayList<>(count); for(int i = 0; i < count; i++) { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java index 22326e1d17e13..09e418cec916d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java @@ -115,8 +115,8 @@ public List getFeatureNames() { @Override public InferenceResults infer(Map fields, InferenceParams params) { - if ((targetType != TargetType.CLASSIFICATION || outputAggregator instanceof WeightedMode == false) && - params.getNumTopClasses() != 0) { + if (params.getNumTopClasses() != 0 && + (targetType != TargetType.CLASSIFICATION || outputAggregator.providesProbabilities() == false)) { throw ExceptionsHelper.badRequestException( "Cannot return top classes for target_type [{}] and aggregate_output [{}]", targetType, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java index 1f882b724ee94..012f474ab0618 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/OutputAggregator.java @@ -44,4 +44,6 @@ public interface OutputAggregator extends NamedXContentObject, NamedWriteable { * @return The name of the output aggregator */ String getName(); + + boolean providesProbabilities(); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java index 739a4e13d8659..0689d748b0ccb 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java @@ -158,4 +158,9 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(weights); } + + @Override + public boolean providesProbabilities() { + return true; + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java index f5812dabf88f2..9c5c2bf582e54 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java @@ -135,4 +135,9 @@ public int hashCode() { public Integer expectedValueSize() { return weights == null ? null : this.weights.size(); } + + @Override + public boolean providesProbabilities() { + return false; + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java index cb44d03e22bb2..44cca308ea794 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/utils/Statistics.java @@ -5,6 +5,8 @@ */ package org.elasticsearch.xpack.core.ml.inference.utils; +import org.elasticsearch.common.Numbers; + import java.util.List; import java.util.stream.Collectors; @@ -22,31 +24,31 @@ private Statistics(){} */ public static List softMax(List values) { Double expSum = 0.0; - Double max = values.stream().filter(v -> isInvalid(v) == false).max(Double::compareTo).orElse(null); + Double max = values.stream().filter(Statistics::isValid).max(Double::compareTo).orElse(null); if (max == null) { throw new IllegalArgumentException("no valid values present"); } - List exps = values.stream().map(v -> isInvalid(v) ? Double.NEGATIVE_INFINITY : v - max) + List exps = values.stream().map(v -> isValid(v) ? v - max : Double.NEGATIVE_INFINITY) .collect(Collectors.toList()); for (int i = 0; i < exps.size(); i++) { - if (isInvalid(exps.get(i)) == false) { + if (isValid(exps.get(i))) { Double exp = Math.exp(exps.get(i)); expSum += exp; exps.set(i, exp); } } for (int i = 0; i < exps.size(); i++) { - if (isInvalid(exps.get(i))) { - exps.set(i, 0.0); - } else { + if (isValid(exps.get(i))) { exps.set(i, exps.get(i)/expSum); + } else { + exps.set(i, 0.0); } } return exps; } - public static boolean isInvalid(Double v) { - return v == null || Double.isInfinite(v) || Double.isNaN(v); + private static boolean isValid(Double v) { + return v != null && Numbers.isValidDouble(v); } } From a2414103810b879a898630e31386c09680f224bc Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Tue, 8 Oct 2019 14:14:48 -0400 Subject: [PATCH 13/13] addressing PR comments --- .../xpack/core/XPackClientPlugin.java | 4 +-- .../MlInferenceNamedXContentProvider.java | 4 +-- .../ClassificationInferenceResults.java | 6 ++--- .../results/RegressionInferenceResults.java | 6 ++--- .../trainedmodel/InferenceHelpers.java | 20 +++++++------- .../xpack/core/ml/utils/ExceptionsHelper.java | 4 +++ .../action/InferModelActionResponseTests.java | 6 ++--- .../inference/loadingservice/LocalModel.java | 4 +-- .../loadingservice/ModelLoadingService.java | 27 ++++++++++++------- 9 files changed, 47 insertions(+), 34 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java index d6561379c0cfc..2c8553d68fe98 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java @@ -466,10 +466,10 @@ public List getNamedWriteables() { new NamedWriteableRegistry.Entry(OutputAggregator.class, WeightedMode.NAME.getPreferredName(), WeightedMode::new), // ML - Inference Results new NamedWriteableRegistry.Entry(InferenceResults.class, - ClassificationInferenceResults.RESULT_TYPE, + ClassificationInferenceResults.NAME, ClassificationInferenceResults::new), new NamedWriteableRegistry.Entry(InferenceResults.class, - RegressionInferenceResults.RESULT_TYPE, + RegressionInferenceResults.NAME, RegressionInferenceResults::new), // monitoring diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java index ee4eb0e9c1280..7b56a4c3b4da3 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java @@ -105,10 +105,10 @@ public List getNamedWriteables() { // Inference Results namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, - ClassificationInferenceResults.RESULT_TYPE, + ClassificationInferenceResults.NAME, ClassificationInferenceResults::new)); namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class, - RegressionInferenceResults.RESULT_TYPE, + RegressionInferenceResults.NAME, RegressionInferenceResults::new)); return namedWriteables; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java index 1aeb8d34cdd00..662585bedf51d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java @@ -24,7 +24,7 @@ public class ClassificationInferenceResults extends SingleValueInferenceResults { - public static final String RESULT_TYPE = "classification"; + public static final String NAME = "classification"; public static final ParseField CLASSIFICATION_LABEL = new ParseField("classification_label"); public static final ParseField TOP_CLASSES = new ParseField("top_classes"); @@ -102,12 +102,12 @@ public void writeResult(IngestDocument document, String resultField) { @Override public String getWriteableName() { - return RESULT_TYPE; + return NAME; } @Override public String getName() { - return RESULT_TYPE; + return NAME; } public static class TopClassEntry implements ToXContentObject, Writeable { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java index 1b01a2dff1f08..e186489b91dab 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java @@ -16,7 +16,7 @@ public class RegressionInferenceResults extends SingleValueInferenceResults { - public static final String RESULT_TYPE = "regression"; + public static final String NAME = "regression"; public RegressionInferenceResults(double value) { super(value); @@ -58,11 +58,11 @@ public void writeResult(IngestDocument document, String resultField) { @Override public String getWriteableName() { - return RESULT_TYPE; + return NAME; } @Override public String getName() { - return RESULT_TYPE; + return NAME; } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java index b8a89018e8483..5e37b237e9f79 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java @@ -32,20 +32,21 @@ public static List topClasses(List .mapToInt(i -> i) .toArray(); - List labels = classificationLabels == null ? - // If we don't have the labels we should return the top classification values anyways, they will just be numeric - IntStream.range(0, probabilities.size()).boxed().map(String::valueOf).collect(Collectors.toList()) : - classificationLabels; - - if (probabilities.size() != labels.size()) { + if (classificationLabels != null && probabilities.size() != classificationLabels.size()) { throw ExceptionsHelper - .badRequestException( + .serverError( "model returned classification probabilities of size [{}] which is not equal to classification labels size [{}]", + null, probabilities.size(), classificationLabels); } - int count = numToInclude < 0 ? probabilities.size() : numToInclude; + List labels = classificationLabels == null ? + // If we don't have the labels we should return the top classification values anyways, they will just be numeric + IntStream.range(0, probabilities.size()).boxed().map(String::valueOf).collect(Collectors.toList()) : + classificationLabels; + + int count = numToInclude < 0 ? probabilities.size() : Math.min(numToInclude, probabilities.size()); List topClassEntries = new ArrayList<>(count); for(int i = 0; i < count; i++) { int idx = sortedIndices[i]; @@ -63,8 +64,9 @@ public static String classificationLabel(double inferenceValue, @Nullable List= classificationLabels.size()) { - throw ExceptionsHelper.badRequestException( + throw ExceptionsHelper.serverError( "model returned classification value of [{}] which is not a valid index in classification labels [{}]", + null, label, classificationLabels); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java index 320eace983590..8dfcc5fc59977 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/ExceptionsHelper.java @@ -51,6 +51,10 @@ public static ElasticsearchException serverError(String msg, Throwable cause) { return new ElasticsearchException(msg, cause); } + public static ElasticsearchException serverError(String msg, Throwable cause, Object... args) { + return new ElasticsearchException(msg, cause, args); + } + public static ElasticsearchStatusException conflictStatusException(String msg, Throwable cause, Object... args) { return new ElasticsearchStatusException(msg, RestStatus.CONFLICT, cause, args); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java index c0c6ea3719fc6..9e72d1c4e682a 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InferModelActionResponseTests.java @@ -25,7 +25,7 @@ public class InferModelActionResponseTests extends AbstractWireSerializingTestCa @Override protected Response createTestInstance() { - String resultType = randomFrom(ClassificationInferenceResults.RESULT_TYPE, RegressionInferenceResults.RESULT_TYPE); + String resultType = randomFrom(ClassificationInferenceResults.NAME, RegressionInferenceResults.NAME); return new Response( Stream.generate(() -> randomInferenceResult(resultType)) .limit(randomIntBetween(0, 10)) @@ -33,9 +33,9 @@ protected Response createTestInstance() { } private static InferenceResults randomInferenceResult(String resultType) { - if (resultType.equals(ClassificationInferenceResults.RESULT_TYPE)) { + if (resultType.equals(ClassificationInferenceResults.NAME)) { return ClassificationInferenceResultsTests.createRandomResults(); - } else if (resultType.equals(RegressionInferenceResults.RESULT_TYPE)) { + } else if (resultType.equals(RegressionInferenceResults.NAME)) { return RegressionInferenceResultsTests.createRandomResults(); } else { fail("unexpected result type [" + resultType + "]"); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java index 0019f452321c8..e5253b3d5b173 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/LocalModel.java @@ -29,9 +29,9 @@ public LocalModel(String modelId, TrainedModelDefinition trainedModelDefinition) public String getResultsType() { switch (trainedModelDefinition.getTrainedModel().targetType()) { case CLASSIFICATION: - return ClassificationInferenceResults.RESULT_TYPE; + return ClassificationInferenceResults.NAME; case REGRESSION: - return RegressionInferenceResults.RESULT_TYPE; + return RegressionInferenceResults.NAME; default: throw ExceptionsHelper.badRequestException("Model [{}] has unsupported target type [{}]", modelId, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java index 9b5ddb3c1e857..b4fc552ba5f93 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -23,6 +23,7 @@ import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import java.util.ArrayDeque; +import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -38,7 +39,9 @@ public class ModelLoadingService implements ClusterStateListener { private final TrainedModelProvider provider; private final ThreadPool threadPool; - public ModelLoadingService(TrainedModelProvider trainedModelProvider, ThreadPool threadPool, ClusterService clusterService) { + public ModelLoadingService(TrainedModelProvider trainedModelProvider, + ThreadPool threadPool, + ClusterService clusterService) { this.provider = trainedModelProvider; this.threadPool = threadPool; clusterService.addListener(this); @@ -152,13 +155,13 @@ public void clusterChanged(ClusterChangedEvent event) { IngestMetadata currentIngestMetadata = state.metaData().custom(IngestMetadata.TYPE); Set allReferencedModelKeys = getReferencedModelKeys(currentIngestMetadata); // The listeners still waiting for a model and we are canceling the load? - Queue> drainWithFailure = new ArrayDeque<>(); + List>>> drainWithFailure = new ArrayList<>(); synchronized (loadingListeners) { // If we had models still loading here but are no longer referenced // we should remove them from loadingListeners and alert the listeners - for(String modelKey : loadingListeners.keySet()) { + for (String modelKey : loadingListeners.keySet()) { if (allReferencedModelKeys.contains(modelKey) == false) { - drainWithFailure.addAll(loadingListeners.remove(modelKey)); + drainWithFailure.add(Tuple.tuple(splitModelKey(modelKey).v1(), new ArrayList<>(loadingListeners.remove(modelKey)))); } } @@ -170,18 +173,22 @@ public void clusterChanged(ClusterChangedEvent event) { // Remove all that are fully loaded, will attempt empty model loading again loadedModels.forEach((id, optionalModel) -> { - if(optionalModel.isSuccess()) { + if (optionalModel.isSuccess()) { allReferencedModelKeys.remove(id); } }); // Populate loadingListeners key so we know that we are currently loading the model - for(String modelId : allReferencedModelKeys) { + for (String modelId : allReferencedModelKeys) { loadingListeners.put(modelId, new ArrayDeque<>()); } } - for(ActionListener listener = drainWithFailure.poll(); listener != null; listener = drainWithFailure.poll()) { - listener.onFailure( - new ElasticsearchException("Cancelling model load and inference as it is no longer referenced by a pipeline")); + for (Tuple>> modelAndListeners : drainWithFailure) { + final String msg = new ParameterizedMessage( + "Cancelling load of model [{}] as it is no longer referenced by a pipeline", + modelAndListeners.v1()).getFormat(); + for (ActionListener listener : modelAndListeners.v2()) { + listener.onFailure(new ElasticsearchException(msg)); + } } loadModels(allReferencedModelKeys); } @@ -193,7 +200,7 @@ private void loadModels(Set modelKeys) { } // Execute this on a utility thread as when the callbacks occur we don't want them tying up the cluster listener thread pool threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> { - for(String modelKey : modelKeys) { + for (String modelKey : modelKeys) { Tuple modelIdAndVersion = splitModelKey(modelKey); this.loadModel(modelKey, modelIdAndVersion.v1(), modelIdAndVersion.v2()); }