Skip to content

Commit e372854

Browse files
authored
[ML][Inference] Fix model pagination with models as resources (#51573) (#51736)
This adds logic to handle paging problems when the ID pattern + tags reference models stored as resources. Most of the complexity comes from the issue where a model stored as a resource could be at the start, or the end of a page or when we are on the last page.
1 parent dfc9f23 commit e372854

File tree

3 files changed

+232
-26
lines changed

3 files changed

+232
-26
lines changed

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java

Lines changed: 55 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import org.elasticsearch.action.support.WriteRequest;
2929
import org.elasticsearch.client.Client;
3030
import org.elasticsearch.common.CheckedBiFunction;
31-
import org.elasticsearch.common.Nullable;
3231
import org.elasticsearch.common.Strings;
3332
import org.elasticsearch.common.bytes.BytesReference;
3433
import org.elasticsearch.common.collect.Tuple;
@@ -74,10 +73,10 @@
7473
import java.util.Collections;
7574
import java.util.Comparator;
7675
import java.util.HashSet;
77-
import java.util.LinkedHashSet;
7876
import java.util.List;
7977
import java.util.Map;
8078
import java.util.Set;
79+
import java.util.TreeSet;
8180

8281
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
8382
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
@@ -382,19 +381,34 @@ public void deleteTrainedModel(String modelId, ActionListener<Boolean> listener)
382381

383382
public void expandIds(String idExpression,
384383
boolean allowNoResources,
385-
@Nullable PageParams pageParams,
384+
PageParams pageParams,
386385
Set<String> tags,
387386
ActionListener<Tuple<Long, Set<String>>> idsListener) {
388387
String[] tokens = Strings.tokenizeToStringArray(idExpression, ",");
388+
Set<String> matchedResourceIds = matchedResourceIds(tokens);
389+
Set<String> foundResourceIds;
390+
if (tags.isEmpty()) {
391+
foundResourceIds = matchedResourceIds;
392+
} else {
393+
foundResourceIds = new HashSet<>();
394+
for(String resourceId : matchedResourceIds) {
395+
// Does the model as a resource have all the tags?
396+
if (Sets.newHashSet(loadModelFromResource(resourceId, true).getTags()).containsAll(tags)) {
397+
foundResourceIds.add(resourceId);
398+
}
399+
}
400+
}
389401
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder()
390402
.sort(SortBuilders.fieldSort(TrainedModelConfig.MODEL_ID.getPreferredName())
391403
// If there are no resources, there might be no mapping for the id field.
392404
// This makes sure we don't get an error if that happens.
393405
.unmappedType("long"))
394-
.query(buildExpandIdsQuery(tokens, tags));
395-
if (pageParams != null) {
396-
sourceBuilder.from(pageParams.getFrom()).size(pageParams.getSize());
397-
}
406+
.query(buildExpandIdsQuery(tokens, tags))
407+
// We "buffer" the from and size to take into account models stored as resources.
408+
// This is so we handle the edge cases when the model that is stored as a resource is at the start/end of
409+
// a page.
410+
.from(Math.max(0, pageParams.getFrom() - foundResourceIds.size()))
411+
.size(Math.min(10_000, pageParams.getSize() + foundResourceIds.size()));
398412
sourceBuilder.trackTotalHits(true)
399413
// we only care about the item id's
400414
.fetchSource(TrainedModelConfig.MODEL_ID.getPreferredName(), null);
@@ -407,47 +421,65 @@ public void expandIds(String idExpression,
407421
indicesOptions.expandWildcardsClosed(),
408422
indicesOptions))
409423
.source(sourceBuilder);
410-
Set<String> foundResourceIds = new LinkedHashSet<>();
411-
if (tags.isEmpty()) {
412-
foundResourceIds.addAll(matchedResourceIds(tokens));
413-
} else {
414-
for(String resourceId : matchedResourceIds(tokens)) {
415-
// Does the model as a resource have all the tags?
416-
if (Sets.newHashSet(loadModelFromResource(resourceId, true).getTags()).containsAll(tags)) {
417-
foundResourceIds.add(resourceId);
418-
}
419-
}
420-
}
421424

422425
executeAsyncWithOrigin(client.threadPool().getThreadContext(),
423426
ML_ORIGIN,
424427
searchRequest,
425428
ActionListener.<SearchResponse>wrap(
426429
response -> {
427430
long totalHitCount = response.getHits().getTotalHits().value + foundResourceIds.size();
431+
Set<String> foundFromDocs = new HashSet<>();
428432
for (SearchHit hit : response.getHits().getHits()) {
429433
Map<String, Object> docSource = hit.getSourceAsMap();
430434
if (docSource == null) {
431435
continue;
432436
}
433437
Object idValue = docSource.get(TrainedModelConfig.MODEL_ID.getPreferredName());
434438
if (idValue instanceof String) {
435-
foundResourceIds.add(idValue.toString());
439+
foundFromDocs.add(idValue.toString());
436440
}
437441
}
442+
Set<String> allFoundIds = collectIds(pageParams, foundResourceIds, foundFromDocs);
438443
ExpandedIdsMatcher requiredMatches = new ExpandedIdsMatcher(tokens, allowNoResources);
439-
requiredMatches.filterMatchedIds(foundResourceIds);
444+
requiredMatches.filterMatchedIds(allFoundIds);
440445
if (requiredMatches.hasUnmatchedIds()) {
441446
idsListener.onFailure(ExceptionsHelper.missingTrainedModel(requiredMatches.unmatchedIdsString()));
442447
} else {
443-
idsListener.onResponse(Tuple.tuple(totalHitCount, foundResourceIds));
448+
449+
idsListener.onResponse(Tuple.tuple(totalHitCount, allFoundIds));
444450
}
445451
},
446452
idsListener::onFailure
447453
),
448454
client::search);
449455
}
450456

