Skip to content

Commit 44acc99

Browse files
authored
[7.x] [ML] adding new defer_definition_decompression parameter to put trained model API (#77189) (#77256)
* [ML] adding new defer_definition_decompression parameter to put trained model API (#77189) This new parameter is a boolean parameter that allows users to put in a compressed model without it having to be inflated on the master node during the put request This is useful for system/module set up and then later having the model validated and fully parsed when it is being loaded on a node for usage
1 parent 75a8d66 commit 44acc99

File tree

13 files changed

+201
-59
lines changed

13 files changed

+201
-59
lines changed

docs/reference/ml/df-analytics/apis/put-trained-models.asciidoc

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ WARNING: Models created in version 7.8.0 are not backwards compatible
2424
[[ml-put-trained-models-prereq]]
2525
== {api-prereq-title}
2626

27-
Requires the `manage_ml` cluster privilege. This privilege is included in the
27+
Requires the `manage_ml` cluster privilege. This privilege is included in the
2828
`machine_learning_admin` built-in role.
2929

3030

@@ -42,6 +42,17 @@ created by {dfanalytics}.
4242
(Required, string)
4343
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=model-id]
4444

45+
[[ml-put-trained-models-query-params]]
46+
== {api-query-parms-title}
47+
48+
`defer_definition_decompression`::
49+
(Optional, boolean)
50+
If set to `true` and a `compressed_definition` is provided, the request defers
51+
definition decompression and skips relevant validations.
52+
This deferral is useful for systems or users that know a good JVM heap size estimate for their
53+
model and know that their model is valid and likely won't fail during inference.
54+
55+
4556
[role="child_attributes"]
4657
[[ml-put-trained-models-request-body]]
4758
== {api-request-body-title}

rest-api-spec/src/main/resources/rest-api-spec/api/ml.put_trained_model.json

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,14 @@
2626
}
2727
]
2828
},
29+
"params":{
30+
"defer_definition_decompression": {
31+
"required": false,
32+
"type": "boolean",
33+
"description": "If set to `true` and a `compressed_definition` is provided, the request defers definition decompression and skips relevant validations.",
34+
"default": false
35+
}
36+
},
2937
"body":{
3038
"description":"The trained model configuration",
3139
"required":true

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelAction.java

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
*/
77
package org.elasticsearch.xpack.core.ml.action;
88

9+
import org.elasticsearch.Version;
910
import org.elasticsearch.action.ActionRequestValidationException;
1011
import org.elasticsearch.action.ActionResponse;
1112
import org.elasticsearch.action.ActionType;
@@ -22,9 +23,12 @@
2223
import java.io.IOException;
2324
import java.util.Objects;
2425

26+
import static org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig.ESTIMATED_HEAP_MEMORY_USAGE_BYTES;
27+
2528

2629
public class PutTrainedModelAction extends ActionType<PutTrainedModelAction.Response> {
2730

31+
public static final String DEFER_DEFINITION_DECOMPRESSION = "defer_definition_decompression";
2832
public static final PutTrainedModelAction INSTANCE = new PutTrainedModelAction();
2933
public static final String NAME = "cluster:admin/xpack/ml/inference/put";
3034
private PutTrainedModelAction() {
@@ -33,7 +37,7 @@ private PutTrainedModelAction() {
3337

3438
public static class Request extends AcknowledgedRequest<Request> {
3539

36-
public static Request parseRequest(String modelId, XContentParser parser) {
40+
public static Request parseRequest(String modelId, boolean deferDefinitionValidation, XContentParser parser) {
3741
TrainedModelConfig.Builder builder = TrainedModelConfig.STRICT_PARSER.apply(parser, null);
3842

3943
if (builder.getModelId() == null) {
@@ -47,18 +51,25 @@ public static Request parseRequest(String modelId, XContentParser parser) {
4751
}
4852
// Validations are done against the builder so we can build the full config object.
4953
// This allows us to not worry about serializing a builder class between nodes.
50-
return new Request(builder.validate(true).build());
54+
return new Request(builder.validate(true).build(), deferDefinitionValidation);
5155
}
5256

5357
private final TrainedModelConfig config;
58+
private final boolean deferDefinitionDecompression;
5459

55-
public Request(TrainedModelConfig config) {
60+
public Request(TrainedModelConfig config, boolean deferDefinitionDecompression) {
5661
this.config = config;
62+
this.deferDefinitionDecompression = deferDefinitionDecompression;
5763
}
5864

5965
public Request(StreamInput in) throws IOException {
6066
super(in);
6167
this.config = new TrainedModelConfig(in);
68+
if (in.getVersion().onOrAfter(Version.V_7_16_0)) {
69+
this.deferDefinitionDecompression = in.readBoolean();
70+
} else {
71+
this.deferDefinitionDecompression = false;
72+
}
6273
}
6374

6475
public TrainedModelConfig getTrainedModelConfig() {
@@ -67,26 +78,44 @@ public TrainedModelConfig getTrainedModelConfig() {
6778

6879
@Override
6980
public ActionRequestValidationException validate() {
81+
if (deferDefinitionDecompression
82+
&& config.getEstimatedHeapMemory() == 0
83+
&& config.getCompressedDefinitionIfSet() != null) {
84+
ActionRequestValidationException validationException = new ActionRequestValidationException();
85+
validationException.addValidationError(
86+
"when ["
87+
+ DEFER_DEFINITION_DECOMPRESSION
88+
+ "] is true and a compressed definition is provided, " + ESTIMATED_HEAP_MEMORY_USAGE_BYTES + " must be set"
89+
);
90+
return validationException;
91+
}
7092
return null;
7193
}
7294

95+
public boolean isDeferDefinitionDecompression() {
96+
return deferDefinitionDecompression;
97+
}
98+
7399
@Override
74100
public void writeTo(StreamOutput out) throws IOException {
75101
super.writeTo(out);
76102
config.writeTo(out);
103+
if (out.getVersion().onOrAfter(Version.V_7_16_0)) {
104+
out.writeBoolean(deferDefinitionDecompression);
105+
}
77106
}
78107

79108
@Override
80109
public boolean equals(Object o) {
81110
if (this == o) return true;
82111
if (o == null || getClass() != o.getClass()) return false;
83112
Request request = (Request) o;
84-
return Objects.equals(config, request.config);
113+
return Objects.equals(config, request.config) && deferDefinitionDecompression == request.deferDefinitionDecompression;
85114
}
86115

87116
@Override
88117
public int hashCode() {
89-
return Objects.hash(config);
118+
return Objects.hash(config, deferDefinitionDecompression);
90119
}
91120

92121
@Override

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,13 @@ public String getCompressedDefinition() throws IOException {
247247
return definition.getCompressedString();
248248
}
249249

250+
public String getCompressedDefinitionIfSet() {
251+
if (definition == null) {
252+
return null;
253+
}
254+
return definition.getCompressedDefinitionIfSet();
255+
}
256+
250257
public void clearCompressed() {
251258
definition.compressedString = null;
252259
}
@@ -622,6 +629,7 @@ public Builder validate() {
622629

623630
/**
624631
* Runs validations against the builder.
632+
* @param forCreation indicates if we should validate for model creation or for a model read from storage
625633
* @return The current builder object if validations are successful
626634
* @throws ActionRequestValidationException when there are validation failures.
627635
*/
@@ -682,12 +690,6 @@ public Builder validate(boolean forCreation) {
682690
validationException = checkIllegalSetting(version, VERSION.getPreferredName(), validationException);
683691
validationException = checkIllegalSetting(createdBy, CREATED_BY.getPreferredName(), validationException);
684692
validationException = checkIllegalSetting(createTime, CREATE_TIME.getPreferredName(), validationException);
685-
validationException = checkIllegalSetting(estimatedHeapMemory,
686-
ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(),
687-
validationException);
688-
validationException = checkIllegalSetting(estimatedOperations,
689-
ESTIMATED_OPERATIONS.getPreferredName(),
690-
validationException);
691693
validationException = checkIllegalSetting(licenseLevel, LICENSE_LEVEL.getPreferredName(), validationException);
692694
if (metadata != null) {
693695
validationException = checkIllegalSetting(
@@ -783,6 +785,10 @@ public String getCompressedString() throws IOException {
783785
return compressedString;
784786
}
785787

788+
public String getCompressedDefinitionIfSet() {
789+
return compressedString;
790+
}
791+
786792
private void ensureParsedDefinitionUnsafe(NamedXContentRegistry xContentRegistry) throws IOException {
787793
if (parsedDefinition == null) {
788794
parsedDefinition = InferenceToXContentCompressor.inflateUnsafe(compressedString,

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/PutTrainedModelActionRequestTests.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,12 @@ public class PutTrainedModelActionRequestTests extends AbstractWireSerializingTe
2020
@Override
2121
protected Request createTestInstance() {
2222
String modelId = randomAlphaOfLength(10);
23-
return new Request(TrainedModelConfigTests.createTestInstance(modelId)
24-
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
25-
.build());
23+
return new Request(
24+
TrainedModelConfigTests.createTestInstance(modelId)
25+
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder())
26+
.build(),
27+
randomBoolean()
28+
);
2629
}
2730

2831
@Override

x-pack/plugin/ml/qa/ml-with-security/build.gradle

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ tasks.named("yamlRestTest").configure {
158158
'ml/inference_crud/Test update model alias where alias exists but reassign is false',
159159
'ml/inference_crud/Test delete model alias with missing alias',
160160
'ml/inference_crud/Test delete model alias where alias points to different model',
161+
'ml/inference_crud/Test put with defer_definition_decompression with invalid compression definition and no memory estimate',
162+
'ml/inference_crud/Test put with defer_definition_decompression with invalid definition and no memory estimate',
161163
'ml/inference_processor/Test create processor with missing mandatory fields',
162164
'ml/inference_stats_crud/Test get stats given missing trained model',
163165
'ml/inference_stats_crud/Test get stats given expression without matches and allow_no_match is false',

x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/license/MachineLearningLicensingIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -826,7 +826,7 @@ private void putInferenceModel(String modelId) {
826826
.setInput(new TrainedModelInput(Collections.singletonList("feature1")))
827827
.setInferenceConfig(RegressionConfig.EMPTY_PARAMS)
828828
.build();
829-
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config)).actionGet();
829+
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config, false)).actionGet();
830830
}
831831

832832
private static OperationMode randomInvalidLicenseType() {

x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TestFeatureLicenseTrackingIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ public void testFeatureTrackingInferenceModelPipeline() throws Exception {
127127
.setPreProcessors(Arrays.asList(new OneHotEncoding("other.categorical", oneHotEncoding, false)))
128128
.setTrainedModel(buildClassification(true)))
129129
.build();
130-
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config)).actionGet();
130+
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(config, false)).actionGet();
131131

132132
String pipelineId = "pipeline-inference-model-tracked";
133133
putTrainedModelIngestPipeline(pipelineId, modelId);

x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/UnusedStatsRemoverIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ public void testRemoveUnusedStats() throws Exception {
8787
.build())
8888
)
8989
.validate(true)
90-
.build())).actionGet();
90+
.build(),
91+
false)).actionGet();
9192

9293
indexStatDocument(new DataCounts("analytics-with-stats", 1, 1, 1),
9394
DataCounts.documentId("analytics-with-stats"));

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportPutTrainedModelAction.java

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ public TransportPutTrainedModelAction(TransportService transportService, Cluster
7272
protected void masterOperation(Request request,
7373
ClusterState state,
7474
ActionListener<Response> listener) {
75+
76+
final TrainedModelConfig config = request.getTrainedModelConfig();
7577
// 7.8.0 introduced splitting the model definition across multiple documents.
7678
// This means that new models will not be usable on nodes that cannot handle multiple definition documents
7779
if (state.nodes().getMinNodeVersion().before(Version.V_7_8_0)) {
@@ -82,8 +84,10 @@ protected void masterOperation(Request request,
8284
return;
8385
}
8486
try {
85-
request.getTrainedModelConfig().ensureParsedDefinition(xContentRegistry);
86-
request.getTrainedModelConfig().getModelDefinition().getTrainedModel().validate();
87+
if (request.isDeferDefinitionDecompression() == false) {
88+
config.ensureParsedDefinition(xContentRegistry);
89+
config.getModelDefinition().getTrainedModel().validate();
90+
}
8791
} catch (IOException ex) {
8892
listener.onFailure(ExceptionsHelper.badRequestException("Failed to parse definition for [{}]",
8993
ex,
@@ -95,43 +99,47 @@ protected void masterOperation(Request request,
9599
request.getTrainedModelConfig().getModelId()));
96100
return;
97101
}
98-
if (request.getTrainedModelConfig()
99-
.getInferenceConfig()
100-
.isTargetTypeSupported(request.getTrainedModelConfig()
101-
.getModelDefinition()
102-
.getTrainedModel()
103-
.targetType()) == false) {
104-
listener.onFailure(ExceptionsHelper.badRequestException(
105-
"Model [{}] inference config type [{}] does not support definition target type [{}]",
106-
request.getTrainedModelConfig().getModelId(),
107-
request.getTrainedModelConfig().getInferenceConfig().getName(),
108-
request.getTrainedModelConfig()
102+
103+
// NOTE: hasModelDefinition is false if we don't parse it. But, if the fully parsed model was already provided, continue
104+
boolean hasModelDefinition = config.getModelDefinition() != null;
105+
if (hasModelDefinition) {
106+
if (config.getInferenceConfig()
107+
.isTargetTypeSupported(config
109108
.getModelDefinition()
110109
.getTrainedModel()
111-
.targetType()));
112-
return;
113-
}
114-
115-
Version minCompatibilityVersion = request.getTrainedModelConfig()
116-
.getModelDefinition()
117-
.getTrainedModel()
118-
.getMinimalCompatibilityVersion();
119-
if (state.nodes().getMinNodeVersion().before(minCompatibilityVersion)) {
120-
listener.onFailure(ExceptionsHelper.badRequestException(
121-
"Definition for [{}] requires that all nodes are at least version [{}]",
122-
request.getTrainedModelConfig().getModelId(),
123-
minCompatibilityVersion.toString()));
124-
return;
110+
.targetType()) == false) {
111+
listener.onFailure(ExceptionsHelper.badRequestException(
112+
"Model [{}] inference config type [{}] does not support definition target type [{}]",
113+
config.getModelId(),
114+
config.getInferenceConfig().getName(),
115+
config.getModelDefinition().getTrainedModel().targetType()));
116+
return;
117+
}
118+
119+
Version minCompatibilityVersion = config
120+
.getModelDefinition()
121+
.getTrainedModel()
122+
.getMinimalCompatibilityVersion();
123+
if (state.nodes().getMinNodeVersion().before(minCompatibilityVersion)) {
124+
listener.onFailure(ExceptionsHelper.badRequestException(
125+
"Definition for [{}] requires that all nodes are at least version [{}]",
126+
request.getTrainedModelConfig().getModelId(),
127+
minCompatibilityVersion.toString()));
128+
return;
129+
}
125130
}
126131

127-
TrainedModelConfig trainedModelConfig = new TrainedModelConfig.Builder(request.getTrainedModelConfig())
132+
TrainedModelConfig.Builder trainedModelConfigBuilder = new TrainedModelConfig.Builder(config)
128133
.setVersion(Version.CURRENT)
129134
.setCreateTime(Instant.now())
130135
.setCreatedBy("api_user")
131-
.setLicenseLevel(License.OperationMode.PLATINUM.description())
132-
.setEstimatedHeapMemory(request.getTrainedModelConfig().getModelDefinition().ramBytesUsed())
133-
.setEstimatedOperations(request.getTrainedModelConfig().getModelDefinition().getTrainedModel().estimatedNumOperations())
134-
.build();
136+
.setLicenseLevel(License.OperationMode.PLATINUM.description());
137+
if (hasModelDefinition) {
138+
trainedModelConfigBuilder.setEstimatedHeapMemory(request.getTrainedModelConfig().getModelDefinition().ramBytesUsed())
139+
.setEstimatedOperations(request.getTrainedModelConfig().getModelDefinition().getTrainedModel().estimatedNumOperations());
140+
}
141+
TrainedModelConfig trainedModelConfig = trainedModelConfigBuilder.build();
142+
135143
if (ModelAliasMetadata.fromState(state).getModelId(trainedModelConfig.getModelId()) != null) {
136144
listener.onFailure(ExceptionsHelper.badRequestException(
137145
"requested model_id [{}] is the same as an existing model_alias. Model model_aliases and ids must be unique",

0 commit comments

Comments
 (0)