From 9aa845e86381f46cd4eccbbd29d9bd375e62976a Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Thu, 18 Feb 2021 09:41:50 -0500 Subject: [PATCH 1/3] [ML] adds new trained model alias API to simplify trained model updates and deployments (#68922) A `model_alias` allows trained models to be referred by a user defined moniker. This not only improves the readability and simplicity of numerous API calls, but it allows for simpler deployment and upgrade procedures for trained models. Previously, if you referenced a model ID directly within an ingest pipeline, when you have a new model that performs better than an earlier referenced model, you have to update the pipeline itself. If this model was used in numerous pipelines, ALL those pipelines would have to be updated. When using a `model_alias` in an ingest pipeline, only that `model_alias` needs to be updated. Then, the underlying referenced model will change in place for all ingest pipelines automatically. An additional benefit is that the model referenced is not changed until it is fully loaded into cache, this way throughput is not hampered by changing models. --- .../apis/get-trained-models-stats.asciidoc | 2 +- .../apis/get-trained-models.asciidoc | 2 +- .../ml/df-analytics/apis/index.asciidoc | 1 + .../apis/ml-df-analytics-apis.asciidoc | 3 +- .../apis/put-trained-models-aliases.asciidoc | 89 ++++++ docs/reference/ml/ml-shared.asciidoc | 4 + .../elasticsearch/common/util/set/Sets.java | 6 + .../xpack/core/XPackClientPlugin.java | 8 + .../action/GetTrainedModelsStatsAction.java | 14 +- .../ml/action/InternalInferModelAction.java | 30 +- .../ml/action/PutTrainedModelAliasAction.java | 119 ++++++++ .../core/ml/inference/ModelAliasMetadata.java | 221 ++++++++++++++ .../core/ml/inference/TrainedModelConfig.java | 39 ++- .../xpack/core/ml/job/messages/Messages.java | 5 + ...InternalInferModelActionResponseTests.java | 1 + ...utTrainedModelAliasActionRequestTests.java | 68 +++++ .../integration/MlRestTestStateCleaner.java | 25 ++ .../ml/qa/ml-with-security/build.gradle | 4 + .../ml/integration/InferenceIngestIT.java | 84 ++++++ .../ml/integration/MlNativeIntegTestCase.java | 3 + .../ml/integration/InferenceProcessorIT.java | 260 ++++++++++++---- .../ChunkedTrainedModelPersisterIT.java | 16 +- .../xpack/ml/MachineLearning.java | 17 +- .../TransportDeleteTrainedModelAction.java | 65 +++- .../TransportGetTrainedModelsAction.java | 14 +- .../TransportGetTrainedModelsStatsAction.java | 47 ++- .../TransportInternalInferModelAction.java | 4 +- .../TransportPutTrainedModelAction.java | 8 + .../TransportPutTrainedModelAliasAction.java | 207 +++++++++++++ .../inference/ingest/InferenceProcessor.java | 7 +- .../loadingservice/ModelLoadingService.java | 285 ++++++++++++++---- .../persistence/TrainedModelProvider.java | 90 +++++- .../RestPutTrainedModelAliasAction.java | 57 ++++ ...sportGetTrainedModelsStatsActionTests.java | 5 +- .../ingest/InferenceProcessorTests.java | 50 ++- .../ModelLoadingServiceTests.java | 149 ++++++++- .../xpack/security/operator/Constants.java | 1 + .../test/rest/AbstractXPackRestTest.java | 2 +- .../api/ml.put_trained_model_alias.json | 40 +++ .../test/ml/get_trained_model_stats.yml | 23 ++ .../rest-api-spec/test/ml/inference_crud.yml | 110 +++++-- 41 files changed, 1966 insertions(+), 219 deletions(-) create mode 100644 docs/reference/ml/df-analytics/apis/put-trained-models-aliases.asciidoc create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAliasAction.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/ModelAliasMetadata.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAliasActionRequestTests.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAliasAction.java create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestPutTrainedModelAliasAction.java create mode 100644 x-pack/plugin/src/test/resources/rest-api-spec/api/ml.put_trained_model_alias.json diff --git a/docs/reference/ml/df-analytics/apis/get-trained-models-stats.asciidoc b/docs/reference/ml/df-analytics/apis/get-trained-models-stats.asciidoc index 03533a018cb96..4afba286808bf 100644 --- a/docs/reference/ml/df-analytics/apis/get-trained-models-stats.asciidoc +++ b/docs/reference/ml/df-analytics/apis/get-trained-models-stats.asciidoc @@ -47,7 +47,7 @@ request by using a comma-separated list of model IDs or a wildcard expression. ``:: (Optional, string) -include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id] +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id-or-alias] [[ml-get-trained-models-stats-query-params]] diff --git a/docs/reference/ml/df-analytics/apis/get-trained-models.asciidoc b/docs/reference/ml/df-analytics/apis/get-trained-models.asciidoc index 8c1f4cc9be523..c04cc820a8553 100644 --- a/docs/reference/ml/df-analytics/apis/get-trained-models.asciidoc +++ b/docs/reference/ml/df-analytics/apis/get-trained-models.asciidoc @@ -50,7 +50,7 @@ using a comma-separated list of model IDs or a wildcard expression. ``:: (Optional, string) -include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id] +include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id-or-alias] [[ml-get-trained-models-query-params]] diff --git a/docs/reference/ml/df-analytics/apis/index.asciidoc b/docs/reference/ml/df-analytics/apis/index.asciidoc index 63a46480ce757..958298f027874 100644 --- a/docs/reference/ml/df-analytics/apis/index.asciidoc +++ b/docs/reference/ml/df-analytics/apis/index.asciidoc @@ -2,6 +2,7 @@ include::ml-df-analytics-apis.asciidoc[leveloffset=+1] //CREATE include::put-dfanalytics.asciidoc[leveloffset=+2] include::put-trained-models.asciidoc[leveloffset=+2] +include::put-trained-models-aliases.asciidoc[leveloffset=+2] //UPDATE include::update-dfanalytics.asciidoc[leveloffset=+2] //DELETE diff --git a/docs/reference/ml/df-analytics/apis/ml-df-analytics-apis.asciidoc b/docs/reference/ml/df-analytics/apis/ml-df-analytics-apis.asciidoc index dae8757275ee4..7c485f6c35f48 100644 --- a/docs/reference/ml/df-analytics/apis/ml-df-analytics-apis.asciidoc +++ b/docs/reference/ml/df-analytics/apis/ml-df-analytics-apis.asciidoc @@ -22,8 +22,9 @@ You can use the following APIs to perform {infer} operations. * <> * <> * <> +* <> -You can deploy a trained model to make predictions in an ingest pipeline or in +You can deploy a trained model to make predictions in an ingest pipeline or in an aggregation. Refer to the following documentation to learn more. * <> diff --git a/docs/reference/ml/df-analytics/apis/put-trained-models-aliases.asciidoc b/docs/reference/ml/df-analytics/apis/put-trained-models-aliases.asciidoc new file mode 100644 index 0000000000000..6869af8fbc427 --- /dev/null +++ b/docs/reference/ml/df-analytics/apis/put-trained-models-aliases.asciidoc @@ -0,0 +1,89 @@ +[role="xpack"] +[testenv="platinum"] +[[put-trained-models-aliases]] += Put Trained Models Aliases API +[subs="attributes"] +++++ +Put Trained Models Aliases +++++ + +Creates a trained models alias. These model aliases can be used instead of the trained model ID +when referencing the model in the stack. Model aliases must be unique, and a trained model can have +more than one model alias referring to it. But a model alias can only refer to a single trained model. + +beta::[] + +[[ml-put-trained-models-aliases-request]] +== {api-request-title} + +`PUT _ml/trained_models//model_aliases/` + + +[[ml-put-trained-models-aliases-prereq]] +== {api-prereq-title} + +If the {es} {security-features} are enabled, you must have the following +built-in roles and privileges: + +* `machine_learning_admin` + +For more information, see <>, <>, and +{ml-docs-setup-privileges}. + +[[ml-put-trained-models-aliases-desc]] +== {api-description-title} + +This API creates a new model alias to refer to trained models, or updates an existing +trained model's alias. + +When updating an existing model alias to a new model ID, this API will return a error if the models +are of different inference types. Example, if attempting to put the model alias +`flights-delay-prediction` from a regression model to a classification model, the API will error. + +The API will return a warning if there are very few input fields in common between the old +and new models for the model alias. + +[[ml-put-trained-models-aliases-path-params]] +== {api-path-parms-title} + +`model_id`:: +(Required, string) +The trained model ID to which the model alias should refer. + +`model_alias`:: +(Required, string) +The model alias to create or update. The model_alias cannot end in numbers. + +[[ml-put-trained-models-aliases-query-params]] +== {api-query-parms-title} + +`reassign`:: +(Optional, boolean) +Should the `model_alias` get reassigned to the provided `model_id` if it is already +assigned to a model. Defaults to false. The API will return an error if the `model_alias` +is already assigned to a model but this parameter is `false`. + +[[ml-put-trained-models-aliases-example]] +== {api-examples-title} + +[[ml-put-trained-models-aliases-example-new-alias]] +=== Creating a new model alias + +The following example shows how to create a new model alias for a trained model ID. + +[source,console] +-------------------------------------------------- +PUT _ml/trained_models/flight-delay-prediction-1574775339910/model_aliases/flight_delay_model +-------------------------------------------------- +// TEST[skip:setup kibana sample data] + +[[ml-put-trained-models-aliases-example-put-alias]] +=== Updating an existing model alias + +The following example shows how to reassign an existing model alias for a trained model ID. + +[source,console] +-------------------------------------------------- +PUT _ml/trained_models/flight-delay-prediction-1580004349800/model_aliases/flight_delay_model?reassign=true +-------------------------------------------------- +// TEST[skip:setup kibana sample data] diff --git a/docs/reference/ml/ml-shared.asciidoc b/docs/reference/ml/ml-shared.asciidoc index 1aaea50d327ca..2bf435c308014 100644 --- a/docs/reference/ml/ml-shared.asciidoc +++ b/docs/reference/ml/ml-shared.asciidoc @@ -1149,6 +1149,10 @@ tag::model-id[] The unique identifier of the trained model. end::model-id[] +tag::model-id-or-alias[] +The unique identifier of the trained model or a model alias. +end::model-id-or-alias[] + tag::model-memory-limit[] The approximate maximum amount of memory resources that are required for analytical processing. Once this limit is approached, data pruning becomes diff --git a/server/src/main/java/org/elasticsearch/common/util/set/Sets.java b/server/src/main/java/org/elasticsearch/common/util/set/Sets.java index 30776b27c7281..a2504393ceff9 100644 --- a/server/src/main/java/org/elasticsearch/common/util/set/Sets.java +++ b/server/src/main/java/org/elasticsearch/common/util/set/Sets.java @@ -60,6 +60,12 @@ public static boolean haveEmptyIntersection(Set left, Set right) { return left.stream().noneMatch(right::contains); } + public static boolean haveNonEmptyIntersection(Set left, Set right) { + Objects.requireNonNull(left); + Objects.requireNonNull(right); + return left.stream().anyMatch(right::contains); + } + /** * The relative complement, or difference, of the specified left and right set. Namely, the resulting set contains all the elements that * are in the left set but not in the right set. Neither input is mutated by this operation, an entirely new set is returned. diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java index 774f13f901d1b..0f24c3625e778 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java @@ -108,6 +108,7 @@ import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction; import org.elasticsearch.xpack.core.ml.action.ExplainDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.FinalizeJobExecutionAction; +import org.elasticsearch.xpack.core.ml.inference.ModelAliasMetadata; import org.elasticsearch.xpack.core.rollup.action.RollupIndexerAction; import org.elasticsearch.xpack.core.ml.action.FlushJobAction; import org.elasticsearch.xpack.core.ml.action.ForecastJobAction; @@ -534,6 +535,8 @@ public List getNamedWriteables() { // logstash new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.LOGSTASH, LogstashFeatureSetUsage::new), // ML - Custom metadata + new NamedWriteableRegistry.Entry(Metadata.Custom.class, ModelAliasMetadata.NAME, ModelAliasMetadata::new), + new NamedWriteableRegistry.Entry(NamedDiff.class, ModelAliasMetadata.NAME, ModelAliasMetadata::readDiffFrom), new NamedWriteableRegistry.Entry(Metadata.Custom.class, "ml", MlMetadata::new), new NamedWriteableRegistry.Entry(NamedDiff.class, "ml", MlMetadata.MlMetadataDiff::new), // ML - Persistent action requests @@ -712,6 +715,11 @@ public List getNamedXContent() { // ML - Custom metadata new NamedXContentRegistry.Entry(Metadata.Custom.class, new ParseField("ml"), parser -> MlMetadata.LENIENT_PARSER.parse(parser, null).build()), + new NamedXContentRegistry.Entry( + Metadata.Custom.class, + new ParseField(ModelAliasMetadata.NAME), + ModelAliasMetadata::fromXContent + ), // ML - Persistent action requests new NamedXContentRegistry.Entry(PersistentTaskParams.class, new ParseField(MlTasks.DATAFEED_TASK_NAME), StartDatafeedAction.DatafeedParams::fromXContent), diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java index cf3e52338340a..b4ad5d4c865db 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java @@ -191,7 +191,7 @@ protected Reader getReader() { public static class Builder { private long totalModelCount; - private Set expandedIds; + private Map> expandedIdsWithAliases; private Map ingestStatsMap; private Map inferenceStatsMap; @@ -200,13 +200,13 @@ public Builder setTotalModelCount(long totalModelCount) { return this; } - public Builder setExpandedIds(Set expandedIds) { - this.expandedIds = expandedIds; + public Builder setExpandedIdsWithAliases(Map> expandedIdsWithAliases) { + this.expandedIdsWithAliases = expandedIdsWithAliases; return this; } - public Set getExpandedIds() { - return this.expandedIds; + public Map> getExpandedIdsWithAliases() { + return this.expandedIdsWithAliases; } public Builder setIngestStatsByModelId(Map ingestStatsByModelId) { @@ -220,8 +220,8 @@ public Builder setInferenceStatsByModelId(Map infereceSt } public Response build() { - List trainedModelStats = new ArrayList<>(expandedIds.size()); - expandedIds.forEach(id -> { + List trainedModelStats = new ArrayList<>(expandedIdsWithAliases.size()); + expandedIdsWithAliases.keySet().forEach(id -> { IngestStats ingestStats = ingestStatsMap.get(id); InferenceStats inferenceStats = inferenceStatsMap.get(id); trainedModelStats.add(new TrainedModelStats( diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelAction.java index b178e4c239405..67b02848d6b2c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelAction.java @@ -160,18 +160,25 @@ public String toString() { public static class Response extends ActionResponse { private final List inferenceResults; + private final String modelId; private final boolean isLicensed; - public Response(List inferenceResults, boolean isLicensed) { + public Response(List inferenceResults, String modelId, boolean isLicensed) { super(); this.inferenceResults = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(inferenceResults, "inferenceResults")); this.isLicensed = isLicensed; + this.modelId = modelId; } public Response(StreamInput in) throws IOException { super(in); this.inferenceResults = Collections.unmodifiableList(in.readNamedWriteableList(InferenceResults.class)); this.isLicensed = in.readBoolean(); + if (in.getVersion().onOrAfter(Version.V_7_13_0)) { + this.modelId = in.readOptionalString(); + } else { + this.modelId = null; + } } public List getInferenceResults() { @@ -182,10 +189,17 @@ public boolean isLicensed() { return isLicensed; } + public String getModelId() { + return modelId; + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeNamedWriteableList(inferenceResults); out.writeBoolean(isLicensed); + if (out.getVersion().onOrAfter(Version.V_7_13_0)) { + out.writeOptionalString(modelId); + } } @Override @@ -193,12 +207,14 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; InternalInferModelAction.Response that = (InternalInferModelAction.Response) o; - return isLicensed == that.isLicensed && Objects.equals(inferenceResults, that.inferenceResults); + return isLicensed == that.isLicensed + && Objects.equals(inferenceResults, that.inferenceResults) + && Objects.equals(modelId, that.modelId); } @Override public int hashCode() { - return Objects.hash(inferenceResults, isLicensed); + return Objects.hash(inferenceResults, isLicensed, modelId); } public static Builder builder() { @@ -207,6 +223,7 @@ public static Builder builder() { public static class Builder { private List inferenceResults; + private String modelId; private boolean isLicensed; public Builder setInferenceResults(List inferenceResults) { @@ -219,8 +236,13 @@ public Builder setLicensed(boolean licensed) { return this; } + public Builder setModelId(String modelId) { + this.modelId = modelId; + return this; + } + public Response build() { - return new Response(inferenceResults, isLicensed); + return new Response(inferenceResults, modelId, isLicensed); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAliasAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAliasAction.java new file mode 100644 index 0000000000000..d5078c1dadd21 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAliasAction.java @@ -0,0 +1,119 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.support.master.AcknowledgedRequest; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Locale; +import java.util.Objects; +import java.util.regex.Pattern; + +import static org.elasticsearch.action.ValidateActions.addValidationError; +import static org.elasticsearch.xpack.core.ml.job.messages.Messages.INVALID_MODEL_ALIAS; + +public class PutTrainedModelAliasAction extends ActionType { + + // NOTE this is similar to our valid ID check. The difference here is that model_aliases cannot end in numbers + // This is to protect our automatic model naming conventions from hitting weird model_alias conflicts + private static final Pattern VALID_MODEL_ALIAS_CHAR_PATTERN = Pattern.compile("[a-z0-9](?:[a-z0-9_\\-\\.]*[a-z])?"); + + public static final PutTrainedModelAliasAction INSTANCE = new PutTrainedModelAliasAction(); + public static final String NAME = "cluster:admin/xpack/ml/inference/model_aliases/put"; + + private PutTrainedModelAliasAction() { + super(NAME, AcknowledgedResponse::readFrom); + } + + public static class Request extends AcknowledgedRequest { + + public static final String MODEL_ALIAS = "model_alias"; + public static final String REASSIGN = "reassign"; + + private final String modelAlias; + private final String modelId; + private final boolean reassign; + + public Request(String modelAlias, String modelId, boolean reassign) { + this.modelAlias = ExceptionsHelper.requireNonNull(modelAlias, MODEL_ALIAS); + this.modelId = ExceptionsHelper.requireNonNull(modelId, TrainedModelConfig.MODEL_ID); + this.reassign = reassign; + } + + public Request(StreamInput in) throws IOException { + super(in); + this.modelAlias = in.readString(); + this.modelId = in.readString(); + this.reassign = in.readBoolean(); + } + + public String getModelAlias() { + return modelAlias; + } + + public String getModelId() { + return modelId; + } + + public boolean isReassign() { + return reassign; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(modelAlias); + out.writeString(modelId); + out.writeBoolean(reassign); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + if (modelAlias.equals(modelId)) { + validationException = addValidationError( + String.format( + Locale.ROOT, + "model_alias [%s] cannot equal model_id [%s]", + modelAlias, + modelId + ), + validationException + ); + } + if (VALID_MODEL_ALIAS_CHAR_PATTERN.matcher(modelAlias).matches() == false) { + validationException = addValidationError(Messages.getMessage(INVALID_MODEL_ALIAS, modelAlias), validationException); + } + return validationException; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Request request = (Request) o; + return Objects.equals(modelAlias, request.modelAlias) + && Objects.equals(modelId, request.modelId) + && Objects.equals(reassign, request.reassign); + } + + @Override + public int hashCode() { + return Objects.hash(modelAlias, modelId, reassign); + } + + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/ModelAliasMetadata.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/ModelAliasMetadata.java new file mode 100644 index 0000000000000..917277be853a5 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/ModelAliasMetadata.java @@ -0,0 +1,221 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference; + +import org.elasticsearch.Version; +import org.elasticsearch.cluster.AbstractDiffable; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.Diff; +import org.elasticsearch.cluster.DiffableUtils; +import org.elasticsearch.cluster.NamedDiff; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Collections; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +/** + * Custom {@link Metadata} implementation for storing a map of model aliases that point to model IDs + */ +public class ModelAliasMetadata implements Metadata.Custom { + + public static final String NAME = "trained_model_alias"; + + public static final ModelAliasMetadata EMPTY = new ModelAliasMetadata(new HashMap<>()); + + public static ModelAliasMetadata fromState(ClusterState cs) { + ModelAliasMetadata modelAliasMetadata = cs.metadata().custom(NAME); + return modelAliasMetadata == null ? EMPTY : modelAliasMetadata; + } + + public static NamedDiff readDiffFrom(StreamInput in) throws IOException { + return new ModelAliasMetadataDiff(in); + } + + private static final ParseField MODEL_ALIASES = new ParseField("model_aliases"); + private static final ParseField MODEL_ID = new ParseField("model_id"); + + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + NAME, + // to protect BWC serialization + true, + args -> new ModelAliasMetadata((Map)args[0]) + ); + + static { + PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> { + Map modelAliases = new HashMap<>(); + while (p.nextToken() != XContentParser.Token.END_OBJECT) { + String modelAlias = p.currentName(); + modelAliases.put(modelAlias, ModelAliasEntry.fromXContent(p)); + } + return modelAliases; + }, MODEL_ALIASES); + } + + public static ModelAliasMetadata fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final Map modelAliases; + + public ModelAliasMetadata(Map modelAliases) { + this.modelAliases = Collections.unmodifiableMap(modelAliases); + } + + public ModelAliasMetadata(StreamInput in) throws IOException { + this.modelAliases = Collections.unmodifiableMap(in.readMap(StreamInput::readString, ModelAliasEntry::new)); + } + + public Map modelAliases() { + return modelAliases; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(MODEL_ALIASES.getPreferredName()); + for (Map.Entry modelAliasEntry : modelAliases.entrySet()) { + builder.field(modelAliasEntry.getKey(), modelAliasEntry.getValue()); + } + builder.endObject(); + return builder; + } + + @Override + public Diff diff(Metadata.Custom previousState) { + return new ModelAliasMetadataDiff((ModelAliasMetadata) previousState, this); + } + + @Override + public EnumSet context() { + return Metadata.ALL_CONTEXTS; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public Version getMinimalSupportedVersion() { + return Version.V_7_13_0; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeMap(this.modelAliases, StreamOutput::writeString, (stream, val) -> val.writeTo(stream)); + } + + public String getModelId(String modelAlias) { + ModelAliasEntry entry = this.modelAliases.get(modelAlias); + if (entry == null) { + return null; + } + return entry.modelId; + } + + static class ModelAliasMetadataDiff implements NamedDiff { + + final Diff> modelAliasesDiff; + + ModelAliasMetadataDiff(ModelAliasMetadata before, ModelAliasMetadata after) { + this.modelAliasesDiff = DiffableUtils.diff(before.modelAliases, after.modelAliases, DiffableUtils.getStringKeySerializer()); + } + + ModelAliasMetadataDiff(StreamInput in) throws IOException { + this.modelAliasesDiff = DiffableUtils.readJdkMapDiff(in, DiffableUtils.getStringKeySerializer(), + ModelAliasEntry::new, ModelAliasEntry::readDiffFrom); + } + + @Override + public Metadata.Custom apply(Metadata.Custom part) { + return new ModelAliasMetadata(modelAliasesDiff.apply(((ModelAliasMetadata) part).modelAliases)); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + modelAliasesDiff.writeTo(out); + } + } + + public static class ModelAliasEntry extends AbstractDiffable implements ToXContentObject { + private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "model_alias_metadata_alias_entry", + // to protect BWC serialization + true, + args -> new ModelAliasEntry((String)args[0]) + ); + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), MODEL_ID); + } + + private static Diff readDiffFrom(StreamInput in) throws IOException { + return readDiffFrom(ModelAliasEntry::new, in); + } + + private static ModelAliasEntry fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + private final String modelId; + + public ModelAliasEntry(String modelId) { + this.modelId = modelId; + } + + ModelAliasEntry(StreamInput in) throws IOException { + this.modelId = in.readString(); + } + + public String getModelId() { + return modelId; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(MODEL_ID.getPreferredName(), modelId); + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ModelAliasEntry modelAliasEntry = (ModelAliasEntry) o; + return Objects.equals(modelId, modelAliasEntry.modelId); + } + + @Override + public int hashCode() { + return Objects.hash(modelId); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java index 60e259197236f..bf85533533d08 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java @@ -43,6 +43,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; import java.util.stream.Collectors; import static org.elasticsearch.action.ValidateActions.addValidationError; @@ -58,6 +59,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { public static final String TOTAL_FEATURE_IMPORTANCE = "total_feature_importance"; public static final String FEATURE_IMPORTANCE_BASELINE = "feature_importance_baseline"; public static final String HYPERPARAMETERS = "hyperparameters"; + public static final String MODEL_ALIASES = "model_aliases"; private static final String ESTIMATED_HEAP_MEMORY_USAGE_HUMAN = "estimated_heap_memory_usage"; @@ -482,34 +484,41 @@ public Builder setFeatureImportance(List totalFeatureImp if (totalFeatureImportance == null) { return this; } - if (this.metadata == null) { - this.metadata = new HashMap<>(); - } - this.metadata.put(TOTAL_FEATURE_IMPORTANCE, - totalFeatureImportance.stream().map(TotalFeatureImportance::asMap).collect(Collectors.toList())); - return this; + return addToMetadata( + TOTAL_FEATURE_IMPORTANCE, + totalFeatureImportance.stream().map(TotalFeatureImportance::asMap).collect(Collectors.toList()) + ); } public Builder setBaselineFeatureImportance(FeatureImportanceBaseline featureImportanceBaseline) { if (featureImportanceBaseline == null) { return this; } - if (this.metadata == null) { - this.metadata = new HashMap<>(); - } - this.metadata.put(FEATURE_IMPORTANCE_BASELINE, featureImportanceBaseline.asMap()); - return this; + return addToMetadata(FEATURE_IMPORTANCE_BASELINE, featureImportanceBaseline.asMap()); } public Builder setHyperparameters(List hyperparameters) { if (hyperparameters == null) { return this; } + return addToMetadata( + HYPERPARAMETERS, + hyperparameters.stream().map(Hyperparameters::asMap).collect(Collectors.toList()) + ); + } + + public Builder setModelAliases(Set modelAliases) { + if (modelAliases == null || modelAliases.isEmpty()) { + return this; + } + return addToMetadata(MODEL_ALIASES, modelAliases.stream().sorted().collect(Collectors.toList())); + } + + private Builder addToMetadata(String fieldName, Object value) { if (this.metadata == null) { this.metadata = new HashMap<>(); } - this.metadata.put(HYPERPARAMETERS, - hyperparameters.stream().map(Hyperparameters::asMap).collect(Collectors.toList())); + this.metadata.put(fieldName, value); return this; } @@ -674,6 +683,10 @@ public Builder validate(boolean forCreation) { metadata.get(TOTAL_FEATURE_IMPORTANCE), METADATA.getPreferredName() + "." + TOTAL_FEATURE_IMPORTANCE, validationException); + validationException = checkIllegalSetting( + metadata.get(MODEL_ALIASES), + METADATA.getPreferredName() + "." + MODEL_ALIASES, + validationException); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java index 9b556052b9a25..401e6660f8917 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java @@ -120,6 +120,11 @@ public final class Messages { public static final String INFERENCE_TAGS_AND_MODEL_IDS_UNIQUE = "The provided tags {0} must not match existing model_ids."; public static final String INFERENCE_MODEL_ID_AND_TAGS_UNIQUE = "The provided model_id {0} must not match existing tags."; + public static final String INVALID_MODEL_ALIAS = "Invalid model_alias; ''{0}'' can contain lowercase alphanumeric (a-z and 0-9), " + + "hyphens or underscores; must start with alphanumeric and cannot end with numbers"; + public static final String TRAINED_MODEL_INPUTS_DIFFER_SIGNIFICANTLY = + "The input fields for new model [{0}] and for old model [{1}] differ significantly, model results may change drastically."; + public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again"; public static final String JOB_AUDIT_CREATED = "Job created"; public static final String JOB_AUDIT_UPDATED = "Job updated: {0}"; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelActionResponseTests.java index 9d752e6b6799e..9732ff14963b0 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelActionResponseTests.java @@ -29,6 +29,7 @@ protected Response createTestInstance() { Stream.generate(() -> randomInferenceResult(resultType)) .limit(randomIntBetween(0, 10)) .collect(Collectors.toList()), + randomAlphaOfLength(10), randomBoolean()); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAliasActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAliasActionRequestTests.java new file mode 100644 index 0000000000000..8f77b116d7e96 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAliasActionRequestTests.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.core.ml.action; + +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAliasAction.Request; +import org.junit.Before; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.nullValue; + +public class PutTrainedModelAliasActionRequestTests extends AbstractWireSerializingTestCase { + + private String modelAlias; + + @Before + public void setupModelAlias() { + modelAlias = randomAlphaOfLength(10); + } + + @Override + protected Request createTestInstance() { + return new Request( + modelAlias, + randomAlphaOfLength(10), + randomBoolean() + ); + } + + @Override + protected Writeable.Reader instanceReader() { + return Request::new; + } + + public void testCtor() { + expectThrows(Exception.class, () -> new Request(null, randomAlphaOfLength(10), randomBoolean())); + expectThrows(Exception.class, () -> new Request(randomAlphaOfLength(10), null, randomBoolean())); + } + + public void testValidate() { + + { // model_alias equal to model Id + ActionRequestValidationException ex = new Request("foo", "foo", randomBoolean()).validate(); + assertThat(ex, not(nullValue())); + assertThat(ex.getMessage(), containsString("model_alias [foo] cannot equal model_id [foo]")); + } + { // model_alias cannot end in numbers + String modelAlias = randomAlphaOfLength(10) + randomIntBetween(0, Integer.MAX_VALUE); + ActionRequestValidationException ex = new Request(modelAlias, "foo", randomBoolean()).validate(); + assertThat(ex, not(nullValue())); + assertThat( + ex.getMessage(), + containsString( + "can contain lowercase alphanumeric (a-z and 0-9), hyphens or underscores; " + + "must start with alphanumeric and cannot end with numbers" + ) + ); + } + } + +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/integration/MlRestTestStateCleaner.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/integration/MlRestTestStateCleaner.java index bbfca86e27eaf..6b6d76188cef3 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/integration/MlRestTestStateCleaner.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/integration/MlRestTestStateCleaner.java @@ -14,11 +14,14 @@ import org.elasticsearch.test.rest.ESRestTestCase; import java.io.IOException; +import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Set; public class MlRestTestStateCleaner { + private static final Set NOT_DELETED_TRAINED_MODELS = Collections.singleton("lang_ident_model_1"); private final Logger logger; private final RestClient adminClient; @@ -28,12 +31,34 @@ public MlRestTestStateCleaner(Logger logger, RestClient adminClient) { } public void clearMlMetadata() throws IOException { + deleteAllTrainedModels(); deleteAllDatafeeds(); deleteAllJobs(); deleteAllDataFrameAnalytics(); // indices will be deleted by the ESRestTestCase class } + @SuppressWarnings("unchecked") + private void deleteAllTrainedModels() throws IOException { + final Request getTrainedModels = new Request("GET", "/_ml/trained_models"); + getTrainedModels.addParameter("size", "10000"); + final Response trainedModelsResponse = adminClient.performRequest(getTrainedModels); + final List> models = (List>) XContentMapValues.extractValue( + "trained_model_configs", + ESRestTestCase.entityAsMap(trainedModelsResponse) + ); + if (models == null || models.isEmpty()) { + return; + } + for (Map model : models) { + String modelId = (String) model.get("model_id"); + if (NOT_DELETED_TRAINED_MODELS.contains(modelId)) { + continue; + } + adminClient.performRequest(new Request("DELETE", "/_ml/trained_models/" + modelId)); + } + } + @SuppressWarnings("unchecked") private void deleteAllDatafeeds() throws IOException { final Request datafeedsRequest = new Request("GET", "/_ml/datafeeds"); diff --git a/x-pack/plugin/ml/qa/ml-with-security/build.gradle b/x-pack/plugin/ml/qa/ml-with-security/build.gradle index afd4453625995..e26ab4bbd5792 100644 --- a/x-pack/plugin/ml/qa/ml-with-security/build.gradle +++ b/x-pack/plugin/ml/qa/ml-with-security/build.gradle @@ -145,6 +145,10 @@ tasks.named("yamlRestTest").configure { 'ml/inference_crud/Test put ensemble with tree where tree has out of bounds feature_names index', 'ml/inference_crud/Test put model with empty input.field_names', 'ml/inference_crud/Test PUT model where target type and inference config mismatch', + 'ml/inference_crud/Test update model alias with model id referring to missing model', + 'ml/inference_crud/Test update model alias with bad alias', + 'ml/inference_crud/Test update model alias where alias exists but old model id is different inference type', + 'ml/inference_crud/Test update model alias where alias exists but reassign is false', 'ml/inference_processor/Test create processor with missing mandatory fields', 'ml/inference_stats_crud/Test get stats given missing trained model', 'ml/inference_stats_crud/Test get stats given expression without matches and allow_no_match is false', diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java index d232739e16474..16ef2f22f193b 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java @@ -48,6 +48,7 @@ import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; /** * This is a {@link ESRestTestCase} because the cleanup code in {@link ExternalTestCluster#ensureEstimatedStats()} causes problems @@ -202,6 +203,84 @@ public void testPipelineIngest() throws Exception { }, 30, TimeUnit.SECONDS); } + public void testPipelineIngestWithModelAliases() throws Exception { + String regressionModelId = "test_regression_1"; + putModel(regressionModelId, REGRESSION_CONFIG); + String regressionModelId2 = "test_regression_2"; + putModel(regressionModelId2, REGRESSION_CONFIG); + String modelAlias = "test_regression"; + putModelAlias(modelAlias, regressionModelId); + + client().performRequest(putPipeline("simple_regression_pipeline", pipelineDefinition(modelAlias, "regression"))); + + for (int i = 0; i < 10; i++) { + client().performRequest(indexRequest("index_for_inference_test", "simple_regression_pipeline", generateSourceDoc())); + } + putModelAlias(modelAlias, regressionModelId2); + // Need to assert busy as loading the model and then switching the model alias can take time + assertBusy(() -> { + String source = "{\n" + + " \"docs\": [\n" + + " {\"_source\": {\n" + + " \"col1\": \"female\",\n" + + " \"col2\": \"M\",\n" + + " \"col3\": \"none\",\n" + + " \"col4\": 10\n" + + " }}]\n" + + "}"; + Request request = new Request("POST", "_ingest/pipeline/simple_regression_pipeline/_simulate"); + request.setJsonEntity(source); + Response response = client().performRequest(request); + String responseString = EntityUtils.toString(response.getEntity()); + assertThat(responseString, containsString("\"model_id\":\"test_regression_2\"")); + }, 30, TimeUnit.SECONDS); + + for (int i = 0; i < 10; i++) { + client().performRequest(indexRequest("index_for_inference_test", "simple_regression_pipeline", generateSourceDoc())); + } + + client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_regression_pipeline")); + + client().performRequest(new Request("POST", "index_for_inference_test/_refresh")); + + Response searchResponse = client().performRequest(searchRequest("index_for_inference_test", + QueryBuilders.boolQuery() + .filter( + QueryBuilders.existsQuery("ml.inference.regression.predicted_value")))); + // Verify we have 20 documents that contain a predicted value for regression + assertThat(EntityUtils.toString(searchResponse.getEntity()), containsString("\"value\":20")); + + + // Since this is a multi-node cluster, the model could be loaded and cached on one ingest node but not the other + // Consequently, we should only verify that some of the documents refer to the first regression model + // and some refer to the second. + searchResponse = client().performRequest(searchRequest("index_for_inference_test", + QueryBuilders.boolQuery() + .filter( + QueryBuilders.termQuery("ml.inference.regression.model_id.keyword", regressionModelId)))); + assertThat(EntityUtils.toString(searchResponse.getEntity()), not(containsString("\"value\":0"))); + + searchResponse = client().performRequest(searchRequest("index_for_inference_test", + QueryBuilders.boolQuery() + .filter( + QueryBuilders.termQuery("ml.inference.regression.model_id.keyword", regressionModelId2)))); + assertThat(EntityUtils.toString(searchResponse.getEntity()), not(containsString("\"value\":0"))); + + assertBusy(() -> { + try (XContentParser parser = createParser(JsonXContent.jsonXContent, client().performRequest(new Request("GET", + "_ml/trained_models/" + modelAlias + "/_stats")).getEntity().getContent())) { + GetTrainedModelsStatsResponse response = GetTrainedModelsStatsResponse.fromXContent(parser); + assertThat(response.toString(), response.getTrainedModelStats(), hasSize(1)); + TrainedModelStats trainedModelStats = response.getTrainedModelStats().get(0); + assertThat(trainedModelStats.getModelId(), equalTo(regressionModelId2)); + assertThat(trainedModelStats.getInferenceStats(), is(notNullValue())); + } catch (ResponseException ex) { + //this could just mean shard failures. + fail(ex.getMessage()); + } + }); + } + public void assertStatsWithCacheMisses(String modelId, long inferenceCount) throws IOException { Response statsResponse = client().performRequest(new Request("GET", "_ml/trained_models/" + modelId + "/_stats")); @@ -630,4 +709,9 @@ private void putModel(String modelId, String modelConfiguration) throws IOExcept client().performRequest(request); } + private void putModelAlias(String modelAlias, String newModel) throws IOException { + Request request = new Request("PUT", "_ml/trained_models/" + newModel + "/model_aliases/" + modelAlias + "?reassign=true"); + client().performRequest(request); + } + } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java index b1f14e9cce603..70ff006be022b 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlNativeIntegTestCase.java @@ -78,6 +78,7 @@ import org.elasticsearch.xpack.ilm.IndexLifecycle; import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.autoscaling.MlScalingReason; +import org.elasticsearch.xpack.core.ml.inference.ModelAliasMetadata; import java.io.IOException; import java.net.URISyntaxException; @@ -240,6 +241,8 @@ protected void ensureClusterStateConsistency() throws IOException { if (cluster() != null && cluster().size() > 0) { List entries = new ArrayList<>(ClusterModule.getNamedWriteables()); entries.addAll(new SearchModule(Settings.EMPTY, true, Collections.emptyList()).getNamedWriteables()); + entries.add(new NamedWriteableRegistry.Entry(Metadata.Custom.class, ModelAliasMetadata.NAME, ModelAliasMetadata::new)); + entries.add(new NamedWriteableRegistry.Entry(NamedDiff.class, ModelAliasMetadata.NAME, ModelAliasMetadata::readDiffFrom)); entries.add(new NamedWriteableRegistry.Entry(Metadata.Custom.class, "ml", MlMetadata::new)); entries.add(new NamedWriteableRegistry.Entry(Metadata.Custom.class, IndexLifecycleMetadata.TYPE, IndexLifecycleMetadata::new)); entries.add(new NamedWriteableRegistry.Entry(LifecycleType.class, TimeseriesLifecycleType.TYPE, diff --git a/x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceProcessorIT.java b/x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceProcessorIT.java index a13e6c38ad846..960514b6e1cc0 100644 --- a/x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceProcessorIT.java +++ b/x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceProcessorIT.java @@ -13,19 +13,26 @@ import org.elasticsearch.client.ResponseException; import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.test.rest.ESRestTestCase; +import org.junit.After; import org.junit.Before; import java.io.IOException; +import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; public class InferenceProcessorIT extends ESRestTestCase { + private static final Set NOT_DELETED_TRAINED_MODELS = Collections.singleton("lang_ident_model_1"); private static final String MODEL_ID = "a-perfect-regression-model"; + private final Set createdPipelines = new HashSet<>(); @Before public void enableLogging() throws IOException { @@ -36,8 +43,39 @@ public void enableLogging() throws IOException { assertThat(client().performRequest(setTrace).getStatusLine().getStatusCode(), equalTo(200)); } - private void putRegressionModel() throws IOException { + @SuppressWarnings("unchecked") + @After + public void cleanup() throws Exception { + for (String createdPipeline : createdPipelines) { + deletePipeline(createdPipeline); + } + createdPipelines.clear(); + waitForStats(); + final Request getTrainedModels = new Request("GET", "/_ml/trained_models"); + getTrainedModels.addParameter("size", "10000"); + final Response trainedModelsResponse = adminClient().performRequest(getTrainedModels); + final List> models = (List>) XContentMapValues.extractValue( + "trained_model_configs", + ESRestTestCase.entityAsMap(trainedModelsResponse) + ); + if (models == null || models.isEmpty()) { + return; + } + for (Map model : models) { + String modelId = (String) model.get("model_id"); + if (NOT_DELETED_TRAINED_MODELS.contains(modelId)) { + continue; + } + adminClient().performRequest(new Request("DELETE", "/_ml/trained_models/" + modelId)); + } + } + + private void putModelAlias(String modelAlias, String newModel) throws IOException { + Request request = new Request("PUT", "_ml/trained_models/" + newModel + "/model_aliases/" + modelAlias + "?reassign=true"); + client().performRequest(request); + } + private void putRegressionModel() throws IOException { Request model = new Request("PUT", "_ml/trained_models/" + MODEL_ID); model.setJsonEntity( " {\n" + @@ -66,24 +104,9 @@ private void putRegressionModel() throws IOException { @SuppressWarnings("unchecked") public void testCreateAndDeletePipelineWithInferenceProcessor() throws Exception { putRegressionModel(); - - Request putPipeline = new Request("PUT", "_ingest/pipeline/regression-model-pipeline"); - putPipeline.setJsonEntity( - " {\n" + - " \"processors\": [\n" + - " {\n" + - " \"inference\" : {\n" + - " \"model_id\" : \"" + MODEL_ID + "\",\n" + - " \"inference_config\": {\"regression\": {}},\n" + - " \"target_field\": \"regression_field\",\n" + - " \"field_map\": {}\n" + - " }\n" + - " }\n" + - " ]\n" + - " }" - ); - - assertThat(client().performRequest(putPipeline).getStatusLine().getStatusCode(), equalTo(200)); + String pipelineId = "regression-model-pipeline"; + createdPipelines.add(pipelineId); + putPipeline(MODEL_ID, pipelineId); Map statsAsMap = getStats(); List pipelineCount = @@ -100,8 +123,8 @@ public void testCreateAndDeletePipelineWithInferenceProcessor() throws Exception // using the model will ensure it is loaded and stats will be written before it is deleted infer("regression-model-pipeline"); - Request deletePipeline = new Request("DELETE", "_ingest/pipeline/regression-model-pipeline"); - assertThat(client().performRequest(deletePipeline).getStatusLine().getStatusCode(), equalTo(200)); + deletePipeline(pipelineId); + createdPipelines.remove(pipelineId); // check stats are updated assertBusy(() -> { @@ -129,9 +152,100 @@ public void testCreateAndDeletePipelineWithInferenceProcessor() throws Exception }); } + @SuppressWarnings("unchecked") + public void testCreateAndDeletePipelineWithInferenceProcessorByName() throws Exception { + putRegressionModel(); + + putModelAlias("regression_first", MODEL_ID); + putModelAlias("regression_second", MODEL_ID); + createdPipelines.add("first_pipeline"); + putPipeline("regression_first", "first_pipeline"); + createdPipelines.add("second_pipeline"); + putPipeline("regression_second", "second_pipeline"); + + Map statsAsMap = getStats(); + List pipelineCount = + (List)XContentMapValues.extractValue("trained_model_stats.pipeline_count", statsAsMap); + assertThat(pipelineCount.get(0), equalTo(2)); + + List> counts = + (List>)XContentMapValues.extractValue("trained_model_stats.ingest.total", statsAsMap); + assertThat(counts.get(0).get("count"), equalTo(0)); + assertThat(counts.get(0).get("time_in_millis"), equalTo(0)); + assertThat(counts.get(0).get("current"), equalTo(0)); + assertThat(counts.get(0).get("failed"), equalTo(0)); + + // using the model will ensure it is loaded and stats will be written before it is deleted + infer("first_pipeline"); + deletePipeline("first_pipeline"); + createdPipelines.remove("first_pipeline"); + + infer("second_pipeline"); + deletePipeline("second_pipeline"); + createdPipelines.remove("second_pipeline"); + + // check stats are updated + assertBusy(() -> { + Map updatedStatsMap = null; + try { + updatedStatsMap = getStats(); + } catch (ResponseException e) { + // the search may fail because the index is not ready yet in which case retry + if (e.getMessage().contains("search_phase_execution_exception")) { + fail("search failed- retry"); + } else { + throw e; + } + } + + List updatedPipelineCount = + (List) XContentMapValues.extractValue("trained_model_stats.pipeline_count", updatedStatsMap); + assertThat(updatedPipelineCount.get(0), equalTo(0)); + + List> inferenceStats = + (List>) XContentMapValues.extractValue("trained_model_stats.inference_stats", updatedStatsMap); + assertNotNull(inferenceStats); + assertThat(inferenceStats, hasSize(1)); + assertThat(inferenceStats.toString(), inferenceStats.get(0).get("inference_count"), equalTo(2)); + }); + } + + public void testDeleteModelWhileAliasReferencedByPipeline() throws Exception { + putRegressionModel(); + putModelAlias("regression_first", MODEL_ID); + createdPipelines.add("first_pipeline"); + putPipeline("regression_first", "first_pipeline"); + Exception ex = expectThrows(Exception.class, + () -> client().performRequest(new Request("DELETE", "_ml/trained_models/" + MODEL_ID))); + assertThat(ex.getMessage(), + containsString("Cannot delete model [" + + MODEL_ID + + "] as it has a model_alias [regression_first] that is still referenced by ingest processors")); + infer("first_pipeline"); + deletePipeline("first_pipeline"); + waitForStats(); + } + + public void testDeleteModelWhileReferencedByPipeline() throws Exception { + putRegressionModel(); + createdPipelines.add("first_pipeline"); + putPipeline(MODEL_ID, "first_pipeline"); + Exception ex = expectThrows(Exception.class, + () -> client().performRequest(new Request("DELETE", "_ml/trained_models/" + MODEL_ID))); + assertThat(ex.getMessage(), + containsString("Cannot delete model [" + + MODEL_ID + + "] as it is still referenced by ingest processors")); + infer("first_pipeline"); + deletePipeline("first_pipeline"); + waitForStats(); + } + + @SuppressWarnings("unchecked") public void testCreateProcessorWithDeprecatedFields() throws Exception { putRegressionModel(); + createdPipelines.add("regression-model-deprecated-pipeline"); Request putPipeline = new Request("PUT", "_ingest/pipeline/regression-model-deprecated-pipeline"); putPipeline.setJsonEntity( "{\n" + @@ -155,14 +269,35 @@ public void testCreateProcessorWithDeprecatedFields() throws Exception { // using the model will ensure it is loaded and stats will be written before it is deleted infer("regression-model-deprecated-pipeline"); - Request deletePipeline = new Request("DELETE", "_ingest/pipeline/regression-model-deprecated-pipeline"); - Response deleteResponse = client().performRequest(deletePipeline); - assertThat(deleteResponse.getStatusLine().getStatusCode(), equalTo(200)); + deletePipeline("regression-model-deprecated-pipeline"); + createdPipelines.remove("regression-model-deprecated-pipeline"); + waitForStats(); + assertBusy(() -> { + Map updatedStatsMap = null; + try { + updatedStatsMap = getStats(); + } catch (ResponseException e) { + // the search may fail because the index is not ready yet in which case retry + if (e.getMessage().contains("search_phase_execution_exception")) { + fail("search failed- retry"); + } else { + throw e; + } + } + + List updatedPipelineCount = + (List) XContentMapValues.extractValue("trained_model_stats.pipeline_count", updatedStatsMap); + assertThat(updatedPipelineCount.get(0), equalTo(0)); - waitForStatsDoc(); + List> inferenceStats = + (List>) XContentMapValues.extractValue("trained_model_stats.inference_stats", updatedStatsMap); + assertNotNull(inferenceStats); + assertThat(inferenceStats, hasSize(1)); + assertThat(inferenceStats.get(0).get("inference_count"), equalTo(1)); + }); } - public void infer(String pipelineId) throws IOException { + private void infer(String pipelineId) throws IOException { Request putDoc = new Request("POST", "any_index/_doc?pipeline=" + pipelineId); putDoc.setJsonEntity("{\"field1\": 1, \"field2\": 2}"); @@ -170,43 +305,56 @@ public void infer(String pipelineId) throws IOException { assertThat(response.getStatusLine().getStatusCode(), equalTo(201)); } - @SuppressWarnings("unchecked") - public void waitForStatsDoc() throws Exception { - assertBusy( () -> { - Request searchForStats = new Request("GET", ".ml-stats-*/_search?rest_total_hits_as_int"); - searchForStats.setJsonEntity( - "{\n" + - " \"query\": {\n" + - " \"bool\": {\n" + - " \"filter\": [\n" + - " {\n" + - " \"term\": {\n" + - " \"type\": \"inference_stats\"\n" + - " }\n" + - " },\n" + - " {\n" + - " \"term\": {\n" + - " \"model_id\": \"" + MODEL_ID + "\"\n" + - " }\n" + - " }\n" + - " ]\n" + - " }\n" + - " }\n" + - "}" - ); + private void putPipeline(String modelId, String pipelineName) throws IOException { + Request putPipeline = new Request("PUT", "_ingest/pipeline/" + pipelineName); + putPipeline.setJsonEntity( + " {\n" + + " \"processors\": [\n" + + " {\n" + + " \"inference\" : {\n" + + " \"model_id\" : \"" + modelId + "\",\n" + + " \"inference_config\": {\"regression\": {}},\n" + + " \"target_field\": \"regression_field\",\n" + + " \"field_map\": {}\n" + + " }\n" + + " }\n" + + " ]\n" + + " }" + ); - try { - Response searchResponse = client().performRequest(searchForStats); + assertThat(client().performRequest(putPipeline).getStatusLine().getStatusCode(), equalTo(200)); + } - Map responseAsMap = entityAsMap(searchResponse); - Map hits = (Map)responseAsMap.get("hits"); - assertThat(responseAsMap.toString(), hits.get("total"), equalTo(1)); + private void deletePipeline(String pipelineId) throws IOException { + try { + Request deletePipeline = new Request("DELETE", "_ingest/pipeline/" + pipelineId); + assertThat(client().performRequest(deletePipeline).getStatusLine().getStatusCode(), equalTo(200)); + } catch (ResponseException ex) { + if (ex.getResponse().getStatusLine().getStatusCode() != 404) { + throw ex; + } + } + } + + @SuppressWarnings("unchecked") + private void waitForStats() throws Exception { + assertBusy(() -> { + Map updatedStatsMap = null; + try { + ensureGreen(".ml-stats-*"); + updatedStatsMap = getStats(); } catch (ResponseException e) { // the search may fail because the index is not ready yet in which case retry - if (e.getMessage().contains("search_phase_execution_exception") == false) { + if (e.getMessage().contains("search_phase_execution_exception")) { + fail("search failed- retry"); + } else { throw e; } } + + List> inferenceStats = + (List>) XContentMapValues.extractValue("trained_model_stats.inference_stats", updatedStatsMap); + assertNotNull(inferenceStats); }); } diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java index 8dd6413eea435..5245d85eaccec 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java @@ -36,6 +36,7 @@ import org.elasticsearch.xpack.ml.extractor.DocValueField; import org.elasticsearch.xpack.ml.extractor.ExtractedField; import org.elasticsearch.xpack.ml.extractor.ExtractedFields; +import org.elasticsearch.xpack.core.ml.inference.ModelAliasMetadata; import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider; import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo; import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfoTests; @@ -102,11 +103,18 @@ public void testStoreModelViaChunkedPersister() throws IOException { .collect(Collectors.toList())); persister.createAndIndexInferenceModelMetadata(modelMetadata); - PlainActionFuture>> getIdsFuture = new PlainActionFuture<>(); - trainedModelProvider.expandIds(modelId + "*", false, PageParams.defaultParams(), Collections.emptySet(), getIdsFuture); - Tuple> ids = getIdsFuture.actionGet(); + PlainActionFuture>>> getIdsFuture = new PlainActionFuture<>(); + trainedModelProvider.expandIds( + modelId + "*", + false, + PageParams.defaultParams(), + Collections.emptySet(), + ModelAliasMetadata.EMPTY, + getIdsFuture + ); + Tuple>> ids = getIdsFuture.actionGet(); assertThat(ids.v1(), equalTo(1L)); - String inferenceModelId = ids.v2().iterator().next(); + String inferenceModelId = ids.v2().keySet().iterator().next(); PlainActionFuture getTrainedModelFuture = new PlainActionFuture<>(); trainedModelProvider.getTrainedModel(inferenceModelId, GetTrainedModelsAction.Includes.all(), getTrainedModelFuture); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index b53e621f208ba..718ebd68b052a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -19,10 +19,12 @@ import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.metadata.IndexTemplateMetadata; +import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodeRole; import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.ParseField; import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.inject.Module; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; @@ -134,6 +136,7 @@ import org.elasticsearch.xpack.core.ml.action.UpdateJobAction; import org.elasticsearch.xpack.core.ml.action.UpdateModelSnapshotAction; import org.elasticsearch.xpack.core.ml.action.UpdateProcessAction; +import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAliasAction; import org.elasticsearch.xpack.core.ml.action.UpgradeJobModelSnapshotAction; import org.elasticsearch.xpack.core.ml.action.ValidateDetectorAction; import org.elasticsearch.xpack.core.ml.action.ValidateJobConfigAction; @@ -207,6 +210,7 @@ import org.elasticsearch.xpack.ml.action.TransportUpdateJobAction; import org.elasticsearch.xpack.ml.action.TransportUpdateModelSnapshotAction; import org.elasticsearch.xpack.ml.action.TransportUpdateProcessAction; +import org.elasticsearch.xpack.ml.action.TransportPutTrainedModelAliasAction; import org.elasticsearch.xpack.ml.action.TransportUpgradeJobModelSnapshotAction; import org.elasticsearch.xpack.ml.action.TransportValidateDetectorAction; import org.elasticsearch.xpack.ml.action.TransportValidateJobConfigAction; @@ -226,6 +230,7 @@ import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationProcessFactory; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult; +import org.elasticsearch.xpack.core.ml.inference.ModelAliasMetadata; import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService; import org.elasticsearch.xpack.ml.inference.aggs.InferencePipelineAggregationBuilder; import org.elasticsearch.xpack.ml.inference.aggs.InternalInferenceAggregation; @@ -300,6 +305,7 @@ import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsAction; import org.elasticsearch.xpack.ml.rest.inference.RestGetTrainedModelsStatsAction; import org.elasticsearch.xpack.ml.rest.inference.RestPutTrainedModelAction; +import org.elasticsearch.xpack.ml.rest.inference.RestPutTrainedModelAliasAction; import org.elasticsearch.xpack.ml.rest.job.RestCloseJobAction; import org.elasticsearch.xpack.ml.rest.job.RestDeleteForecastAction; import org.elasticsearch.xpack.ml.rest.job.RestDeleteJobAction; @@ -928,6 +934,7 @@ public List getRestHandlers(Settings settings, RestController restC new RestGetTrainedModelsStatsAction(), new RestPutTrainedModelAction(), new RestUpgradeJobModelSnapshotAction(), + new RestPutTrainedModelAliasAction(), // CAT Handlers new RestCatJobsAction(), new RestCatTrainedModelsAction(), @@ -1006,7 +1013,8 @@ public List getRestHandlers(Settings settings, RestController restC new ActionHandler<>(DeleteTrainedModelAction.INSTANCE, TransportDeleteTrainedModelAction.class), new ActionHandler<>(GetTrainedModelsStatsAction.INSTANCE, TransportGetTrainedModelsStatsAction.class), new ActionHandler<>(PutTrainedModelAction.INSTANCE, TransportPutTrainedModelAction.class), - new ActionHandler<>(UpgradeJobModelSnapshotAction.INSTANCE, TransportUpgradeJobModelSnapshotAction.class) + new ActionHandler<>(UpgradeJobModelSnapshotAction.INSTANCE, TransportUpgradeJobModelSnapshotAction.class), + new ActionHandler<>(PutTrainedModelAliasAction.INSTANCE, TransportPutTrainedModelAliasAction.class) ); } @@ -1116,6 +1124,13 @@ public List getNamedXContent() { namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); namedXContent.addAll(new MlModelSizeNamedXContentProvider().getNamedXContentParsers()); + namedXContent.add( + new NamedXContentRegistry.Entry( + Metadata.Custom.class, + new ParseField(ModelAliasMetadata.NAME), + ModelAliasMetadata::fromXContent + ) + ); return namedXContent; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteTrainedModelAction.java index b9cc7b0ded34a..087ddb95b6ec3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteTrainedModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteTrainedModelAction.java @@ -14,10 +14,12 @@ import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.action.support.master.AcknowledgedTransportMasterNodeAction; +import org.elasticsearch.cluster.AckedClusterStateUpdateTask; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.block.ClusterBlockException; import org.elasticsearch.cluster.block.ClusterBlockLevel; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.ingest.IngestMetadata; @@ -28,11 +30,15 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction; +import org.elasticsearch.xpack.core.ml.inference.ModelAliasMetadata; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; +import java.util.ArrayList; +import java.util.HashMap; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; @@ -78,13 +84,60 @@ protected void masterOperation(DeleteTrainedModelAction.Request request, return; } - trainedModelProvider.deleteTrainedModel(request.getId(), ActionListener.wrap( - r -> { - auditor.info(request.getId(), "trained model deleted"); - listener.onResponse(AcknowledgedResponse.TRUE); - }, + final ModelAliasMetadata currentMetadata = ModelAliasMetadata.fromState(state); + final List modelAliases = new ArrayList<>(); + for (Map.Entry modelAliasEntry : currentMetadata.modelAliases().entrySet()) { + if (modelAliasEntry.getValue().getModelId().equals(id)) { + modelAliases.add(modelAliasEntry.getKey()); + } + } + for (String modelAlias : modelAliases) { + if (referencedModels.contains(modelAlias)) { + listener.onFailure(new ElasticsearchStatusException( + "Cannot delete model [{}] as it has a model_alias [{}] that is still referenced by ingest processors", + RestStatus.CONFLICT, + id, + modelAlias)); + return; + } + } + + ActionListener nameDeletionListener = ActionListener.wrap( + ack -> trainedModelProvider.deleteTrainedModel(request.getId(), ActionListener.wrap( + r -> { + auditor.info(request.getId(), "trained model deleted"); + listener.onResponse(AcknowledgedResponse.TRUE); + }, + listener::onFailure + )), + listener::onFailure - )); + ); + + // No reason to update cluster state, simply delete the model + if (modelAliases.isEmpty()) { + nameDeletionListener.onResponse(AcknowledgedResponse.of(true)); + return; + } + + clusterService.submitStateUpdateTask("delete-trained-model-alias", new AckedClusterStateUpdateTask(request, nameDeletionListener) { + @Override + public ClusterState execute(final ClusterState currentState) { + final ClusterState.Builder builder = ClusterState.builder(currentState); + final ModelAliasMetadata currentMetadata = ModelAliasMetadata.fromState(currentState); + if (currentMetadata.modelAliases().isEmpty()) { + return currentState; + } + final Map newMetadata = new HashMap<>(currentMetadata.modelAliases()); + logger.info("[{}] delete model model_aliases {}", request.getId(), modelAliases); + modelAliases.forEach(newMetadata::remove); + final ModelAliasMetadata modelAliasMetadata = new ModelAliasMetadata(newMetadata); + builder.metadata(Metadata.builder(currentState.getMetadata()) + .putCustom(ModelAliasMetadata.NAME, modelAliasMetadata) + .build()); + return builder.build(); + } + }); } private Set getReferencedModelKeys(IngestMetadata ingestMetadata) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java index f8d78db353634..c5464acbb290c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsAction.java @@ -9,6 +9,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.tasks.Task; @@ -18,22 +19,27 @@ import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction.Response; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.inference.ModelAliasMetadata; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import java.util.Collections; import java.util.HashSet; +import java.util.Map; import java.util.Set; public class TransportGetTrainedModelsAction extends HandledTransportAction { private final TrainedModelProvider provider; + private final ClusterService clusterService; @Inject public TransportGetTrainedModelsAction(TransportService transportService, ActionFilters actionFilters, + ClusterService clusterService, TrainedModelProvider trainedModelProvider) { super(GetTrainedModelsAction.NAME, transportService, actionFilters, GetTrainedModelsAction.Request::new); this.provider = trainedModelProvider; + this.clusterService = clusterService; } @Override @@ -41,7 +47,7 @@ protected void doExecute(Task task, Request request, ActionListener li Response.Builder responseBuilder = Response.builder(); - ActionListener>> idExpansionListener = ActionListener.wrap( + ActionListener>>> idExpansionListener = ActionListener.wrap( totalAndIds -> { responseBuilder.setTotalCount(totalAndIds.v1()); @@ -58,8 +64,10 @@ protected void doExecute(Task task, Request request, ActionListener li } if (request.getIncludes().isIncludeModelDefinition()) { + Map.Entry> modelIdAndAliases = totalAndIds.v2().entrySet().iterator().next(); provider.getTrainedModel( - totalAndIds.v2().iterator().next(), + modelIdAndAliases.getKey(), + modelIdAndAliases.getValue(), request.getIncludes(), ActionListener.wrap( config -> listener.onResponse(responseBuilder.setModels(Collections.singletonList(config)).build()), @@ -80,11 +88,11 @@ protected void doExecute(Task task, Request request, ActionListener li }, listener::onFailure ); - provider.expandIds(request.getResourceId(), request.isAllowNoResources(), request.getPageParams(), new HashSet<>(request.getTags()), + ModelAliasMetadata.fromState(clusterService.state()), idExpansionListener); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java index 945e621daa601..80a1e5ac439fa 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.java @@ -20,6 +20,7 @@ import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.metrics.CounterMetric; +import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.ingest.IngestMetadata; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.ingest.IngestStats; @@ -28,6 +29,7 @@ import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats; +import org.elasticsearch.xpack.core.ml.inference.ModelAliasMetadata; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; @@ -42,6 +44,7 @@ import java.util.Set; import java.util.function.Function; import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; @@ -73,6 +76,7 @@ protected void doExecute(Task task, GetTrainedModelsStatsAction.Request request, ActionListener listener) { + final ModelAliasMetadata currentMetadata = ModelAliasMetadata.fromState(clusterService.state()); GetTrainedModelsStatsAction.Response.Builder responseBuilder = new GetTrainedModelsStatsAction.Response.Builder(); ActionListener> inferenceStatsListener = ActionListener.wrap( @@ -84,20 +88,30 @@ protected void doExecute(Task task, ActionListener nodesStatsListener = ActionListener.wrap( nodesStatsResponse -> { + Set allPossiblePipelineReferences = responseBuilder.getExpandedIdsWithAliases() + .entrySet() + .stream() + .flatMap(entry -> Stream.concat(entry.getValue().stream(), Stream.of(entry.getKey()))) + .collect(Collectors.toSet()); + Map> pipelineIdsByModelIdsOrAliases = pipelineIdsByModelIdsOrAliases(clusterService.state(), + ingestService, + allPossiblePipelineReferences); Map modelIdIngestStats = inferenceIngestStatsByModelId(nodesStatsResponse, - pipelineIdsByModelIds(clusterService.state(), - ingestService, - responseBuilder.getExpandedIds())); + currentMetadata, + pipelineIdsByModelIdsOrAliases + ); responseBuilder.setIngestStatsByModelId(modelIdIngestStats); - trainedModelProvider.getInferenceStats(responseBuilder.getExpandedIds().toArray(new String[0]), inferenceStatsListener); + trainedModelProvider.getInferenceStats( + responseBuilder.getExpandedIdsWithAliases().keySet().toArray(new String[0]), + inferenceStatsListener + ); }, listener::onFailure ); - ActionListener>> idsListener = ActionListener.wrap( + ActionListener>>> idsListener = ActionListener.wrap( tuple -> { - responseBuilder.setExpandedIds(tuple.v2()) - .setTotalModelCount(tuple.v1()); + responseBuilder.setExpandedIdsWithAliases(tuple.v2()).setTotalModelCount(tuple.v1()); String[] ingestNodes = ingestNodes(clusterService.state()); NodesStatsRequest nodesStatsRequest = new NodesStatsRequest(ingestNodes).clear() .addMetric(NodesStatsRequest.Metric.INGEST.metricName()); @@ -105,27 +119,36 @@ protected void doExecute(Task task, }, listener::onFailure ); - trainedModelProvider.expandIds(request.getResourceId(), request.isAllowNoResources(), request.getPageParams(), Collections.emptySet(), + currentMetadata, idsListener); } static Map inferenceIngestStatsByModelId(NodesStatsResponse response, + ModelAliasMetadata currentMetadata, Map> modelIdToPipelineId) { Map ingestStatsMap = new HashMap<>(); - - modelIdToPipelineId.forEach((modelId, pipelineIds) -> { + Map> trueModelIdToPipelines = modelIdToPipelineId.entrySet() + .stream() + .collect(Collectors.toMap( + entry -> { + String maybeModelId = currentMetadata.getModelId(entry.getKey()); + return maybeModelId == null ? entry.getKey() : maybeModelId; + }, + Map.Entry::getValue, + Sets::union + )); + trueModelIdToPipelines.forEach((modelId, pipelineIds) -> { List collectedStats = response.getNodes() .stream() .map(nodeStats -> ingestStatsForPipelineIds(nodeStats, pipelineIds)) .collect(Collectors.toList()); ingestStatsMap.put(modelId, mergeStats(collectedStats)); }); - return ingestStatsMap; } @@ -139,7 +162,7 @@ static String[] ingestNodes(final ClusterState clusterState) { return ingestNodes; } - static Map> pipelineIdsByModelIds(ClusterState state, IngestService ingestService, Set modelIds) { + static Map> pipelineIdsByModelIdsOrAliases(ClusterState state, IngestService ingestService, Set modelIds) { IngestMetadata ingestMetadata = state.metadata().custom(IngestMetadata.TYPE); Map> pipelineIdsByModelIds = new HashMap<>(); if (ingestMetadata == null) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java index d50ba9fd6addc..8ba42634b68f7 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java @@ -69,7 +69,9 @@ protected void doExecute(Task task, Request request, ActionListener li typedChainTaskExecutor.execute(ActionListener.wrap( inferenceResultsInterfaces -> { model.release(); - listener.onResponse(responseBuilder.setInferenceResults(inferenceResultsInterfaces).build()); + listener.onResponse(responseBuilder.setInferenceResults(inferenceResultsInterfaces) + .setModelId(model.getModelId()) + .build()); }, e -> { model.release(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java index 7a3def44914d8..f242abaa00a07 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java @@ -38,6 +38,7 @@ import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.inference.ModelAliasMetadata; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import java.io.IOException; @@ -131,6 +132,13 @@ protected void masterOperation(Request request, .setEstimatedHeapMemory(request.getTrainedModelConfig().getModelDefinition().ramBytesUsed()) .setEstimatedOperations(request.getTrainedModelConfig().getModelDefinition().getTrainedModel().estimatedNumOperations()) .build(); + if (ModelAliasMetadata.fromState(state).getModelId(trainedModelConfig.getModelId()) != null) { + listener.onFailure(ExceptionsHelper.badRequestException( + "requested model_id [{}] is the same as an existing model_alias. Model model_aliases and ids must be unique", + request.getTrainedModelConfig().getModelId() + )); + return; + } ActionListener tagsModelIdCheckListener = ActionListener.wrap( r -> trainedModelProvider.storeTrainedModel(trainedModelConfig, ActionListener.wrap( diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAliasAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAliasAction.java new file mode 100644 index 0000000000000..901402b8e1bed --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAliasAction.java @@ -0,0 +1,207 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.action; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.action.support.master.AcknowledgedTransportMasterNodeAction; +import org.elasticsearch.cluster.AckedClusterStateUpdateTask; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.block.ClusterBlockException; +import org.elasticsearch.cluster.block.ClusterBlockLevel; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.logging.HeaderWarning; +import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.license.LicenseUtils; +import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.XPackField; +import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction; +import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAliasAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.job.messages.Messages; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.inference.ModelAliasMetadata; +import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; +import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.function.Predicate; + +import static org.elasticsearch.xpack.core.ml.job.messages.Messages.TRAINED_MODEL_INPUTS_DIFFER_SIGNIFICANTLY; + +public class TransportPutTrainedModelAliasAction extends AcknowledgedTransportMasterNodeAction { + + private static final Logger logger = LogManager.getLogger(TransportPutTrainedModelAliasAction.class); + + private final XPackLicenseState licenseState; + private final TrainedModelProvider trainedModelProvider; + private final InferenceAuditor auditor; + + @Inject + public TransportPutTrainedModelAliasAction( + TransportService transportService, + TrainedModelProvider trainedModelProvider, + ClusterService clusterService, + ThreadPool threadPool, + XPackLicenseState licenseState, + ActionFilters actionFilters, + InferenceAuditor auditor, + IndexNameExpressionResolver indexNameExpressionResolver) { + super( + PutTrainedModelAliasAction.NAME, + transportService, + clusterService, + threadPool, + actionFilters, + PutTrainedModelAliasAction.Request::new, + indexNameExpressionResolver, + ThreadPool.Names.SAME + ); + this.licenseState = licenseState; + this.trainedModelProvider = trainedModelProvider; + this.auditor = auditor; + } + + @Override + protected void masterOperation( + PutTrainedModelAliasAction.Request request, + ClusterState state, + ActionListener listener + ) throws Exception { + final boolean mlSupported = licenseState.checkFeature(XPackLicenseState.Feature.MACHINE_LEARNING); + final Predicate isLicensed = (model) -> mlSupported || licenseState.isAllowedByLicense(model.getLicenseLevel()); + final String oldModelId = ModelAliasMetadata.fromState(state).getModelId(request.getModelAlias()); + + if (oldModelId != null && (request.isReassign() == false)) { + listener.onFailure(ExceptionsHelper.badRequestException( + "cannot assign model_alias [{}] to model_id [{}] as model_alias already refers to [{}]. " + + + "Set parameter [reassign] to [true] if model_alias should be reassigned.", + request.getModelAlias(), + request.getModelId(), + oldModelId)); + return; + } + Set modelIds = new HashSet<>(); + modelIds.add(request.getModelAlias()); + modelIds.add(request.getModelId()); + if (oldModelId != null) { + modelIds.add(oldModelId); + } + trainedModelProvider.getTrainedModels(modelIds, GetTrainedModelsAction.Includes.empty(), true, ActionListener.wrap( + models -> { + TrainedModelConfig newModel = null; + TrainedModelConfig oldModel = null; + for (TrainedModelConfig config : models) { + if (config.getModelId().equals(request.getModelId())) { + newModel = config; + } + if (config.getModelId().equals(oldModelId)) { + oldModel = config; + } + if (config.getModelId().equals(request.getModelAlias())) { + listener.onFailure( + ExceptionsHelper.badRequestException("model_alias cannot be the same as an existing trained model_id") + ); + return; + } + } + if (newModel == null) { + listener.onFailure( + ExceptionsHelper.missingTrainedModel(request.getModelId()) + ); + return; + } + if (isLicensed.test(newModel) == false) { + listener.onFailure(LicenseUtils.newComplianceException(XPackField.MACHINE_LEARNING)); + return; + } + // if old model is null, none of these validations matter + // we should still allow reassignment even if the old model was some how deleted and the alias still refers to it + if (oldModel != null) { + // validate inference configs are the same type. Moving an alias from regression -> classification seems dangerous + if (newModel.getInferenceConfig() != null && oldModel.getInferenceConfig() != null) { + if (newModel.getInferenceConfig().getName().equals(oldModel.getInferenceConfig().getName()) == false) { + listener.onFailure( + ExceptionsHelper.badRequestException( + "cannot reassign model_alias [{}] to model [{}] " + + "with inference config type [{}] from model [{}] with type [{}]", + request.getModelAlias(), + newModel.getModelId(), + newModel.getInferenceConfig().getName(), + oldModel.getModelId(), + oldModel.getInferenceConfig().getName() + ) + ); + return; + } + } + + Set oldInputFields = new HashSet<>(oldModel.getInput().getFieldNames()); + Set newInputFields = new HashSet<>(newModel.getInput().getFieldNames()); + // TODO should we fail in this case??? + if (Sets.difference(oldInputFields, newInputFields).size() > (oldInputFields.size() / 2) + || Sets.intersection(newInputFields, oldInputFields).size() < (oldInputFields.size() / 2)) { + String warning = Messages.getMessage( + TRAINED_MODEL_INPUTS_DIFFER_SIGNIFICANTLY, + request.getModelId(), + oldModelId); + auditor.warning(oldModelId, warning); + logger.warn("[{}] {}", oldModelId, warning); + HeaderWarning.addWarning(warning); + } + } + clusterService.submitStateUpdateTask("update-model-alias", new AckedClusterStateUpdateTask(request, listener) { + @Override + public ClusterState execute(final ClusterState currentState) { + return updateModelAlias(currentState, request); + } + }); + + }, + listener::onFailure + )); + } + + static ClusterState updateModelAlias(final ClusterState currentState, final PutTrainedModelAliasAction.Request request) { + final ClusterState.Builder builder = ClusterState.builder(currentState); + final ModelAliasMetadata currentMetadata = ModelAliasMetadata.fromState(currentState); + String currentModelId = currentMetadata.getModelId(request.getModelAlias()); + final Map newMetadata = new HashMap<>(currentMetadata.modelAliases()); + if (currentModelId == null) { + logger.info("creating new model_alias [{}] for model [{}]", request.getModelAlias(), request.getModelId()); + } else { + logger.info( + "updating model_alias [{}] to refer to model [{}] from model [{}]", + request.getModelAlias(), + request.getModelId(), + currentModelId + ); + } + newMetadata.put(request.getModelAlias(), new ModelAliasMetadata.ModelAliasEntry(request.getModelId())); + final ModelAliasMetadata modelAliasMetadata = new ModelAliasMetadata(newMetadata); + builder.metadata(Metadata.builder(currentState.getMetadata()).putCustom(ModelAliasMetadata.NAME, modelAliasMetadata).build()); + return builder.build(); + } + + @Override + protected ClusterBlockException checkBlock(PutTrainedModelAliasAction.Request request, ClusterState state) { + return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java index c1e6b69517a5c..761154fa54341 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java @@ -157,7 +157,12 @@ void mutateDocument(InternalInferModelAction.Response response, IngestDocument i throw new ElasticsearchStatusException("Unexpected empty inference response", RestStatus.INTERNAL_SERVER_ERROR); } assert response.getInferenceResults().size() == 1; - InferenceResults.writeResult(response.getInferenceResults().get(0), ingestDocument, targetField, modelId); + InferenceResults.writeResult( + response.getInferenceResults().get(0), + ingestDocument, + targetField, + response.getModelId() != null ? response.getModelId() : modelId + ); } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java index f9a922b286a79..40709ae218f17 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.breaker.CircuitBreakingException; import org.elasticsearch.common.cache.Cache; import org.elasticsearch.common.cache.CacheBuilder; +import org.elasticsearch.common.cache.CacheLoader; import org.elasticsearch.common.cache.RemovalNotification; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; @@ -37,12 +38,14 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.core.ml.inference.ModelAliasMetadata; import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; import java.util.ArrayDeque; +import java.util.Collections; import java.util.EnumSet; import java.util.HashMap; import java.util.HashSet; @@ -50,6 +53,7 @@ import java.util.Map; import java.util.Queue; import java.util.Set; +import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; /** @@ -108,11 +112,14 @@ private ModelAndConsumer(LocalModel model, Consumer consumer) { } } - private static final Logger logger = LogManager.getLogger(ModelLoadingService.class); private final TrainedModelStatsService modelStatsService; private final Cache localModelCache; + // Referenced models can be model aliases or IDs private final Set referencedModels = new HashSet<>(); + private final Map modelAliasToId = new HashMap<>(); + private final Map> modelIdToModelAliases = new HashMap<>(); + private final Map> modelIdToUpdatedModelAliases = new HashMap<>(); private final Map>> loadingListeners = new HashMap<>(); private final TrainedModelProvider provider; private final Set shouldNotAudit; @@ -148,8 +155,13 @@ public ModelLoadingService(TrainedModelProvider trainedModelProvider, this.trainedModelCircuitBreaker = ExceptionsHelper.requireNonNull(trainedModelCircuitBreaker, "trainedModelCircuitBreaker"); } + // for testing + String getModelId(String modelIdOrAlias) { + return modelAliasToId.getOrDefault(modelIdOrAlias, modelIdOrAlias); + } + boolean isModelCached(String modelId) { - return localModelCache.get(modelId) != null; + return localModelCache.get(modelAliasToId.getOrDefault(modelId, modelId)) != null; } /** @@ -195,11 +207,12 @@ public void getModelForSearch(String modelId, ActionListener modelAc * The main difference being that models for search are always cached whereas pipeline models * are only cached if they are referenced by an ingest pipeline * - * @param modelId the model to get + * @param modelIdOrAlias the model id or model alias to get * @param consumer which feature is requesting the model * @param modelActionListener the listener to alert when the model has been retrieved. */ - private void getModel(String modelId, Consumer consumer, ActionListener modelActionListener) { + private void getModel(String modelIdOrAlias, Consumer consumer, ActionListener modelActionListener) { + final String modelId = modelAliasToId.getOrDefault(modelIdOrAlias, modelIdOrAlias); ModelAndConsumer cachedModel = localModelCache.get(modelId); if (cachedModel != null) { cachedModel.consumers.add(consumer); @@ -210,12 +223,16 @@ private void getModel(String modelId, Consumer consumer, ActionListener new ParameterizedMessage("[{}] loaded from cache", modelId)); + logger.trace(() -> new ParameterizedMessage("[{}] (model_alias [{}]) loaded from cache", modelId, modelIdOrAlias)); return; } - if (loadModelIfNecessary(modelId, consumer, modelActionListener)) { - logger.trace(() -> new ParameterizedMessage("[{}] is loading or loaded, added new listener to queue", modelId)); + if (loadModelIfNecessary(modelIdOrAlias, consumer, modelActionListener)) { + logger.trace(() -> new ParameterizedMessage( + "[{}] (model_alias [{}]) is loading or loaded, added new listener to queue", + modelId, + modelIdOrAlias + )); } } @@ -224,14 +241,15 @@ private void getModel(String modelId, Consumer consumer, ActionListener modelActionListener) { + private boolean loadModelIfNecessary(String modelIdOrAlias, Consumer consumer, ActionListener modelActionListener) { synchronized (loadingListeners) { + final String modelId = modelAliasToId.getOrDefault(modelIdOrAlias, modelIdOrAlias); ModelAndConsumer cachedModel = localModelCache.get(modelId); if (cachedModel != null) { cachedModel.consumers.add(consumer); @@ -257,13 +275,21 @@ private boolean loadModelIfNecessary(String modelId, Consumer consumer, ActionLi if (Consumer.PIPELINE == consumer && referencedModels.contains(modelId) == false) { // The model is requested by a pipeline but not referenced by any ingest pipelines. // This means it is a simulate call and the model should not be cached + logger.trace(() -> new ParameterizedMessage( + "[{}] (model_alias [{}]) not actively loading, eager loading without cache", + modelId, + modelIdOrAlias + )); loadWithoutCaching(modelId, modelActionListener); } else { - logger.trace(() -> new ParameterizedMessage("[{}] attempting to load and cache", modelId)); + logger.trace(() -> new ParameterizedMessage( + "[{}] (model_alias [{}]) attempting to load and cache", + modelId, + modelIdOrAlias + )); loadingListeners.put(modelId, addFluently(new ArrayDeque<>(), modelActionListener)); loadModel(modelId, consumer); } - return false; } // synchronized (loadingListeners) } @@ -304,7 +330,6 @@ private void loadModel(String modelId, Consumer consumer) { private void loadWithoutCaching(String modelId, ActionListener modelActionListener) { // If we the model is not loaded and we did not kick off a new loading attempt, this means that we may be getting called // by a simulated pipeline - logger.trace(() -> new ParameterizedMessage("[{}] not actively loading, eager loading without cache", modelId)); provider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.empty(), ActionListener.wrap( trainedModelConfig -> { // Verify we can pull the model into memory without causing OOM @@ -377,34 +402,41 @@ private void handleLoadSuccess(String modelId, trainedModelConfig.getLicenseLevel(), modelStatsService, trainedModelCircuitBreaker); - boolean modelAcquired = false; + final ModelAndConsumerLoader modelAndConsumerLoader = new ModelAndConsumerLoader(new ModelAndConsumer(loadedModel, consumer)); synchronized (loadingListeners) { - listeners = loadingListeners.remove(modelId); - // if there are no listeners, simply release and leave - if (listeners == null) { - loadedModel.release(); - return; - } - + populateNewModelAlias(modelId); // If the model is referenced, that means it is currently in a pipeline somewhere // Also, if the consume is a search consumer, we should always cache it - if (referencedModels.contains(modelId) || consumer.equals(Consumer.SEARCH)) { - // temporarily increase the reference count before adding to - // the cache in case the model is evicted before the listeners - // are called in which case acquire() would throw. - loadedModel.acquire(); - localModelCache.put(modelId, new ModelAndConsumer(loadedModel, consumer)); + if (referencedModels.contains(modelId) + || Sets.haveNonEmptyIntersection(modelIdToModelAliases.getOrDefault(modelId, new HashSet<>()), referencedModels) + || consumer.equals(Consumer.SEARCH)) { + try { + // The local model may already be in cache. If it is, we don't bother adding it to cache. + // If it isn't, we flip an `isLoaded` flag, and increment the model counter to make sure if it is evicted + // between now and when the listeners access it, the circuit breaker reflects actual usage. + localModelCache.computeIfAbsent(modelId, modelAndConsumerLoader); + } catch (ExecutionException ee) { + logger.warn(() -> new ParameterizedMessage("[{}] threw when attempting add to cache", modelId), ee); + } shouldNotAudit.remove(modelId); - modelAcquired = true; + } + listeners = loadingListeners.remove(modelId); + // if there are no listeners, we should just exit + if (listeners == null) { + // If we newly added it into cache, release the model so that the circuit breaker can still accurately keep track + // of memory + if(modelAndConsumerLoader.isLoaded()) { + loadedModel.release(); + } + return; } } // synchronized (loadingListeners) for (ActionListener listener = listeners.poll(); listener != null; listener = listeners.poll()) { loadedModel.acquire(); listener.onResponse(loadedModel); } - // account for the acquire in the synchronized block above - // We cannot simply utilize the same conditionals as `referencedModels` could have changed once we exited the synchronized block - if (modelAcquired) { + // account for the acquire in the synchronized block above if the model was loaded into the cache + if (modelAndConsumerLoader.isLoaded()) { loadedModel.release(); } } @@ -413,6 +445,7 @@ private void handleLoadFailure(String modelId, Exception failure) { Queue> listeners; synchronized (loadingListeners) { listeners = loadingListeners.remove(modelId); + populateNewModelAlias(modelId); if (listeners == null) { return; } @@ -424,6 +457,20 @@ private void handleLoadFailure(String modelId, Exception failure) { } } + private void populateNewModelAlias(String modelId) { + Set newModelAliases = modelIdToUpdatedModelAliases.remove(modelId); + if (newModelAliases != null && newModelAliases.isEmpty() == false) { + logger.trace(() -> new ParameterizedMessage( + "[{}] model is now loaded, setting new model_aliases {}", + modelId, + newModelAliases + )); + for (String modelAlias: newModelAliases) { + modelAliasToId.put(modelAlias, modelId); + } + } + } + private void cacheEvictionListener(RemovalNotification notification) { try { if (notification.getRemovalReason() == RemovalNotification.RemovalReason.EVICTED) { @@ -438,12 +485,15 @@ private void cacheEvictionListener(RemovalNotification INFERENCE_MODEL_CACHE_TTL.getKey()); auditIfNecessary(notification.getKey(), msg); } - - logger.trace(() -> new ParameterizedMessage("Persisting stats for evicted model [{}]", - notification.getValue().model.getModelId())); + String modelId = modelAliasToId.getOrDefault(notification.getKey(), notification.getKey()); + logger.trace(() -> new ParameterizedMessage( + "Persisting stats for evicted model [{}] (model_aliases {})", + modelId, + modelIdToModelAliases.getOrDefault(modelId, new HashSet<>()) + )); // If the model is no longer referenced, flush the stats to persist as soon as possible - notification.getValue().model.persistStats(referencedModels.contains(notification.getKey()) == false); + notification.getValue().model.persistStats(referencedModels.contains(modelId) == false); } finally { notification.getValue().model.release(); } @@ -451,46 +501,112 @@ private void cacheEvictionListener(RemovalNotification @Override public void clusterChanged(ClusterChangedEvent event) { - // If ingest data has not changed or if the current node is not an ingest node, don't bother caching models - if (event.changedCustomMetadataSet().contains(IngestMetadata.TYPE) == false || - event.state().nodes().getLocalNode().isIngestNode() == false) { + final boolean prefetchModels = event.state().nodes().getLocalNode().isIngestNode(); + // If we are not prefetching models and there were no model alias changes, don't bother handling the changes + if ((prefetchModels == false) + && (event.changedCustomMetadataSet().contains(IngestMetadata.TYPE) == false) + && (event.changedCustomMetadataSet().contains(ModelAliasMetadata.NAME) == false)) { return; } ClusterState state = event.state(); IngestMetadata currentIngestMetadata = state.metadata().custom(IngestMetadata.TYPE); - Set allReferencedModelKeys = getReferencedModelKeys(currentIngestMetadata); - if (allReferencedModelKeys.equals(referencedModels)) { - return; - } - Set referencedModelsBeforeClusterState = null; + Set allReferencedModelKeys = event.changedCustomMetadataSet().contains(IngestMetadata.TYPE) ? + getReferencedModelKeys(currentIngestMetadata) : + new HashSet<>(referencedModels); + Set referencedModelsBeforeClusterState; Set loadingModelBeforeClusterState = null; - Set removedModels = null; + Set removedModels; + Map> addedModelViaAliases = new HashMap<>(); + Map> oldIdToAliases; synchronized (loadingListeners) { + oldIdToAliases = new HashMap<>(modelIdToModelAliases); + Map changedAliases = gatherLazyChangedAliasesAndUpdateModelAliases( + event, + prefetchModels, + allReferencedModelKeys + ); + + // if we are not prefetching, exit now. + if (prefetchModels == false) { + return; + } + referencedModelsBeforeClusterState = new HashSet<>(referencedModels); if (logger.isTraceEnabled()) { loadingModelBeforeClusterState = new HashSet<>(loadingListeners.keySet()); } removedModels = Sets.difference(referencedModelsBeforeClusterState, allReferencedModelKeys); - // Remove all cached models that are not referenced by any processors - // and are not used in search - removedModels.forEach(modelId -> { - ModelAndConsumer modelAndConsumer = localModelCache.get(modelId); - if (modelAndConsumer != null && modelAndConsumer.consumers.contains(Consumer.SEARCH) == false) { - localModelCache.invalidate(modelId); - } - }); // Remove the models that are no longer referenced referencedModels.removeAll(removedModels); shouldNotAudit.removeAll(removedModels); + // Remove all cached models that are not referenced by any processors + // and are not used in search + for (String modelAliasOrId : removedModels) { + String modelId = changedAliases.getOrDefault(modelAliasOrId, modelAliasToId.getOrDefault(modelAliasOrId, modelAliasOrId)); + // If the "old" model_alias is referenced, we don't want to invalidate. This way the model that now has the model_alias + // can be loaded in first + boolean oldModelAliasesNotReferenced = Sets.haveEmptyIntersection(referencedModels, + oldIdToAliases.getOrDefault(modelId, Collections.emptySet())); + // If the model itself is referenced, we shouldn't evict. + boolean modelIsNotReferenced = referencedModels.contains(modelId) == false; + // If a model_alias change causes it to NOW be referenced, we shouldn't attempt to evict it + boolean newModelAliasesNotReferenced = Sets.haveEmptyIntersection(referencedModels, + modelIdToModelAliases.getOrDefault(modelId, Collections.emptySet())); + if (oldModelAliasesNotReferenced && newModelAliasesNotReferenced && modelIsNotReferenced) { + ModelAndConsumer modelAndConsumer = localModelCache.get(modelId); + if (modelAndConsumer != null && modelAndConsumer.consumers.contains(Consumer.SEARCH) == false) { + logger.trace("[{} ({})] invalidated from cache", modelId, modelAliasOrId); + localModelCache.invalidate(modelId); + } + } + } // Remove all that are still referenced, i.e. the intersection of allReferencedModelKeys and referencedModels allReferencedModelKeys.removeAll(referencedModels); + for (String newlyReferencedModel : allReferencedModelKeys) { + // check if the model_alias has changed in this round + String modelId = changedAliases.getOrDefault( + newlyReferencedModel, + // If the model_alias hasn't changed, get the model id IF it is a model_alias, otherwise we assume it is an id + modelAliasToId.getOrDefault( + newlyReferencedModel, + newlyReferencedModel + ) + ); + // Verify that it isn't an old model id but just a new model_alias + if (referencedModels.contains(modelId) == false) { + addedModelViaAliases.computeIfAbsent(modelId, k -> new HashSet<>()).add(newlyReferencedModel); + } + } + // For any previously referenced model, the model_alias COULD have changed, so it is actually a NEWLY referenced model + for (Map.Entry modelAliasAndId : changedAliases.entrySet()) { + String modelAlias = modelAliasAndId.getKey(); + String modelId = modelAliasAndId.getValue(); + if (referencedModels.contains(modelAlias)) { + // we need to load the underlying model since its model_alias is referenced + addedModelViaAliases.computeIfAbsent(modelId, k -> new HashSet<>()).add(modelAlias); + // If we are in cache, keep the old translation for now, it will be updated later + String oldModelId = modelAliasToId.get(modelAlias); + if (oldModelId != null && localModelCache.get(oldModelId) != null) { + modelIdToUpdatedModelAliases.computeIfAbsent(modelId, k -> new HashSet<>()).add(modelAlias); + } else { + // If we are not cached, might as well add the translation right away as new callers will have to load + // from disk anyways. + modelAliasToId.put(modelAlias, modelId); + } + } else { + // Add model_alias and id here, since the model_alias wasn't previously referenced, + // no reason to wait on updating the model_alias -> model_id mapping + modelAliasToId.put(modelAlias, modelId); + } + } + // Gather ALL currently referenced model ids referencedModels.addAll(allReferencedModelKeys); // Populate loadingListeners key so we know that we are currently loading the model - for (String modelId : allReferencedModelKeys) { + for (String modelId : addedModelViaAliases.keySet()) { loadingListeners.computeIfAbsent(modelId, (s) -> new ArrayDeque<>()); } } // synchronized (loadingListeners) @@ -503,9 +619,51 @@ public void clusterChanged(ClusterChangedEvent event) { logger.trace("cluster state event changed referenced models: before {} after {}", referencedModelsBeforeClusterState, referencedModels); } + if (oldIdToAliases.equals(modelIdToModelAliases) == false) { + logger.trace("model id to alias mappings changed. before {} after {}. Model alias to IDs {}", + oldIdToAliases, + modelIdToModelAliases, + modelAliasToId); + } + if (addedModelViaAliases.isEmpty() == false) { + logger.trace("adding new models via model_aliases and ids: {}", addedModelViaAliases); + } + if (modelIdToUpdatedModelAliases.isEmpty() == false) { + logger.trace("delayed model aliases to update {}", modelIdToModelAliases); + } } removedModels.forEach(this::auditUnreferencedModel); - loadModelsForPipeline(allReferencedModelKeys); + loadModelsForPipeline(addedModelViaAliases.keySet()); + } + + private Map gatherLazyChangedAliasesAndUpdateModelAliases(ClusterChangedEvent event, + boolean prefetchModels, + Set allReferencedModelKeys) { + Map changedAliases = new HashMap<>(); + if (event.changedCustomMetadataSet().contains(ModelAliasMetadata.NAME)) { + final Map modelAliasesToIds = new HashMap<>( + ModelAliasMetadata.fromState(event.state()).modelAliases() + ); + modelIdToModelAliases.clear(); + for (Map.Entry aliasToId : modelAliasesToIds.entrySet()) { + modelIdToModelAliases.computeIfAbsent(aliasToId.getValue().getModelId(), k -> new HashSet<>()).add(aliasToId.getKey()); + java.lang.String modelId = modelAliasToId.get(aliasToId.getKey()); + if (modelId != null + && modelId.equals(aliasToId.getValue().getModelId()) == false) { + if (prefetchModels && allReferencedModelKeys.contains(aliasToId.getKey())) { + changedAliases.put(aliasToId.getKey(), aliasToId.getValue().getModelId()); + } else { + modelAliasToId.put(aliasToId.getKey(), aliasToId.getValue().getModelId()); + } + } + if (modelId == null) { + modelAliasToId.put(aliasToId.getKey(), aliasToId.getValue().getModelId()); + } + } + Set removedAliases = Sets.difference(modelAliasToId.keySet(), modelAliasesToIds.keySet()); + modelAliasToId.keySet().removeAll(removedAliases); + } + return changedAliases; } private void auditIfNecessary(String modelId, MessageSupplier msg) { @@ -600,4 +758,25 @@ void addModelLoadedListener(String modelId, ActionListener modelLoad }); } } + + private static class ModelAndConsumerLoader implements CacheLoader { + + private boolean loaded; + private final ModelAndConsumer modelAndConsumer; + + ModelAndConsumerLoader(ModelAndConsumer modelAndConsumer) { + this.modelAndConsumer = modelAndConsumer; + } + + boolean isLoaded() { + return loaded; + } + + @Override + public ModelAndConsumer load(String key) throws Exception { + loaded = true; + modelAndConsumer.model.acquire(); + return modelAndConsumer; + } + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index c1dcee0b99b89..20ae6adb56b62 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -81,6 +81,7 @@ import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; +import org.elasticsearch.xpack.core.ml.inference.ModelAliasMetadata; import java.io.IOException; import java.io.InputStream; @@ -98,6 +99,7 @@ import java.util.Objects; import java.util.Set; import java.util.TreeSet; +import java.util.function.Function; import java.util.stream.Collectors; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; @@ -443,6 +445,13 @@ public void getTrainedModelForInference(final String modelId, final ActionListen public void getTrainedModel(final String modelId, final GetTrainedModelsAction.Includes includes, final ActionListener finalListener) { + getTrainedModel(modelId, Collections.emptySet(), includes, finalListener); + } + + public void getTrainedModel(final String modelId, + final Set modelAliases, + final GetTrainedModelsAction.Includes includes, + final ActionListener finalListener) { if (MODELS_STORED_AS_RESOURCE.contains(modelId)) { try { @@ -456,6 +465,7 @@ public void getTrainedModel(final String modelId, ActionListener getTrainedModelListener = ActionListener.wrap( modelBuilder -> { + modelBuilder.setModelAliases(modelAliases); if ((includes.isIncludeFeatureImportanceBaseline() || includes.isIncludeTotalFeatureImportance() || includes.isIncludeHyperparameters()) == false) { finalListener.onResponse(modelBuilder.build()); @@ -571,6 +581,18 @@ public void getTrainedModel(final String modelId, multiSearchResponseActionListener); } + public void getTrainedModels(Set modelIds, + GetTrainedModelsAction.Includes includes, + boolean allowNoResources, + final ActionListener> finalListener) { + getTrainedModels( + modelIds.stream().collect(Collectors.toMap(Function.identity(), _k -> Collections.emptySet())), + includes, + allowNoResources, + finalListener + ); + } + /** * Gets all the provided trained config model objects * @@ -578,11 +600,15 @@ public void getTrainedModel(final String modelId, * This does no expansion on the ids. * It assumes that there are fewer than 10k. */ - public void getTrainedModels(Set modelIds, + public void getTrainedModels(Map> modelIds, GetTrainedModelsAction.Includes includes, boolean allowNoResources, final ActionListener> finalListener) { - QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders.idsQuery().addIds(modelIds.toArray(new String[0]))); + QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery( + QueryBuilders + .idsQuery() + .addIds(modelIds.keySet().toArray(new String[0])) + ); SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN) .addSort(TrainedModelConfig.MODEL_ID.getPreferredName(), SortOrder.ASC) @@ -591,8 +617,8 @@ public void getTrainedModels(Set modelIds, .setSize(modelIds.size()) .request(); List configs = new ArrayList<>(modelIds.size()); - Set modelsInIndex = Sets.difference(modelIds, MODELS_STORED_AS_RESOURCE); - Set modelsAsResource = Sets.intersection(MODELS_STORED_AS_RESOURCE, modelIds); + Set modelsInIndex = Sets.difference(modelIds.keySet(), MODELS_STORED_AS_RESOURCE); + Set modelsAsResource = Sets.intersection(MODELS_STORED_AS_RESOURCE, modelIds.keySet()); for(String modelId : modelsAsResource) { try { configs.add(loadModelFromResource(modelId, true)); @@ -614,12 +640,12 @@ public void getTrainedModels(Set modelIds, if ((includes.isIncludeFeatureImportanceBaseline() || includes.isIncludeTotalFeatureImportance() || includes.isIncludeHyperparameters()) == false) { finalListener.onResponse(modelBuilders.stream() - .map(TrainedModelConfig.Builder::build) + .map(b -> b.setModelAliases(modelIds.get(b.getModelId())).build()) .sorted(Comparator.comparing(TrainedModelConfig::getModelId)) .collect(Collectors.toList())); return; } - this.getTrainedModelMetadata(modelIds, ActionListener.wrap( + this.getTrainedModelMetadata(modelIds.keySet(), ActionListener.wrap( metadata -> finalListener.onResponse(modelBuilders.stream() .map(builder -> { @@ -634,9 +660,8 @@ public void getTrainedModels(Set modelIds, if (includes.isIncludeHyperparameters()) { builder.setHyperparameters(modelMetadata.getHyperparameters()); } - } - return builder.build(); + return builder.setModelAliases(modelIds.get(builder.getModelId())).build(); }) .sorted(Comparator.comparing(TrainedModelConfig::getModelId)) .collect(Collectors.toList())), @@ -680,7 +705,7 @@ public void getTrainedModels(Set modelIds, // We previously expanded the IDs. // If the config has gone missing between then and now we should throw if allowNoResources is false // Otherwise, treat it as if it was never expanded to begin with. - Set missingConfigs = Sets.difference(modelIds, observedIds); + Set missingConfigs = Sets.difference(modelIds.keySet(), observedIds); if (missingConfigs.isEmpty() == false && allowNoResources == false) { getTrainedModelListener.onFailure(new ResourceNotFoundException(Messages.INFERENCE_NOT_FOUND_MULTIPLE, missingConfigs)); return; @@ -730,8 +755,23 @@ public void expandIds(String idExpression, boolean allowNoResources, PageParams pageParams, Set tags, - ActionListener>> idsListener) { + ModelAliasMetadata modelAliasMetadata, + ActionListener>>> idsListener) { String[] tokens = Strings.tokenizeToStringArray(idExpression, ","); + Set expandedIdsFromAliases = new HashSet<>(); + if (Strings.isAllOrWildcard(tokens) == false) { + for (String token : tokens) { + if (Regex.isSimpleMatchPattern(token)) { + for (String modelAlias : modelAliasMetadata.modelAliases().keySet()) { + if (Regex.simpleMatch(token, modelAlias)) { + expandedIdsFromAliases.add(modelAliasMetadata.getModelId(modelAlias)); + } + } + } else if (modelAliasMetadata.getModelId(token) != null) { + expandedIdsFromAliases.add(modelAliasMetadata.getModelId(token)); + } + } + } Set matchedResourceIds = matchedResourceIds(tokens); Set foundResourceIds; if (tags.isEmpty()) { @@ -745,12 +785,17 @@ public void expandIds(String idExpression, } } } + expandedIdsFromAliases.addAll(Arrays.asList(tokens)); + + // We need to include the translated model alias, and ANY tokens that were not translated + String[] tokensForQuery = expandedIdsFromAliases.toArray(new String[0]); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder() .sort(SortBuilders.fieldSort(TrainedModelConfig.MODEL_ID.getPreferredName()) // If there are no resources, there might be no mapping for the id field. // This makes sure we don't get an error if that happens. .unmappedType("long")) - .query(buildExpandIdsQuery(tokens, tags)) + .query(buildExpandIdsQuery(tokensForQuery, tags)) // We "buffer" the from and size to take into account models stored as resources. // This is so we handle the edge cases when the model that is stored as a resource is at the start/end of // a page. @@ -786,9 +831,28 @@ public void expandIds(String idExpression, foundFromDocs.add(idValue.toString()); } } - Set allFoundIds = collectIds(pageParams, foundResourceIds, foundFromDocs); + Map> allFoundIds = collectIds(pageParams, foundResourceIds, foundFromDocs) + .stream() + .collect(Collectors.toMap(Function.identity(), k -> new HashSet<>())); + + // We technically have matched on model tokens and any reversed referenced aliases + // We may end up with "over matching" on the aliases (matching on an alias that was not provided) + // But the expanded ID matcher does not care. + Set matchedTokens = new HashSet<>(allFoundIds.keySet()); + + // We should gather ALL model aliases referenced by the given model IDs + // This way the callers have access to them + modelAliasMetadata.modelAliases().forEach((alias, modelIdEntry) -> { + final String modelId = modelIdEntry.getModelId(); + if (allFoundIds.containsKey(modelId)) { + allFoundIds.get(modelId).add(alias); + matchedTokens.add(alias); + } + }); + + // Reverse lookup to see what model aliases were matched by their found trained model IDs ExpandedIdsMatcher requiredMatches = new ExpandedIdsMatcher(tokens, allowNoResources); - requiredMatches.filterMatchedIds(allFoundIds); + requiredMatches.filterMatchedIds(matchedTokens); if (requiredMatches.hasUnmatchedIds()) { idsListener.onFailure(ExceptionsHelper.missingTrainedModel(requiredMatches.unmatchedIdsString())); } else { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestPutTrainedModelAliasAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestPutTrainedModelAliasAction.java new file mode 100644 index 0000000000000..7c440f2cc77a5 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestPutTrainedModelAliasAction.java @@ -0,0 +1,57 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.ml.rest.inference; + +import static java.util.Collections.singletonList; +import static org.elasticsearch.rest.RestRequest.Method.PUT; + +import java.io.IOException; +import java.util.List; + +import org.elasticsearch.client.node.NodeClient; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.action.RestToXContentListener; +import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAliasAction; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.ml.MachineLearning; + +public class RestPutTrainedModelAliasAction extends BaseRestHandler { + + @Override + public List routes() { + return singletonList( + new Route( + PUT, + MachineLearning.BASE_PATH + + "trained_models/{" + + TrainedModelConfig.MODEL_ID.getPreferredName() + + "}/model_aliases/{" + + PutTrainedModelAliasAction.Request.MODEL_ALIAS + + "}" + + ) + ); + } + + @Override + public String getName() { + return "ml_put_trained_model_alias_action"; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + String modelAlias = restRequest.param(PutTrainedModelAliasAction.Request.MODEL_ALIAS); + String modelId = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName()); + boolean reassign = restRequest.paramAsBoolean(PutTrainedModelAliasAction.Request.REASSIGN, false); + return channel -> client.execute( + PutTrainedModelAliasAction.INSTANCE, + new PutTrainedModelAliasAction.Request(modelAlias, modelId, reassign), + new RestToXContentListener<>(channel) + ); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsActionTests.java index 08b2e1e17a6f4..b373f9a02959d 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsActionTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsActionTests.java @@ -36,6 +36,7 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.ModelAliasMetadata; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.junit.Before; @@ -132,7 +133,6 @@ public void setUpVariables() { null, Collections.singletonList(SKINNY_INGEST_PLUGIN), client); } - public void testInferenceIngestStatsByModelId() { List nodeStatsList = Arrays.asList( buildNodeStats( @@ -201,6 +201,7 @@ public void testInferenceIngestStatsByModelId() { put("trained_model_2", new HashSet<>(Arrays.asList("pipeline1", "pipeline2"))); }}; Map ingestStatsMap = TransportGetTrainedModelsStatsAction.inferenceIngestStatsByModelId(response, + ModelAliasMetadata.EMPTY, pipelineIdsByModelIds); assertThat(ingestStatsMap.keySet(), equalTo(new HashSet<>(Arrays.asList("trained_model_1", "trained_model_2")))); @@ -241,7 +242,7 @@ public void testPipelineIdsByModelIds() throws IOException { ClusterState clusterState = buildClusterStateWithModelReferences(modelId1, modelId2, modelId3); Map> pipelineIdsByModelIds = - TransportGetTrainedModelsStatsAction.pipelineIdsByModelIds(clusterState, ingestService, modelIds); + TransportGetTrainedModelsStatsAction.pipelineIdsByModelIdsOrAliases(clusterState, ingestService, modelIds); assertThat(pipelineIdsByModelIds.keySet(), equalTo(modelIds)); assertThat(pipelineIdsByModelIds, diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java index b33f6ad707678..9ee4a9fcdfbb0 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java @@ -74,6 +74,7 @@ public void testMutateDocumentWithClassification() { ClassificationConfig.EMPTY_PARAMS, 1.0, 1.0)), + null, true); inferenceProcessor.mutateDocument(response, document); @@ -110,6 +111,7 @@ public void testMutateDocumentClassificationTopNClasses() { classificationConfig, 0.6, 0.6)), + null, true); inferenceProcessor.mutateDocument(response, document); @@ -152,6 +154,7 @@ public void testMutateDocumentClassificationFeatureInfluence() { classificationConfig, 0.6, 0.6)), + null, true); inferenceProcessor.mutateDocument(response, document); @@ -193,6 +196,7 @@ public void testMutateDocumentClassificationTopNClassesWithSpecificField() { classificationConfig, 0.6, 0.6)), + null, true); inferenceProcessor.mutateDocument(response, document); @@ -218,14 +222,16 @@ public void testMutateDocumentRegression() { IngestDocument document = new IngestDocument(source, ingestMetadata); InternalInferModelAction.Response response = new InternalInferModelAction.Response( - Collections.singletonList(new RegressionInferenceResults(0.7, regressionConfig)), true); + Collections.singletonList(new RegressionInferenceResults(0.7, regressionConfig)), + null, + true); inferenceProcessor.mutateDocument(response, document); assertThat(document.getFieldValue("ml.my_processor.foo", Double.class), equalTo(0.7)); assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo("regression_model")); } - public void testMutateDocumentRegressionWithTopFetures() { + public void testMutateDocumentRegressionWithTopFeatures() { RegressionConfig regressionConfig = new RegressionConfig("foo", 2); RegressionConfigUpdate regressionConfigUpdate = new RegressionConfigUpdate("foo", 2); InferenceProcessor inferenceProcessor = new InferenceProcessor(client, @@ -245,7 +251,9 @@ public void testMutateDocumentRegressionWithTopFetures() { featureInfluence.add(new RegressionFeatureImportance("feature_2", -42.0)); InternalInferModelAction.Response response = new InternalInferModelAction.Response( - Collections.singletonList(new RegressionInferenceResults(0.7, regressionConfig, featureInfluence)), true); + Collections.singletonList(new RegressionInferenceResults(0.7, regressionConfig, featureInfluence)), + null, + true); inferenceProcessor.mutateDocument(response, document); assertThat(document.getFieldValue("ml.my_processor.foo", Double.class), equalTo(0.7)); @@ -383,7 +391,9 @@ public void testHandleResponseLicenseChanged() { assertThat(inferenceProcessor.buildRequest(document).isPreviouslyLicensed(), is(false)); InternalInferModelAction.Response response = new InternalInferModelAction.Response( - Collections.singletonList(new RegressionInferenceResults(0.7, RegressionConfig.EMPTY_PARAMS)), true); + Collections.singletonList(new RegressionInferenceResults(0.7, RegressionConfig.EMPTY_PARAMS)), + null, + true); inferenceProcessor.handleResponse(response, document, (doc, ex) -> { assertThat(doc, is(not(nullValue()))); assertThat(ex, is(nullValue())); @@ -392,7 +402,9 @@ public void testHandleResponseLicenseChanged() { assertThat(inferenceProcessor.buildRequest(document).isPreviouslyLicensed(), is(true)); response = new InternalInferModelAction.Response( - Collections.singletonList(new RegressionInferenceResults(0.7, RegressionConfig.EMPTY_PARAMS)), false); + Collections.singletonList(new RegressionInferenceResults(0.7, RegressionConfig.EMPTY_PARAMS)), + null, + false); inferenceProcessor.handleResponse(response, document, (doc, ex) -> { assertThat(doc, is(not(nullValue()))); @@ -424,11 +436,37 @@ public void testMutateDocumentWithWarningResult() { IngestDocument document = new IngestDocument(source, ingestMetadata); InternalInferModelAction.Response response = new InternalInferModelAction.Response( - Collections.singletonList(new WarningInferenceResults("something broke")), true); + Collections.singletonList(new WarningInferenceResults("something broke")), null, true); inferenceProcessor.mutateDocument(response, document); assertThat(document.hasField(targetField), is(false)); assertThat(document.hasField("ml.warning"), is(true)); assertThat(document.hasField("ml.my_processor"), is(false)); } + + public void testMutateDocumentWithModelIdResult() { + String modelAlias = "special_model"; + String modelId = "regression-123"; + InferenceProcessor inferenceProcessor = new InferenceProcessor(client, + auditor, + "my_processor", + null, + "ml.my_processor", + modelAlias, + new RegressionConfigUpdate("foo", null), + Collections.emptyMap()); + + Map source = new HashMap<>(); + Map ingestMetadata = new HashMap<>(); + IngestDocument document = new IngestDocument(source, ingestMetadata); + + InternalInferModelAction.Response response = new InternalInferModelAction.Response( + Collections.singletonList(new RegressionInferenceResults(0.7, new RegressionConfig("foo"))), + modelId, + true); + inferenceProcessor.mutateDocument(response, document); + + assertThat(document.getFieldValue("ml.my_processor.foo", Double.class), equalTo(0.7)); + assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo(modelId)); + } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java index a0639aa530bf2..7ced965b9353e 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingServiceTests.java @@ -22,6 +22,7 @@ import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.breaker.CircuitBreakingException; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.unit.ByteSizeValue; @@ -44,6 +45,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.core.ml.inference.ModelAliasMetadata; import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService; import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; @@ -59,10 +61,12 @@ import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME; import static org.hamcrest.Matchers.equalTo; @@ -282,7 +286,6 @@ public boolean matches(final Object o) { verify(trainedModelProvider, times(5)).getTrainedModelForInference(eq(model3), any()); } - public void testWhenCacheEnabledButNotIngestNode() throws Exception { String model1 = "test-uncached-not-ingest-model-1"; withTrainedModel(model1, 1L); @@ -538,6 +541,101 @@ public void testReferenceCounting_ModelIsNotCached() throws ExecutionException, assertEquals(1, model.getReferenceCount()); } + public void testGetCachedModelViaModelAliases() throws Exception { + String model1 = "test-load-model-1"; + String model2 = "test-load-model-2"; + withTrainedModel(model1, 1L); + withTrainedModel(model2, 1L); + + ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, + auditor, + threadPool, + clusterService, + trainedModelStatsService, + Settings.EMPTY, + "test-node", + circuitBreaker); + + modelLoadingService.clusterChanged(aliasChangeEvent( + true, + new String[]{"loaded_model"}, + true, + Arrays.asList(Tuple.tuple(model1, "loaded_model")) + )); + + String[] modelIds = new String[]{model1, "loaded_model"}; + for(int i = 0; i < 10; i++) { + String model = modelIds[i%2]; + PlainActionFuture future = new PlainActionFuture<>(); + modelLoadingService.getModelForPipeline(model, future); + assertThat(future.get(), is(not(nullValue()))); + } + + verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model1), any()); + + assertTrue(modelLoadingService.isModelCached(model1)); + assertTrue(modelLoadingService.isModelCached("loaded_model")); + + // alias change only + modelLoadingService.clusterChanged(aliasChangeEvent( + true, + new String[]{"loaded_model"}, + false, + Arrays.asList(Tuple.tuple(model2, "loaded_model")) + )); + + modelIds = new String[]{model2, "loaded_model"}; + for(int i = 0; i < 10; i++) { + String model = modelIds[i%2]; + PlainActionFuture future = new PlainActionFuture<>(); + modelLoadingService.getModelForPipeline(model, future); + assertThat(future.get(), is(not(nullValue()))); + } + + verify(trainedModelProvider, times(1)).getTrainedModelForInference(eq(model2), any()); + assertTrue(modelLoadingService.isModelCached(model2)); + assertTrue(modelLoadingService.isModelCached("loaded_model")); + } + + public void testAliasesGetUpdatedEvenWhenNotIngestNode() throws IOException { + String model1 = "test-load-model-1"; + withTrainedModel(model1, 1L); + String model2 = "test-load-model-2"; + withTrainedModel(model2, 1L); + + ModelLoadingService modelLoadingService = new ModelLoadingService(trainedModelProvider, + auditor, + threadPool, + clusterService, + trainedModelStatsService, + Settings.EMPTY, + "test-node", + circuitBreaker); + + modelLoadingService.clusterChanged(aliasChangeEvent( + false, + new String[0], + false, + Arrays.asList(Tuple.tuple(model1, "loaded_model")) + )); + + assertThat(modelLoadingService.getModelId("loaded_model"), equalTo(model1)); + + modelLoadingService.clusterChanged(aliasChangeEvent( + false, + new String[0], + false, + Arrays.asList( + Tuple.tuple(model1, "loaded_model_again"), + Tuple.tuple(model1, "loaded_model_foo"), + Tuple.tuple(model2, "loaded_model") + ) + )); + assertThat(modelLoadingService.getModelId("loaded_model"), equalTo(model2)); + assertThat(modelLoadingService.getModelId("loaded_model_foo"), equalTo(model1)); + assertThat(modelLoadingService.getModelId("loaded_model_again"), equalTo(model1)); + } + @SuppressWarnings("unchecked") private void withTrainedModel(String modelId, long size) { InferenceDefinition definition = mock(InferenceDefinition.class); @@ -601,6 +699,21 @@ private static ClusterChangedEvent ingestChangedEvent(String... modelId) throws return ingestChangedEvent(true, modelId); } + private static ClusterChangedEvent aliasChangeEvent(boolean isIngestNode, + String[] modelId, + boolean ingestToo, + List> modelIdAndAliases) throws IOException { + ClusterChangedEvent event = mock(ClusterChangedEvent.class); + Set set = new HashSet<>(); + set.add(ModelAliasMetadata.NAME); + if (ingestToo) { + set.add(IngestMetadata.TYPE); + } + when(event.changedCustomMetadataSet()).thenReturn(set); + when(event.state()).thenReturn(withModelReferencesAndAliasChange(isIngestNode, modelId, modelIdAndAliases)); + return event; + } + private static ClusterChangedEvent ingestChangedEvent(boolean isIngestNode, String... modelId) throws IOException { ClusterChangedEvent event = mock(ClusterChangedEvent.class); when(event.changedCustomMetadataSet()).thenReturn(Collections.singleton(IngestMetadata.TYPE)); @@ -609,14 +722,17 @@ private static ClusterChangedEvent ingestChangedEvent(boolean isIngestNode, Stri } private static ClusterState buildClusterStateWithModelReferences(boolean isIngestNode, String... modelId) throws IOException { - Map configurations = new HashMap<>(modelId.length); - for (String id : modelId) { - configurations.put("pipeline_with_model_" + id, newConfigurationWithInferenceProcessor(id)); - } - IngestMetadata ingestMetadata = new IngestMetadata(configurations); + return builder(isIngestNode).metadata(addIngest(Metadata.builder(), modelId)).build(); + } + + private static ClusterState withModelReferencesAndAliasChange(boolean isIngestNode, + String[] modelId, + List> modelIdAndAliases) throws IOException { + return builder(isIngestNode).metadata(addAliases(addIngest(Metadata.builder(), modelId), modelIdAndAliases)).build(); + } + private static ClusterState.Builder builder(boolean isIngestNode) { return ClusterState.builder(new ClusterName("_name")) - .metadata(Metadata.builder().putCustom(IngestMetadata.TYPE, ingestMetadata)) .nodes(DiscoveryNodes.builder().add( new DiscoveryNode("node_name", "node_id", @@ -625,8 +741,23 @@ private static ClusterState buildClusterStateWithModelReferences(boolean isInges isIngestNode ? Collections.singleton(DiscoveryNodeRole.INGEST_ROLE) : Collections.emptySet(), Version.CURRENT)) .localNodeId("node_id") - .build()) - .build(); + .build() + ); + } + + private static Metadata.Builder addIngest(Metadata.Builder builder, String... modelId) throws IOException { + Map configurations = new HashMap<>(modelId.length); + for (String id : modelId) { + configurations.put("pipeline_with_model_" + id, newConfigurationWithInferenceProcessor(id)); + } + IngestMetadata ingestMetadata = new IngestMetadata(configurations); + return builder.putCustom(IngestMetadata.TYPE, ingestMetadata); + } + + private static Metadata.Builder addAliases(Metadata.Builder builder, List> modelIdAndAliases) { + ModelAliasMetadata modelAliasMetadata = new ModelAliasMetadata(modelIdAndAliases.stream() + .collect(Collectors.toMap(Tuple::v2, t -> new ModelAliasMetadata.ModelAliasEntry(t.v1())))); + return builder.putCustom(ModelAliasMetadata.NAME, modelAliasMetadata); } private static PipelineConfiguration newConfigurationWithInferenceProcessor(String modelId) throws IOException { diff --git a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java index aa5c3623ab7b3..8d828bd7fd592 100644 --- a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java +++ b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java @@ -134,6 +134,7 @@ public class Constants { "cluster:admin/xpack/ml/filters/update", "cluster:admin/xpack/ml/inference/delete", "cluster:admin/xpack/ml/inference/put", + "cluster:admin/xpack/ml/inference/model_aliases/put", "cluster:admin/xpack/ml/job/close", "cluster:admin/xpack/ml/job/data/post", "cluster:admin/xpack/ml/job/delete", diff --git a/x-pack/plugin/src/test/java/org/elasticsearch/xpack/test/rest/AbstractXPackRestTest.java b/x-pack/plugin/src/test/java/org/elasticsearch/xpack/test/rest/AbstractXPackRestTest.java index bd3fd47586ea1..9cd81093d9e0d 100644 --- a/x-pack/plugin/src/test/java/org/elasticsearch/xpack/test/rest/AbstractXPackRestTest.java +++ b/x-pack/plugin/src/test/java/org/elasticsearch/xpack/test/rest/AbstractXPackRestTest.java @@ -18,12 +18,12 @@ import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.plugins.MetadataUpgrader; import org.elasticsearch.test.SecuritySettingsSourceField; -import org.elasticsearch.test.rest.ESRestTestCase; import org.elasticsearch.test.rest.yaml.ClientYamlTestCandidate; import org.elasticsearch.test.rest.yaml.ClientYamlTestResponse; import org.elasticsearch.test.rest.yaml.ESClientYamlSuiteTestCase; import org.elasticsearch.xpack.core.ml.MlConfigIndex; import org.elasticsearch.xpack.core.ml.MlMetaIndex; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex; import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndexFields; diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.put_trained_model_alias.json b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.put_trained_model_alias.json new file mode 100644 index 0000000000000..c07e96397fe29 --- /dev/null +++ b/x-pack/plugin/src/test/resources/rest-api-spec/api/ml.put_trained_model_alias.json @@ -0,0 +1,40 @@ +{ + "ml.put_trained_model_alias":{ + "documentation":{ + "url":"https://www.elastic.co/guide/en/elasticsearch/reference/current/put-trained-models-aliases.html", + "description":"Creates a new model alias (or reassigns an existing one) to refer to the trained model" + }, + "stability":"beta", + "visibility":"public", + "headers":{ + "accept": [ "application/json"], + "content_type": ["application/json"] + }, + "url":{ + "paths":[ + { + "path":"/_ml/trained_models/{model_id}/model_aliases/{model_alias}", + "methods":[ + "PUT" + ], + "parts":{ + "model_alias":{ + "type":"string", + "description":"The trained model alias to update" + }, + "model_id": { + "type": "string", + "description": "The trained model where the model alias should be assigned" + } + } + } + ] + }, + "params":{ + "reassign":{ + "type":"boolean", + "description":"If the model_alias already exists and points to a separate model_id, this parameter must be true. Defaults to false." + } + } + } +} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/get_trained_model_stats.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/get_trained_model_stats.yml index 4acb694835f68..93f23e1f0f405 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/get_trained_model_stats.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/get_trained_model_stats.yml @@ -79,6 +79,12 @@ setup: } } } + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + ml.put_trained_model_alias: + model_alias: "my-regression" + model_id: "a-unused-regression-model1" --- "Test get stats given missing trained model": @@ -175,3 +181,20 @@ setup: - match: { count: 1 } - match: { trained_model_stats.0.model_id: another-regression-model } - match: { trained_model_stats.0.pipeline_count: 0 } + + +# test with model alias + - do: + ml.get_trained_models_stats: + model_id: "my-regression" + + - match: { count: 1 } + - match: { trained_model_stats.0.model_id: a-unused-regression-model1 } + + - do: + ml.get_trained_models_stats: + model_id: "my-regression,another-regression-model" + + - match: { count: 2 } + - match: { trained_model_stats.0.model_id: a-unused-regression-model1 } + - match: { trained_model_stats.1.model_id: another-regression-model } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml index de7ce694f9857..0994bdf33319c 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml @@ -561,30 +561,6 @@ setup: ml.delete_trained_model: model_id: "missing-trained-model" --- -"Test delete given used trained model": - - do: - ingest.put_pipeline: - id: "regression-model-pipeline" - body: > - { - "processors": [ - { - "inference" : { - "model_id" : "a-regression-model-0", - "inference_config": {"regression": {}}, - "target_field": "regression_field", - "field_map": {} - } - } - ] - } - - match: { acknowledged: true } - - - do: - catch: conflict - ml.delete_trained_model: - model_id: "a-regression-model-0" ---- "Test get pre-packaged trained models": - do: ml.get_trained_models: @@ -851,3 +827,89 @@ setup: model_id: "a-regression-model-1" include_model_definition: true decompress_definition: false +--- +"Test put model model aliases": + + - do: + ml.put_trained_model_alias: + model_alias: "regression-model" + model_id: "a-regression-model-1" + - do: + ml.get_trained_models: + model_id: "regression-model,a-classification-model" + + - match: { count: 2 } + - length: { trained_model_configs: 2 } + - match: { trained_model_configs.0.model_id: "a-classification-model" } + - match: { trained_model_configs.1.model_id: "a-regression-model-1" } + + - do: + ml.put_trained_model_alias: + model_alias: "regression-model" + model_id: "a-regression-model-0" + reassign: true + - do: + ml.get_trained_models: + model_id: "regression-model,a-classification-model" + + - match: { count: 2 } + - length: { trained_model_configs: 2 } + - match: { trained_model_configs.0.model_id: "a-classification-model" } + - match: { trained_model_configs.1.model_id: "a-regression-model-0" } + + - do: + ml.put_trained_model_alias: + model_alias: "regression-model-again" + model_id: "a-regression-model-0" + - do: + ml.get_trained_models: + model_id: "a-regression-model-*" + size: 1 + + - length: { trained_model_configs: 1 } + - match: { trained_model_configs.0.model_id: "a-regression-model-0" } + - match: { trained_model_configs.0.metadata.model_aliases.0: "regression-model" } + - match: { trained_model_configs.0.metadata.model_aliases.1: "regression-model-again" } +--- +"Test update model alias with model id referring to missing model": + - do: + catch: missing + ml.put_trained_model_alias: + model_alias: "regression-model" + model_id: "missing-model" +--- +"Test update model alias with bad alias": + - do: + catch: /must start with alphanumeric and cannot end with numbers/ + ml.put_trained_model_alias: + model_alias: "regression-model-123123" + model_id: "regression-model-123123" + - do: + catch: bad_request + ml.put_trained_model_alias: + model_alias: "z-classification-model" + model_id: "z-classification-model" +--- +"Test update model alias where alias exists but old model id is different inference type": + - do: + ml.put_trained_model_alias: + model_alias: "regression-model" + model_id: "a-regression-model-0" + - do: + catch: bad_request + ml.put_trained_model_alias: + model_alias: "regression-model" + model_id: "a-classification-model" + reassign: true +--- +"Test update model alias where alias exists but reassign is false": + - do: + ml.put_trained_model_alias: + model_alias: "regression-model" + model_id: "a-regression-model-0" + - do: + catch: bad_request + ml.put_trained_model_alias: + model_alias: "regression-model" + model_id: "a-regression-model-1" + reassign: false From 578a11cbac49d0671b9555782909d7b77d4a8b1a Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Thu, 18 Feb 2021 13:11:38 -0500 Subject: [PATCH 2/3] fixing backport --- .../java/org/elasticsearch/xpack/ml/MachineLearning.java | 7 ------- 1 file changed, 7 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 718ebd68b052a..5a2b9cc4d03c0 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -1124,13 +1124,6 @@ public List getNamedXContent() { namedXContent.addAll(new MlDataFrameAnalysisNamedXContentProvider().getNamedXContentParsers()); namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers()); namedXContent.addAll(new MlModelSizeNamedXContentProvider().getNamedXContentParsers()); - namedXContent.add( - new NamedXContentRegistry.Entry( - Metadata.Custom.class, - new ParseField(ModelAliasMetadata.NAME), - ModelAliasMetadata::fromXContent - ) - ); return namedXContent; } From 8f1c2c73dac2f12e3dba75ba77b4e69a7d11b59f Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Thu, 18 Feb 2021 14:40:29 -0500 Subject: [PATCH 3/3] fixing metadata class impl --- .../xpack/core/ml/inference/ModelAliasMetadata.java | 3 ++- .../main/java/org/elasticsearch/xpack/ml/MachineLearning.java | 3 --- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/ModelAliasMetadata.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/ModelAliasMetadata.java index 917277be853a5..03f88197d2795 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/ModelAliasMetadata.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/ModelAliasMetadata.java @@ -21,6 +21,7 @@ import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.XPackPlugin; import java.io.IOException; import java.util.Collections; @@ -32,7 +33,7 @@ /** * Custom {@link Metadata} implementation for storing a map of model aliases that point to model IDs */ -public class ModelAliasMetadata implements Metadata.Custom { +public class ModelAliasMetadata implements XPackPlugin.XPackMetadataCustom { public static final String NAME = "trained_model_alias"; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 5a2b9cc4d03c0..468ad6b114b2e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -19,12 +19,10 @@ import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.metadata.IndexTemplateMetadata; -import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodeRole; import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.common.ParseField; import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.inject.Module; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; @@ -230,7 +228,6 @@ import org.elasticsearch.xpack.ml.dataframe.process.NativeMemoryUsageEstimationProcessFactory; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.MemoryUsageEstimationResult; -import org.elasticsearch.xpack.core.ml.inference.ModelAliasMetadata; import org.elasticsearch.xpack.ml.inference.TrainedModelStatsService; import org.elasticsearch.xpack.ml.inference.aggs.InferencePipelineAggregationBuilder; import org.elasticsearch.xpack.ml.inference.aggs.InternalInferenceAggregation;