Skip to content

Commit cd79777

Browse files
authored
[ML][Inference] add new flag for optionally including model definition (#48718)
* [ML][Inference] add new flag for optionally including model definition * adjusting after definition and config split * revert unnecessary changes to AbstractTransportGetResourcesAction * fixing TrainedModelDefinitionTests * fixing yaml tests from previous code changes * fixing integration test * making tests an assertBusy for verification
1 parent f531a9d commit cd79777

File tree

13 files changed

+391
-200
lines changed

13 files changed

+391
-200
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsAction.java

Lines changed: 61 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,20 @@
55
*/
66
package org.elasticsearch.xpack.core.ml.action;
77

8-
import org.elasticsearch.action.ActionRequestBuilder;
98
import org.elasticsearch.action.ActionType;
10-
import org.elasticsearch.client.ElasticsearchClient;
119
import org.elasticsearch.common.ParseField;
1210
import org.elasticsearch.common.io.stream.StreamInput;
11+
import org.elasticsearch.common.io.stream.StreamOutput;
1312
import org.elasticsearch.xpack.core.action.AbstractGetResourcesRequest;
1413
import org.elasticsearch.xpack.core.action.AbstractGetResourcesResponse;
1514
import org.elasticsearch.xpack.core.action.util.QueryPage;
1615
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
1716

1817
import java.io.IOException;
18+
import java.util.Collections;
19+
import java.util.List;
20+
import java.util.Objects;
21+
1922

2023
public class GetTrainedModelsAction extends ActionType<GetTrainedModelsAction.Response> {
2124

@@ -28,26 +31,53 @@ private GetTrainedModelsAction() {
2831

2932
public static class Request extends AbstractGetResourcesRequest {
3033

34+
public static final ParseField INCLUDE_MODEL_DEFINITION = new ParseField("include_model_definition");
3135
public static final ParseField ALLOW_NO_MATCH = new ParseField("allow_no_match");
3236

33-
public Request() {
34-
setAllowNoResources(true);
35-
}
37+
private final boolean includeModelDefinition;
3638

37-
public Request(String id) {
39+
public Request(String id, boolean includeModelDefinition) {
3840
setResourceId(id);
3941
setAllowNoResources(true);
42+
this.includeModelDefinition = includeModelDefinition;
4043
}
4144

4245
public Request(StreamInput in) throws IOException {
4346
super(in);
47+
this.includeModelDefinition = in.readBoolean();
4448
}
4549

4650
@Override
4751
public String getResourceIdField() {
4852
return TrainedModelConfig.MODEL_ID.getPreferredName();
4953
}
5054

55+
public boolean isIncludeModelDefinition() {
56+
return includeModelDefinition;
57+
}
58+
59+
@Override
60+
public void writeTo(StreamOutput out) throws IOException {
61+
super.writeTo(out);
62+
out.writeBoolean(includeModelDefinition);
63+
}
64+
65+
@Override
66+
public int hashCode() {
67+
return Objects.hash(super.hashCode(), includeModelDefinition);
68+
}
69+
70+
@Override
71+
public boolean equals(Object obj) {
72+
if (obj == this) {
73+
return true;
74+
}
75+
if (obj == null || getClass() != obj.getClass()) {
76+
return false;
77+
}
78+
Request other = (Request) obj;
79+
return super.equals(obj) && this.includeModelDefinition == other.includeModelDefinition;
80+
}
5181
}
5282

5383
public static class Response extends AbstractGetResourcesResponse<TrainedModelConfig> {
@@ -66,12 +96,33 @@ public Response(QueryPage<TrainedModelConfig> trainedModels) {
6696
protected Reader<TrainedModelConfig> getReader() {
6797
return TrainedModelConfig::new;
6898
}
69-
}
7099

71-
public static class RequestBuilder extends ActionRequestBuilder<Request, Response> {
100+
public static Builder builder() {
101+
return new Builder();
102+
}
103+
104+
public static class Builder {
72105

73-
public RequestBuilder(ElasticsearchClient client) {
74-
super(client, INSTANCE, new Request());
106+
private long totalCount;
107+
private List<TrainedModelConfig> configs = Collections.emptyList();
108+
109+
private Builder() {
110+
}
111+
112+
public Builder setTotalCount(long totalCount) {
113+
this.totalCount = totalCount;
114+
return this;
115+
}
116+
117+
public Builder setModels(List<TrainedModelConfig> configs) {
118+
this.configs = configs;
119+
return this;
120+
}
121+
122+
public Response build() {
123+
return new Response(new QueryPage<>(configs, totalCount, RESULTS_FIELD));
124+
}
75125
}
76126
}
127+
77128
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,13 @@ public final class Messages {
8383
public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] already exists";
8484
public static final String INFERENCE_FAILED_TO_STORE_MODEL = "Failed to store trained machine learning model [{0}]";
8585
public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}]";
86+
public static final String INFERENCE_NOT_FOUND_MULTIPLE = "Could not find trained models {0}";
8687
public static final String INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION =
8788
"Configuration [{0}] requires minimum node version [{1}] (current minimum node version [{2}]";
8889
public static final String MODEL_DEFINITION_NOT_FOUND = "Could not find trained model definition [{0}]";
90+
public static final String INFERENCE_FAILED_TO_DESERIALIZE = "Could not deserialize trained model [{0}]";
91+
public static final String INFERENCE_TO_MANY_DEFINITIONS_REQUESTED =
92+
"Getting model definition is not supported when getting more than one model";
8993

9094
public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again";
9195
public static final String JOB_AUDIT_CREATED = "Job created";

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsRequestTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ public class GetTrainedModelsRequestTests extends AbstractWireSerializingTestCas
1414

1515
@Override
1616
protected Request createTestInstance() {
17-
Request request = new Request(randomAlphaOfLength(20));
17+
Request request = new Request(randomAlphaOfLength(20), randomBoolean());
1818
request.setPageParams(new PageParams(randomIntBetween(0, 100), randomIntBetween(0, 100)));
1919
return request;
2020
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@
55
*/
66
package org.elasticsearch.xpack.core.ml.inference;
77

8+
import org.elasticsearch.common.bytes.BytesReference;
89
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
910
import org.elasticsearch.common.io.stream.Writeable;
1011
import org.elasticsearch.common.settings.Settings;
1112
import org.elasticsearch.common.xcontent.DeprecationHandler;
1213
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
1314
import org.elasticsearch.common.xcontent.ToXContent;
1415
import org.elasticsearch.common.xcontent.XContentFactory;
16+
import org.elasticsearch.common.xcontent.XContentHelper;
17+
import org.elasticsearch.common.xcontent.XContentParseException;
1518
import org.elasticsearch.common.xcontent.XContentParser;
1619
import org.elasticsearch.common.xcontent.XContentType;
1720
import org.elasticsearch.search.SearchModule;
@@ -22,7 +25,6 @@
2225
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
2326
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
2427
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests;
25-
import org.junit.Before;
2628

2729
import java.io.IOException;
2830
import java.util.ArrayList;
@@ -33,27 +35,21 @@
3335
import java.util.stream.Stream;
3436

3537
import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE;
38+
import static org.hamcrest.Matchers.containsString;
3639
import static org.hamcrest.Matchers.equalTo;
3740
import static org.hamcrest.Matchers.greaterThan;
3841

3942

4043
public class TrainedModelDefinitionTests extends AbstractSerializingTestCase<TrainedModelDefinition> {
4144

42-
private boolean lenient;
43-
44-
@Before
45-
public void chooseStrictOrLenient() {
46-
lenient = randomBoolean();
47-
}
48-
4945
@Override
5046
protected TrainedModelDefinition doParseInstance(XContentParser parser) throws IOException {
51-
return TrainedModelDefinition.fromXContent(parser, lenient).build();
47+
return TrainedModelDefinition.fromXContent(parser, true).build();
5248
}
5349

5450
@Override
5551
protected boolean supportsUnknownFields() {
56-
return lenient;
52+
return true;
5753
}
5854

5955
@Override
@@ -63,7 +59,7 @@ protected Predicate<String> getRandomFieldsExcludeFilter() {
6359

6460
@Override
6561
protected ToXContent.Params getToXContentParams() {
66-
return lenient ? ToXContent.EMPTY_PARAMS : new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true"));
62+
return new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true"));
6763
}
6864

6965
@Override
@@ -286,9 +282,27 @@ public void testTreeSchemaDeserialization() throws IOException {
286282
assertThat(definition.getTrainedModel().getClass(), equalTo(Tree.class));
287283
}
288284

285+
public void testStrictParser() throws IOException {
286+
TrainedModelDefinition.Builder builder = createRandomBuilder("asdf");
287+
BytesReference reference = XContentHelper.toXContent(builder.build(),
288+
XContentType.JSON,
289+
new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true")),
290+
false);
291+
292+
XContentParser parser = XContentHelper.createParser(xContentRegistry(),
293+
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
294+
reference,
295+
XContentType.JSON);
296+
297+
XContentParseException exception = expectThrows(XContentParseException.class,
298+
() -> TrainedModelDefinition.fromXContent(parser, false));
299+
300+
assertThat(exception.getMessage(), containsString("[trained_model_definition] unknown field [doc_type]"));
301+
}
302+
289303
@Override
290304
protected TrainedModelDefinition createTestInstance() {
291-
return createRandomBuilder(null).build();
305+
return createRandomBuilder(randomAlphaOfLength(10)).build();
292306
}
293307

294308
@Override

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
2626
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
2727
import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner;
28+
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
2829
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
2930
import org.elasticsearch.xpack.ml.MachineLearning;
3031
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModelTests;
@@ -63,6 +64,11 @@ public void testGetTrainedModels() throws IOException {
6364
model1.setJsonEntity(buildRegressionModel(modelId));
6465
assertThat(client().performRequest(model1).getStatusLine().getStatusCode(), equalTo(201));
6566

67+
Request modelDefinition1 = new Request("PUT",
68+
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + TrainedModelDefinition.docId(modelId));
69+
modelDefinition1.setJsonEntity(buildRegressionModelDefinition(modelId));
70+
assertThat(client().performRequest(modelDefinition1).getStatusLine().getStatusCode(), equalTo(201));
71+
6672
Request model2 = new Request("PUT",
6773
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId2);
6874
model2.setJsonEntity(buildRegressionModel(modelId2));
@@ -85,8 +91,26 @@ public void testGetTrainedModels() throws IOException {
8591
response = EntityUtils.toString(getModel.getEntity());
8692
assertThat(response, containsString("\"model_id\":\"test_regression_model\""));
8793
assertThat(response, containsString("\"model_id\":\"test_regression_model-2\""));
94+
assertThat(response, not(containsString("\"definition\"")));
8895
assertThat(response, containsString("\"count\":2"));
8996

97+
getModel = client().performRequest(new Request("GET",
98+
MachineLearning.BASE_PATH + "inference/test_regression_model?human=true&include_model_definition=true"));
99+
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
100+
101+
response = EntityUtils.toString(getModel.getEntity());
102+
assertThat(response, containsString("\"model_id\":\"test_regression_model\""));
103+
assertThat(response, containsString("\"heap_memory_estimation_bytes\""));
104+
assertThat(response, containsString("\"heap_memory_estimation\""));
105+
assertThat(response, containsString("\"definition\""));
106+
assertThat(response, containsString("\"count\":1"));
107+
108+
ResponseException responseException = expectThrows(ResponseException.class, () ->
109+
client().performRequest(new Request("GET",
110+
MachineLearning.BASE_PATH + "inference/test_regression*?human=true&include_model_definition=true")));
111+
assertThat(EntityUtils.toString(responseException.getResponse().getEntity()),
112+
containsString(Messages.INFERENCE_TO_MANY_DEFINITIONS_REQUESTED));
113+
90114
getModel = client().performRequest(new Request("GET",
91115
MachineLearning.BASE_PATH + "inference/test_regression_model,test_regression_model-2"));
92116
assertThat(getModel.getStatusLine().getStatusCode(), equalTo(200));
@@ -131,6 +155,11 @@ public void testDeleteTrainedModels() throws IOException {
131155
model1.setJsonEntity(buildRegressionModel(modelId));
132156
assertThat(client().performRequest(model1).getStatusLine().getStatusCode(), equalTo(201));
133157

158+
Request modelDefinition1 = new Request("PUT",
159+
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + TrainedModelDefinition.docId(modelId));
160+
modelDefinition1.setJsonEntity(buildRegressionModelDefinition(modelId));
161+
assertThat(client().performRequest(modelDefinition1).getStatusLine().getStatusCode(), equalTo(201));
162+
134163
adminClient().performRequest(new Request("POST", InferenceIndexConstants.LATEST_INDEX_NAME + "/_refresh"));
135164

136165
Response delModel = client().performRequest(new Request("DELETE",
@@ -141,6 +170,18 @@ public void testDeleteTrainedModels() throws IOException {
141170
ResponseException responseException = expectThrows(ResponseException.class,
142171
() -> client().performRequest(new Request("DELETE", MachineLearning.BASE_PATH + "inference/" + modelId)));
143172
assertThat(responseException.getResponse().getStatusLine().getStatusCode(), equalTo(404));
173+
174+
responseException = expectThrows(ResponseException.class,
175+
() -> client().performRequest(
176+
new Request("GET",
177+
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + TrainedModelDefinition.docId(modelId))));
178+
assertThat(responseException.getResponse().getStatusLine().getStatusCode(), equalTo(404));
179+
180+
responseException = expectThrows(ResponseException.class,
181+
() -> client().performRequest(
182+
new Request("GET",
183+
InferenceIndexConstants.LATEST_INDEX_NAME + "/_doc/" + modelId)));
184+
assertThat(responseException.getResponse().getStatusLine().getStatusCode(), equalTo(404));
144185
}
145186

146187
private static String buildRegressionModel(String modelId) throws IOException {
@@ -149,9 +190,6 @@ private static String buildRegressionModel(String modelId) throws IOException {
149190
.setModelId(modelId)
150191
.setInput(new TrainedModelInput(Arrays.asList("col1", "col2", "col3")))
151192
.setCreatedBy("ml_test")
152-
.setDefinition(new TrainedModelDefinition.Builder()
153-
.setPreProcessors(Collections.emptyList())
154-
.setTrainedModel(LocalModelTests.buildRegression()))
155193
.setVersion(Version.CURRENT)
156194
.setCreateTime(Instant.now())
157195
.build()
@@ -160,6 +198,18 @@ private static String buildRegressionModel(String modelId) throws IOException {
160198
}
161199
}
162200

201+
private static String buildRegressionModelDefinition(String modelId) throws IOException {
202+
try(XContentBuilder builder = XContentFactory.jsonBuilder()) {
203+
new TrainedModelDefinition.Builder()
204+
.setPreProcessors(Collections.emptyList())
205+
.setTrainedModel(LocalModelTests.buildRegression())
206+
.setModelId(modelId)
207+
.build()
208+
.toXContent(builder, new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")));
209+
return XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON);
210+
}
211+
}
212+
163213

164214
@After
165215
public void clearMlState() throws Exception {

0 commit comments

Comments
 (0)