457+
static Set<String> collectIds(PageParams pageParams, Set<String> foundFromResources, Set<String> foundFromDocs) {
458+
// If there are no matching resource models, there was no buffering and the models from the docs
459+
// are paginated correctly.
460+
if (foundFromResources.isEmpty()) {
461+
return foundFromDocs;
462+
}
463+
464+
TreeSet<String> allFoundIds = new TreeSet<>(foundFromDocs);
465+
allFoundIds.addAll(foundFromResources);
466+
467+
if (pageParams.getFrom() > 0) {
468+
// not the first page so there will be extra results at the front to remove
469+
int numToTrimFromFront = Math.min(foundFromResources.size(), pageParams.getFrom());
470+
for (int i = 0; i < numToTrimFromFront; i++) {
471+
allFoundIds.remove(allFoundIds.first());
472+
}
473+
}
474+
475+
// trim down to size removing from the rear
476+
while (allFoundIds.size() > pageParams.getSize()) {
477+
allFoundIds.remove(allFoundIds.last());
478+
}
479+
480+
return allFoundIds;
481+
}
482+
451483
static QueryBuilder buildExpandIdsQuery(String[] tokens, Collection<String> tags) {
452484
BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery()
453485
.filter(buildQueryIdExpressionQuery(tokens, TrainedModelConfig.MODEL_ID.getPreferredName()));
@@ -518,7 +550,7 @@ private static QueryBuilder buildQueryIdExpressionQuery(String[] tokens, String
518550

519551
private Set<String> matchedResourceIds(String[] tokens) {
520552
if (Strings.isAllOrWildcard(tokens)) {
521-
return new HashSet<>(MODELS_STORED_AS_RESOURCE);
553+
return MODELS_STORED_AS_RESOURCE;
522554
}
523555

524556
Set<String> matchedModels = new HashSet<>();
@@ -536,7 +568,7 @@ private Set<String> matchedResourceIds(String[] tokens) {
536568
}
537569
}
538570
}
539-
return matchedModels;
571+
return Collections.unmodifiableSet(matchedModels);
540572
}
541573

542574
private static <T> T handleSearchItem(MultiSearchResponse.Item item,

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,16 @@
1414
import org.elasticsearch.index.query.QueryBuilder;
1515
import org.elasticsearch.index.query.TermQueryBuilder;
1616
import org.elasticsearch.test.ESTestCase;
17+
import org.elasticsearch.xpack.core.action.util.PageParams;
1718
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
1819
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
1920
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests;
2021
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
2122

2223
import java.util.Arrays;
24+
import java.util.Collections;
25+
import java.util.HashSet;
26+
import java.util.TreeSet;
2327

2428
import static org.hamcrest.Matchers.equalTo;
2529
import static org.hamcrest.Matchers.instanceOf;
@@ -86,6 +90,50 @@ public void testExpandIdsQuery() {
8690
});
8791
}
8892

