2525import org .elasticsearch .xpack .core .ml .inference .TrainedModelInput ;
2626import org .elasticsearch .xpack .core .ml .inference .persistence .InferenceIndexConstants ;
2727import org .elasticsearch .xpack .core .ml .integration .MlRestTestStateCleaner ;
28+ import org .elasticsearch .xpack .core .ml .job .messages .Messages ;
2829import org .elasticsearch .xpack .core .ml .utils .ToXContentParams ;
2930import org .elasticsearch .xpack .ml .MachineLearning ;
3031import org .elasticsearch .xpack .ml .inference .loadingservice .LocalModelTests ;
@@ -63,6 +64,11 @@ public void testGetTrainedModels() throws IOException {
6364 model1 .setJsonEntity (buildRegressionModel (modelId ));
6465 assertThat (client ().performRequest (model1 ).getStatusLine ().getStatusCode (), equalTo (201 ));
6566
67+ Request modelDefinition1 = new Request ("PUT" ,
68+ InferenceIndexConstants .LATEST_INDEX_NAME + "/_doc/" + TrainedModelDefinition .docId (modelId ));
69+ modelDefinition1 .setJsonEntity (buildRegressionModelDefinition (modelId ));
70+ assertThat (client ().performRequest (modelDefinition1 ).getStatusLine ().getStatusCode (), equalTo (201 ));
71+
6672 Request model2 = new Request ("PUT" ,
6773 InferenceIndexConstants .LATEST_INDEX_NAME + "/_doc/" + modelId2 );
6874 model2 .setJsonEntity (buildRegressionModel (modelId2 ));
@@ -85,8 +91,26 @@ public void testGetTrainedModels() throws IOException {
8591 response = EntityUtils .toString (getModel .getEntity ());
8692 assertThat (response , containsString ("\" model_id\" :\" test_regression_model\" " ));
8793 assertThat (response , containsString ("\" model_id\" :\" test_regression_model-2\" " ));
94+ assertThat (response , not (containsString ("\" definition\" " )));
8895 assertThat (response , containsString ("\" count\" :2" ));
8996
97+ getModel = client ().performRequest (new Request ("GET" ,
98+ MachineLearning .BASE_PATH + "inference/test_regression_model?human=true&include_model_definition=true" ));
99+ assertThat (getModel .getStatusLine ().getStatusCode (), equalTo (200 ));
100+
101+ response = EntityUtils .toString (getModel .getEntity ());
102+ assertThat (response , containsString ("\" model_id\" :\" test_regression_model\" " ));
103+ assertThat (response , containsString ("\" heap_memory_estimation_bytes\" " ));
104+ assertThat (response , containsString ("\" heap_memory_estimation\" " ));
105+ assertThat (response , containsString ("\" definition\" " ));
106+ assertThat (response , containsString ("\" count\" :1" ));
107+
108+ ResponseException responseException = expectThrows (ResponseException .class , () ->
109+ client ().performRequest (new Request ("GET" ,
110+ MachineLearning .BASE_PATH + "inference/test_regression*?human=true&include_model_definition=true" )));
111+ assertThat (EntityUtils .toString (responseException .getResponse ().getEntity ()),
112+ containsString (Messages .INFERENCE_TO_MANY_DEFINITIONS_REQUESTED ));
113+
90114 getModel = client ().performRequest (new Request ("GET" ,
91115 MachineLearning .BASE_PATH + "inference/test_regression_model,test_regression_model-2" ));
92116 assertThat (getModel .getStatusLine ().getStatusCode (), equalTo (200 ));
@@ -131,6 +155,11 @@ public void testDeleteTrainedModels() throws IOException {
131155 model1 .setJsonEntity (buildRegressionModel (modelId ));
132156 assertThat (client ().performRequest (model1 ).getStatusLine ().getStatusCode (), equalTo (201 ));
133157
158+ Request modelDefinition1 = new Request ("PUT" ,
159+ InferenceIndexConstants .LATEST_INDEX_NAME + "/_doc/" + TrainedModelDefinition .docId (modelId ));
160+ modelDefinition1 .setJsonEntity (buildRegressionModelDefinition (modelId ));
161+ assertThat (client ().performRequest (modelDefinition1 ).getStatusLine ().getStatusCode (), equalTo (201 ));
162+
134163 adminClient ().performRequest (new Request ("POST" , InferenceIndexConstants .LATEST_INDEX_NAME + "/_refresh" ));
135164
136165 Response delModel = client ().performRequest (new Request ("DELETE" ,
@@ -141,6 +170,18 @@ public void testDeleteTrainedModels() throws IOException {
141170 ResponseException responseException = expectThrows (ResponseException .class ,
142171 () -> client ().performRequest (new Request ("DELETE" , MachineLearning .BASE_PATH + "inference/" + modelId )));
143172 assertThat (responseException .getResponse ().getStatusLine ().getStatusCode (), equalTo (404 ));
173+
174+ responseException = expectThrows (ResponseException .class ,
175+ () -> client ().performRequest (
176+ new Request ("GET" ,
177+ InferenceIndexConstants .LATEST_INDEX_NAME + "/_doc/" + TrainedModelDefinition .docId (modelId ))));
178+ assertThat (responseException .getResponse ().getStatusLine ().getStatusCode (), equalTo (404 ));
179+
180+ responseException = expectThrows (ResponseException .class ,
181+ () -> client ().performRequest (
182+ new Request ("GET" ,
183+ InferenceIndexConstants .LATEST_INDEX_NAME + "/_doc/" + modelId )));
184+ assertThat (responseException .getResponse ().getStatusLine ().getStatusCode (), equalTo (404 ));
144185 }
145186
146187 private static String buildRegressionModel (String modelId ) throws IOException {
@@ -149,9 +190,6 @@ private static String buildRegressionModel(String modelId) throws IOException {
149190 .setModelId (modelId )
150191 .setInput (new TrainedModelInput (Arrays .asList ("col1" , "col2" , "col3" )))
151192 .setCreatedBy ("ml_test" )
152- .setDefinition (new TrainedModelDefinition .Builder ()
153- .setPreProcessors (Collections .emptyList ())
154- .setTrainedModel (LocalModelTests .buildRegression ()))
155193 .setVersion (Version .CURRENT )
156194 .setCreateTime (Instant .now ())
157195 .build ()
@@ -160,6 +198,18 @@ private static String buildRegressionModel(String modelId) throws IOException {
160198 }
161199 }
162200
201+ private static String buildRegressionModelDefinition (String modelId ) throws IOException {
202+ try (XContentBuilder builder = XContentFactory .jsonBuilder ()) {
203+ new TrainedModelDefinition .Builder ()
204+ .setPreProcessors (Collections .emptyList ())
205+ .setTrainedModel (LocalModelTests .buildRegression ())
206+ .setModelId (modelId )
207+ .build ()
208+ .toXContent (builder , new ToXContent .MapParams (Collections .singletonMap (ToXContentParams .FOR_INTERNAL_STORAGE , "true" )));
209+ return XContentHelper .convertToJson (BytesReference .bytes (builder ), false , XContentType .JSON );
210+ }
211+ }
212+
163213
164214 @ After
165215 public void clearMlState () throws Exception {
0 commit comments