Skip to content

Commit 03df107

Browse files
authored
[ML][Inference] Adds validations for model PUT (#51376) (#51410)
Adds validations making sure that * `input.field_names` is not empty * `ensemble.trained_models` is not empty * `tree.feature_names` is not empty closes #51354
1 parent 2cbef17 commit 03df107

File tree

8 files changed

+133
-1
lines changed

8 files changed

+133
-1
lines changed

client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelDefinitionTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
import java.io.IOException;
3434
import java.util.ArrayList;
35+
import java.util.Arrays;
3536
import java.util.Collections;
3637
import java.util.List;
3738
import java.util.function.Predicate;
@@ -70,7 +71,7 @@ public static TrainedModelDefinition.Builder createRandomBuilder(TargetType targ
7071
TargetMeanEncodingTests.createRandom()))
7172
.limit(numberOfProcessors)
7273
.collect(Collectors.toList()))
73-
.setTrainedModel(randomFrom(TreeTests.buildRandomTree(Collections.emptyList(), 6, targetType),
74+
.setTrainedModel(randomFrom(TreeTests.buildRandomTree(Arrays.asList("foo", "bar"), 6, targetType),
7475
EnsembleTests.createRandom(targetType)));
7576
}
7677

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,9 @@ public Builder validate(boolean forCreation) {
535535
break;
536536
}
537537
}
538+
if (input != null && input.getFieldNames().isEmpty()) {
539+
validationException = addValidationError("[input.field_names] must not be empty", validationException);
540+
}
538541
if (forCreation) {
539542
validationException = checkIllegalSetting(version, VERSION.getPreferredName(), validationException);
540543
validationException = checkIllegalSetting(createdBy, CREATED_BY.getPreferredName(), validationException);

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,10 @@ public int hashCode() {
250250

251251
@Override
252252
public void validate() {
253+
if (this.models.isEmpty()) {
254+
throw ExceptionsHelper.badRequestException("[{}] must not be empty", TRAINED_MODELS.getPreferredName());
255+
}
256+
253257
if (outputAggregator.compatibleWith(targetType) == false) {
254258
throw ExceptionsHelper.badRequestException(
255259
"aggregate_output [{}] is not compatible with target_type [{}]",

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,9 @@ public static Builder builder() {
253253

254254
@Override
255255
public void validate() {
256+
if (featureNames.isEmpty()) {
257+
throw ExceptionsHelper.badRequestException("[{}] must not be empty for tree model", FEATURE_NAMES.getPreferredName());
258+
}
256259
checkTargetType();
257260
detectMissingNodes();
258261
detectCycle();

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import java.io.IOException;
2727
import java.util.ArrayList;
2828
import java.util.Arrays;
29+
import java.util.Collections;
2930
import java.util.HashMap;
3031
import java.util.List;
3132
import java.util.Map;
@@ -202,6 +203,14 @@ public void testEnsembleWithTargetTypeAndLabelsMismatch() {
202203
assertThat(ex.getMessage(), equalTo(msg));
203204
}
204205

206+
public void testEnsembleWithEmptyModels() {
207+
List<String> featureNames = Arrays.asList("foo", "bar");
208+
ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> {
209+
Ensemble.builder().setTrainedModels(Collections.emptyList()).setFeatureNames(featureNames).build().validate();
210+
});
211+
assertThat(ex.getMessage(), equalTo("[trained_models] must not be empty"));
212+
}
213+
205214
public void testClassificationProbability() {
206215
List<String> featureNames = Arrays.asList("foo", "bar");
207216
Tree tree1 = Tree.builder()

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,21 @@ public void testTreeWithTargetTypeAndLabelsMismatch() {
339339
assertThat(ex.getMessage(), equalTo(msg));
340340
}
341341

342+
public void testTreeWithEmptyFeatureNames() {
343+
String msg = "[feature_names] must not be empty for tree model";
344+
ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> {
345+
Tree.builder()
346+
.setRoot(TreeNode.builder(0)
347+
.setLeftChild(1)
348+
.setSplitFeature(1)
349+
.setThreshold(randomDouble()))
350+
.setFeatureNames(Collections.emptyList())
351+
.build()
352+
.validate();
353+
});
354+
assertThat(ex.getMessage(), equalTo(msg));
355+
}
356+
342357
public void testOperationsEstimations() {
343358
Tree tree = buildRandomTree(Arrays.asList("foo", "bar", "baz"), 5);
344359
assertThat(tree.estimatedNumOperations(), equalTo(7L));

x-pack/plugin/ml/qa/ml-with-security/build.gradle

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,9 @@ integTest.runner {
136136
'ml/inference_crud/Test delete with missing model',
137137
'ml/inference_crud/Test get given missing trained model',
138138
'ml/inference_crud/Test get given expression without matches and allow_no_match is false',
139+
'ml/inference_crud/Test put ensemble with empty models',
140+
'ml/inference_crud/Test put ensemble with tree where tree has empty feature-names',
141+
'ml/inference_crud/Test put model with empty input.field_names',
139142
'ml/inference_stats_crud/Test get stats given missing trained model',
140143
'ml/inference_stats_crud/Test get stats given expression without matches and allow_no_match is false',
141144
'ml/jobs_crud/Test cannot create job with existing categorizer state document',

x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,97 @@ setup:
171171
allow_no_match: false
172172
- match: { count: 1 }
173173
- match: { trained_model_configs.0.model_id: "lang_ident_model_1" }
174+
---
175+
"Test put ensemble with empty models":
176+
- do:
177+
catch: /\[trained_models\] must not be empty/
178+
ml.put_trained_model:
179+
model_id: "missing_model_ensembles"
180+
body: >
181+
{
182+
"input": {
183+
"field_names": "fieldy_mc_fieldname"
184+
},
185+
"definition": {
186+
"trained_model": {
187+
"ensemble": {
188+
"feature_names": [],
189+
"trained_models": []
190+
}
191+
}
192+
}
193+
}
194+
---
195+
"Test put ensemble with tree where tree has empty feature-names":
196+
- do:
197+
catch: /\[feature_names\] must not be empty/
198+
ml.put_trained_model:
199+
model_id: "ensemble_tree_missing_feature_names"
200+
body: >
201+
{
202+
"input": {
203+
"field_names": "fieldy_mc_fieldname"
204+
},
205+
"definition": {
206+
"trained_model": {
207+
"ensemble": {
208+
"feature_names": [],
209+
"trained_models": [
210+
{
211+
"tree": {
212+
"feature_names": [],
213+
"tree_structure": [
214+
{
215+
"node_index": 0,
216+
"split_feature": 0,
217+
"split_gain": 12.0,
218+
"threshold": 10.0,
219+
"decision_type": "lte",
220+
"default_left": true,
221+
"left_child": 1,
222+
"right_child": 2
223+
}]
224+
}
225+
}
226+
]
227+
}
228+
}
229+
}
230+
}
231+
---
232+
"Test put model with empty input.field_names":
233+
- do:
234+
catch: /\[input\.field_names\] must not be empty/
235+
ml.put_trained_model:
236+
model_id: "missing_model_ensembles"
237+
body: >
238+
{
239+
"input": {
240+
"field_names": []
241+
},
242+
"definition": {
243+
"trained_model": {
244+
"ensemble": {
245+
"feature_names": [],
246+
"trained_models": [
247+
{
248+
"tree": {
249+
"feature_names": [],
250+
"tree_structure": [
251+
{
252+
"node_index": 0,
253+
"split_feature": 0,
254+
"split_gain": 12.0,
255+
"threshold": 10.0,
256+
"decision_type": "lte",
257+
"default_left": true,
258+
"left_child": 1,
259+
"right_child": 2
260+
}]
261+
}
262+
}
263+
]
264+
}
265+
}
266+
}
267+
}

0 commit comments

Comments
 (0)