93+
public void testExpandIdsPagination() {
94+
// NOTE: these tests assume that the query pagination results are "buffered"
95+
96+
assertThat(TrainedModelProvider.collectIds(new PageParams(0, 3),
97+
Collections.emptySet(),
98+
new HashSet<>(Arrays.asList("a", "b", "c"))),
99+
equalTo(new TreeSet<>(Arrays.asList("a", "b", "c"))));
100+
101+
assertThat(TrainedModelProvider.collectIds(new PageParams(0, 3),
102+
Collections.singleton("a"),
103+
new HashSet<>(Arrays.asList("b", "c", "d"))),
104+
equalTo(new TreeSet<>(Arrays.asList("a", "b", "c"))));
105+
106+
assertThat(TrainedModelProvider.collectIds(new PageParams(1, 3),
107+
Collections.singleton("a"),
108+
new HashSet<>(Arrays.asList("b", "c", "d"))),
109+
equalTo(new TreeSet<>(Arrays.asList("b", "c", "d"))));
110+
111+
assertThat(TrainedModelProvider.collectIds(new PageParams(1, 1),
112+
Collections.singleton("c"),
113+
new HashSet<>(Arrays.asList("a", "b"))),
114+
equalTo(new TreeSet<>(Arrays.asList("b"))));
115+
116+
assertThat(TrainedModelProvider.collectIds(new PageParams(1, 1),
117+
Collections.singleton("b"),
118+
new HashSet<>(Arrays.asList("a", "c"))),
119+
equalTo(new TreeSet<>(Arrays.asList("b"))));
120+
121+
assertThat(TrainedModelProvider.collectIds(new PageParams(1, 2),
122+
new HashSet<>(Arrays.asList("a", "b")),
123+
new HashSet<>(Arrays.asList("c", "d", "e"))),
124+
equalTo(new TreeSet<>(Arrays.asList("b", "c"))));
125+
126+
assertThat(TrainedModelProvider.collectIds(new PageParams(1, 3),
127+
new HashSet<>(Arrays.asList("a", "b")),
128+
new HashSet<>(Arrays.asList("c", "d", "e"))),
129+
equalTo(new TreeSet<>(Arrays.asList("b", "c", "d"))));
130+
131+
assertThat(TrainedModelProvider.collectIds(new PageParams(2, 3),
132+
new HashSet<>(Arrays.asList("a", "b")),
133+
new HashSet<>(Arrays.asList("c", "d", "e"))),
134+
equalTo(new TreeSet<>(Arrays.asList("c", "d", "e"))));
135+
}
136+
89137
public void testGetModelThatExistsAsResourceButIsMissing() {
90138
TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry());
91139
ElasticsearchException ex = expectThrows(ElasticsearchException.class,

x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml

Lines changed: 129 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,56 @@ setup:
7272
}
7373
}
7474
}
75+
76+
- do:
77+
headers:
78+
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
79+
ml.put_trained_model:
80+
model_id: yyy-classification-model
81+
body: >
82+
{
83+
"description": "empty model for tests",
84+
"input": {"field_names": ["field1", "field2"]},
85+
"tags": ["classification", "tag3"],
86+
"definition": {
87+
"preprocessors": [],
88+
"trained_model": {
89+
"tree": {
90+
"feature_names": ["field1", "field2"],
91+
"tree_structure": [
92+
{"node_index": 0, "leaf_value": 1}
93+
],
94+
"target_type": "classification",
95+
"classification_labels": ["no", "yes"]
96+
}
97+
}
98+
}
99+
}
100+
101+
- do:
102+
headers:
103+
Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser
104+
ml.put_trained_model:
105+
model_id: zzz-classification-model
106+
body: >
107+
{
108+
"description": "empty model for tests",
109+
"input": {"field_names": ["field1", "field2"]},
110+
"tags": ["classification", "tag3"],
111+
"definition": {
112+
"preprocessors": [],
113+
"trained_model": {
114+
"tree": {
115+
"feature_names": ["field1", "field2"],
116+
"tree_structure": [
117+
{"node_index": 0, "leaf_value": 1}
118+
],
119+
"target_type": "classification",
120+
"classification_labels": ["no", "yes"]
121+
}
122+
}
123+
}
124+
}
75125
---
76126
"Test get given missing trained model":
77127

