Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion docs/reference/ml/df-analytics/apis/put-trained-models.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ WARNING: Models created in version 7.8.0 are not backwards compatible
[[ml-put-trained-models-prereq]]
== {api-prereq-title}

Requires the `manage_ml` cluster privilege. This privilege is included in the
Requires the `manage_ml` cluster privilege. This privilege is included in the
`machine_learning_admin` built-in role.


Expand All @@ -42,6 +42,17 @@ created by {dfanalytics}.
(Required, string)
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]

[[ml-put-trained-models-query-params]]
== {api-query-parms-title}

`defer_definition_decompression`::
(Optional, boolean)
If set to `true` and a `compressed_definition` is provided, the request defers
definition decompression and skips relevant validations.
This deferral is useful for systems or users that know a good JVM heap size estimate for their
model and know that their model is valid and likely won't fail during inference.


[role="child_attributes"]
[[ml-put-trained-models-request-body]]
== {api-request-body-title}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@
}
]
},
"params":{
"defer_definition_decompression": {
"required": false,
"type": "boolean",
"description": "If set to `true` and a `compressed_definition` is provided, the request defers definition decompression and skips relevant validations.",
"default": false
}
},
"body":{
"description":"The trained model configuration",
"required":true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
*/
package org.elasticsearch.xpack.core.ml.action;

import org.elasticsearch.Version;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
Expand All @@ -22,9 +23,12 @@
import java.io.IOException;
import java.util.Objects;

import static org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig.ESTIMATED_HEAP_MEMORY_USAGE_BYTES;


public class PutTrainedModelAction extends ActionType<PutTrainedModelAction.Response> {

public static final String DEFER_DEFINITION_DECOMPRESSION = "defer_definition_decompression";
public static final PutTrainedModelAction INSTANCE = new PutTrainedModelAction();
public static final String NAME = "cluster:admin/xpack/ml/inference/put";
private PutTrainedModelAction() {
Expand All @@ -33,7 +37,7 @@ private PutTrainedModelAction() {

public static class Request extends AcknowledgedRequest<Request> {

public static Request parseRequest(String modelId, XContentParser parser) {
public static Request parseRequest(String modelId, boolean deferDefinitionValidation, XContentParser parser) {
TrainedModelConfig.Builder builder = TrainedModelConfig.STRICT_PARSER.apply(parser, null);

if (builder.getModelId() == null) {
Expand All @@ -47,18 +51,25 @@ public static Request parseRequest(String modelId, XContentParser parser) {
}
// Validations are done against the builder so we can build the full config object.
// This allows us to not worry about serializing a builder class between nodes.
return new Request(builder.validate(true).build());
return new Request(builder.validate(true).build(), deferDefinitionValidation);
}

private final TrainedModelConfig config;
private final boolean deferDefinitionDecompression;

public Request(TrainedModelConfig config) {
public Request(TrainedModelConfig config, boolean deferDefinitionDecompression) {
this.config = config;
this.deferDefinitionDecompression = deferDefinitionDecompression;
}

public Request(StreamInput in) throws IOException {
super(in);
this.config = new TrainedModelConfig(in);
if (in.getVersion().onOrAfter(Version.V_7_16_0)) {
this.deferDefinitionDecompression = in.readBoolean();
} else {
this.deferDefinitionDecompression = false;
}
}

public TrainedModelConfig getTrainedModelConfig() {
Expand All @@ -67,26 +78,44 @@ public TrainedModelConfig getTrainedModelConfig() {

@Override
public ActionRequestValidationException validate() {
if (deferDefinitionDecompression
&& config.getEstimatedHeapMemory() == 0
&& config.getCompressedDefinitionIfSet() != null) {
ActionRequestValidationException validationException = new ActionRequestValidationException();
validationException.addValidationError(
"when ["
+ DEFER_DEFINITION_DECOMPRESSION
+ "] is true and a compressed definition is provided, " + ESTIMATED_HEAP_MEMORY_USAGE_BYTES + " must be set"
);
return validationException;
}
return null;
}

public boolean isDeferDefinitionDecompression() {
return deferDefinitionDecompression;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
config.writeTo(out);
if (out.getVersion().onOrAfter(Version.V_7_16_0)) {
out.writeBoolean(deferDefinitionDecompression);
}
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Request request = (Request) o;
return Objects.equals(config, request.config);
return Objects.equals(config, request.config) && deferDefinitionDecompression == request.deferDefinitionDecompression;
}

@Override
public int hashCode() {
return Objects.hash(config);
return Objects.hash(config, deferDefinitionDecompression);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,13 @@ public String getCompressedDefinition() throws IOException {
return definition.getCompressedString();
}

public String getCompressedDefinitionIfSet() {
if (definition == null) {
return null;
}
return definition.getCompressedDefinitionIfSet();
}

public void clearCompressed() {
definition.compressedString = null;
}
Expand Down Expand Up @@ -622,6 +629,7 @@ public Builder validate() {

/**
* Runs validations against the builder.
* @param forCreation indicates if we should validate for model creation or for a model read from storage
* @return The current builder object if validations are successful
* @throws ActionRequestValidationException when there are validation failures.
*/
Expand Down Expand Up @@ -682,12 +690,6 @@ public Builder validate(boolean forCreation) {
validationException = checkIllegalSetting(version, VERSION.getPreferredName(), validationException);
validationException = checkIllegalSetting(createdBy, CREATED_BY.getPreferredName(), validationException);
validationException = checkIllegalSetting(createTime, CREATE_TIME.getPreferredName(), validationException);
validationException = checkIllegalSetting(estimatedHeapMemory,
ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(),
validationException);
validationException = checkIllegalSetting(estimatedOperations,
ESTIMATED_OPERATIONS.getPreferredName(),
validationException);
validationException = checkIllegalSetting(licenseLevel, LICENSE_LEVEL.getPreferredName(), validationException);
if (metadata != null) {
validationException = checkIllegalSetting(
Expand Down Expand Up @@ -783,6 +785,10 @@ public String getCompressedString() throws IOException {
return compressedString;
}

public String getCompressedDefinitionIfSet() {
return compressedString;
}

private void ensureParsedDefinitionUnsafe(NamedXContentRegistry xContentRegistry) throws IOException {
if (parsedDefinition == null) {
parsedDefinition = InferenceToXContentCompressor.inflateUnsafe(compressedString,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@ public class PutTrainedModelActionRequestTests extends AbstractWireSerializingTe
@Override
protected Request createTestInstance() {
String modelId = randomAlphaOfLength(10);
return new Request(TrainedModelConfigTests.createTestInstance(modelId)
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
.build());
return new Request(
TrainedModelConfigTests.createTestInstance(modelId)
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
.build(),
randomBoolean()
);
}

@Override
Expand Down
2 changes: 2 additions & 0 deletions x-pack/plugin/ml/qa/ml-with-security/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ tasks.named("yamlRestTest").configure {
'ml/inference_crud/Test update model alias where alias exists but reassign is false',
'ml/inference_crud/Test delete model alias with missing alias',
'ml/inference_crud/Test delete model alias where alias points to different model',
'ml/inference_crud/Test put with defer_definition_decompression with invalid compression definition and no memory estimate',
'ml/inference_crud/Test put with defer_definition_decompression with invalid definition and no memory estimate',
'ml/inference_processor/Test create processor with missing mandatory fields',
'ml/inference_stats_crud/Test get stats given missing trained model',
'ml/inference_stats_crud/Test get stats given expression without matches and allow_no_match is false',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,7 @@ private void putInferenceModel(String modelId) {
.setInput(new TrainedModelInput(Collections.singletonList("feature1")))
.setInferenceConfig(RegressionConfig.EMPTY_PARAMS)
.build();
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config)).actionGet();
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config, false)).actionGet();
}

private static OperationMode randomInvalidLicenseType() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ public void testFeatureTrackingInferenceModelPipeline() throws Exception {
.setPreProcessors(Arrays.asList(new OneHotEncoding("other.categorical", oneHotEncoding, false)))
.setTrainedModel(buildClassification(true)))
.build();
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config)).actionGet();
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config, false)).actionGet();

String pipelineId = "pipeline-inference-model-tracked";
putTrainedModelIngestPipeline(pipelineId, modelId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ public void testRemoveUnusedStats() throws Exception {
.build())
)
.validate(true)
.build())).actionGet();
.build(),
false)).actionGet();

indexStatDocument(new DataCounts("analytics-with-stats", 1, 1, 1),
DataCounts.documentId("analytics-with-stats"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ public TransportPutTrainedModelAction(TransportService transportService, Cluster
protected void masterOperation(Request request,
ClusterState state,
ActionListener<Response> listener) {

final TrainedModelConfig config = request.getTrainedModelConfig();
// 7.8.0 introduced splitting the model definition across multiple documents.
// This means that new models will not be usable on nodes that cannot handle multiple definition documents
if (state.nodes().getMinNodeVersion().before(Version.V_7_8_0)) {
Expand All @@ -82,8 +84,10 @@ protected void masterOperation(Request request,
return;
}
try {
request.getTrainedModelConfig().ensureParsedDefinition(xContentRegistry);
request.getTrainedModelConfig().getModelDefinition().getTrainedModel().validate();
if (request.isDeferDefinitionDecompression() == false) {
config.ensureParsedDefinition(xContentRegistry);
config.getModelDefinition().getTrainedModel().validate();
}
} catch (IOException ex) {
listener.onFailure(ExceptionsHelper.badRequestException("Failed to parse definition for [{}]",
ex,
Expand All @@ -95,43 +99,47 @@ protected void masterOperation(Request request,
request.getTrainedModelConfig().getModelId()));
return;
}
if (request.getTrainedModelConfig()
.getInferenceConfig()
.isTargetTypeSupported(request.getTrainedModelConfig()
.getModelDefinition()
.getTrainedModel()
.targetType()) == false) {
listener.onFailure(ExceptionsHelper.badRequestException(
"Model [{}] inference config type [{}] does not support definition target type [{}]",
request.getTrainedModelConfig().getModelId(),
request.getTrainedModelConfig().getInferenceConfig().getName(),
request.getTrainedModelConfig()

// NOTE: hasModelDefinition is false if we don't parse it. But, if the fully parsed model was already provided, continue
boolean hasModelDefinition = config.getModelDefinition() != null;
if (hasModelDefinition) {
if (config.getInferenceConfig()
.isTargetTypeSupported(config
.getModelDefinition()
.getTrainedModel()
.targetType()));
return;
}

Version minCompatibilityVersion = request.getTrainedModelConfig()
.getModelDefinition()
.getTrainedModel()
.getMinimalCompatibilityVersion();
if (state.nodes().getMinNodeVersion().before(minCompatibilityVersion)) {
listener.onFailure(ExceptionsHelper.badRequestException(
"Definition for [{}] requires that all nodes are at least version [{}]",
request.getTrainedModelConfig().getModelId(),
minCompatibilityVersion.toString()));
return;
.targetType()) == false) {
listener.onFailure(ExceptionsHelper.badRequestException(
"Model [{}] inference config type [{}] does not support definition target type [{}]",
config.getModelId(),
config.getInferenceConfig().getName(),
config.getModelDefinition().getTrainedModel().targetType()));
return;
}

Version minCompatibilityVersion = config
.getModelDefinition()
.getTrainedModel()
.getMinimalCompatibilityVersion();
if (state.nodes().getMinNodeVersion().before(minCompatibilityVersion)) {
listener.onFailure(ExceptionsHelper.badRequestException(
"Definition for [{}] requires that all nodes are at least version [{}]",
request.getTrainedModelConfig().getModelId(),
minCompatibilityVersion.toString()));
return;
}
}

TrainedModelConfig trainedModelConfig = new TrainedModelConfig.Builder(request.getTrainedModelConfig())
TrainedModelConfig.Builder trainedModelConfigBuilder = new TrainedModelConfig.Builder(config)
.setVersion(Version.CURRENT)
.setCreateTime(Instant.now())
.setCreatedBy("api_user")
.setLicenseLevel(License.OperationMode.PLATINUM.description())
.setEstimatedHeapMemory(request.getTrainedModelConfig().getModelDefinition().ramBytesUsed())
.setEstimatedOperations(request.getTrainedModelConfig().getModelDefinition().getTrainedModel().estimatedNumOperations())
.build();
.setLicenseLevel(License.OperationMode.PLATINUM.description());
if (hasModelDefinition) {
trainedModelConfigBuilder.setEstimatedHeapMemory(request.getTrainedModelConfig().getModelDefinition().ramBytesUsed())
.setEstimatedOperations(request.getTrainedModelConfig().getModelDefinition().getTrainedModel().estimatedNumOperations());
}
TrainedModelConfig trainedModelConfig = trainedModelConfigBuilder.build();

if (ModelAliasMetadata.fromState(state).getModelId(trainedModelConfig.getModelId()) != null) {
listener.onFailure(ExceptionsHelper.badRequestException(
"requested model_id [{}] is the same as an existing model_alias. Model model_aliases and ids must be unique",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition;
Expand Down Expand Up @@ -134,20 +133,20 @@ public void storeTrainedModel(TrainedModelConfig trainedModelConfig,
return;
}

String definition;
try {
trainedModelConfig.ensureParsedDefinition(xContentRegistry);
definition = trainedModelConfig.getCompressedDefinition();
} catch (IOException ex) {
listener.onFailure(ExceptionsHelper.serverError(
"Unexpected serialization error when parsing model definition for model [" + trainedModelConfig.getModelId() + "]",
ex));
"Unexpected IOException while serializing definition for storage for model [{}]",
ex,
trainedModelConfig.getModelId()));
return;
}

TrainedModelDefinition definition = trainedModelConfig.getModelDefinition();
if (definition == null) {
listener.onFailure(ExceptionsHelper.badRequestException("Unable to store [{}]. [{}] is required",
trainedModelConfig.getModelId(),
TrainedModelConfig.DEFINITION.getPreferredName()));
TrainedModelConfig.COMPRESSED_DEFINITION.getPreferredName()));
return;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ public String getName() {
protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
String id = restRequest.param(TrainedModelConfig.MODEL_ID.getPreferredName());
XContentParser parser = restRequest.contentParser();
PutTrainedModelAction.Request putRequest = PutTrainedModelAction.Request.parseRequest(id, parser);
boolean deferDefinitionDecompression = restRequest.paramAsBoolean(PutTrainedModelAction.DEFER_DEFINITION_DECOMPRESSION, false);
PutTrainedModelAction.Request putRequest = PutTrainedModelAction.Request.parseRequest(id, deferDefinitionDecompression, parser);
putRequest.timeout(restRequest.paramAsTime("timeout", putRequest.timeout()));

return channel -> client.execute(PutTrainedModelAction.INSTANCE, putRequest, new RestToXContentListener<>(channel));
}
}
Loading