@@ -102,15 +152,20 @@ setup:
102152
- do:
103153
ml.get_trained_models:
104154
model_id: "*"
105-
- match: { count: 4 }
155+
- match: { count: 6 }
156+
- length: { trained_model_configs: 6 }
106157
- match: { trained_model_configs.0.model_id: "a-classification-model" }
107158
- match: { trained_model_configs.1.model_id: "a-regression-model-0" }
108159
- match: { trained_model_configs.2.model_id: "a-regression-model-1" }
160+
- match: { trained_model_configs.3.model_id: "lang_ident_model_1" }
161+
- match: { trained_model_configs.4.model_id: "yyy-classification-model" }
162+
- match: { trained_model_configs.5.model_id: "zzz-classification-model" }
109163

110164
- do:
111165
ml.get_trained_models:
112166
model_id: "a-regression*"
113167
- match: { count: 2 }
168+
- length: { trained_model_configs: 2 }
114169
- match: { trained_model_configs.0.model_id: "a-regression-model-0" }
115170
- match: { trained_model_configs.1.model_id: "a-regression-model-1" }
116171

@@ -119,7 +174,8 @@ setup:
119174
model_id: "*"
120175
from: 0
121176
size: 2
122-
- match: { count: 4 }
177+
- match: { count: 6 }
178+
- length: { trained_model_configs: 2 }
123179
- match: { trained_model_configs.0.model_id: "a-classification-model" }
124180
- match: { trained_model_configs.1.model_id: "a-regression-model-0" }
125181

@@ -128,8 +184,78 @@ setup:
128184
model_id: "*"
129185
from: 1
130186
size: 1
131-
- match: { count: 4 }
187+
- match: { count: 6 }
188+
- length: { trained_model_configs: 1 }
132189
- match: { trained_model_configs.0.model_id: "a-regression-model-0" }
190+
191+
- do:
192+
ml.get_trained_models:
193+
model_id: "*"
194+
from: 2
195+
size: 2
196+
- match: { count: 6 }
197+
- length: { trained_model_configs: 2 }
198+
- match: { trained_model_configs.0.model_id: "a-regression-model-1" }
199+
- match: { trained_model_configs.1.model_id: "lang_ident_model_1" }
200+
201+
- do:
202+
ml.get_trained_models:
203+
model_id: "*"
204+
from: 3
205+
size: 1
206+
- match: { count: 6 }
207+
- length: { trained_model_configs: 1 }
208+
- match: { trained_model_configs.0.model_id: "lang_ident_model_1" }
209+
210+
- do:
211+
ml.get_trained_models:
212+
model_id: "*"
213+
from: 3
214+
size: 2
215+
- match: { count: 6 }
216+
- length: { trained_model_configs: 2 }
217+
- match: { trained_model_configs.0.model_id: "lang_ident_model_1" }
218+
- match: { trained_model_configs.1.model_id: "yyy-classification-model" }
219+
220+
- do:
221+
ml.get_trained_models:
222+
model_id: "*"
223+
from: 4
224+
size: 2
225+
- match: { count: 6 }
226+
- length: { trained_model_configs: 2 }
227+
- match: { trained_model_configs.0.model_id: "yyy-classification-model" }
228+
- match: { trained_model_configs.1.model_id: "zzz-classification-model" }
229+
230+
- do:
231+
ml.get_trained_models:
232+
model_id: "a-*,lang*,zzz*"
233+
allow_no_match: true
234+
from: 3
235+
size: 1
236+
- match: { count: 5 }
237+
- length: { trained_model_configs: 1 }
238+
- match: { trained_model_configs.0.model_id: "lang_ident_model_1" }
239+
240+
- do:
241+
ml.get_trained_models:
242+
model_id: "a-*,lang*,zzz*"
243+
allow_no_match: true
244+
from: 4
245+
size: 1
246+
- match: { count: 5 }
247+
- length: { trained_model_configs: 1 }
248+
- match: { trained_model_configs.0.model_id: "zzz-classification-model" }
249+
250+
- do:
251+
ml.get_trained_models:
252+
model_id: "a-*,lang*,zzz*"
253+
from: 4
254+
size: 100
255+
- match: { count: 5 }
256+
- length: { trained_model_configs: 1 }
257+
- match: { trained_model_configs.0.model_id: "zzz-classification-model" }
258+
133259
---
134260
"Test get models with tags":
135261
- do:

0 commit comments

Comments
 (0)