From b8c0d2e9d19621c48daadbba1f5f8015b686dca8 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Tue, 28 Jan 2020 16:30:53 -0500 Subject: [PATCH 1/4] [ML][Inference] Fix model pagination with models as resources --- .../persistence/TrainedModelProvider.java | 80 ++++++++---- .../TrainedModelProviderTests.java | 61 +++++++++ .../rest-api-spec/test/ml/inference_crud.yml | 118 +++++++++++++++++- 3 files changed, 235 insertions(+), 24 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index a549e60e0bd4f..956727cce39d9 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -28,7 +28,6 @@ import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.client.Client; import org.elasticsearch.common.CheckedBiFunction; -import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.collect.Tuple; @@ -73,10 +72,10 @@ import java.util.Collections; import java.util.Comparator; import java.util.HashSet; -import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.TreeSet; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; @@ -381,19 +380,32 @@ public void deleteTrainedModel(String modelId, ActionListener listener) public void expandIds(String idExpression, boolean allowNoResources, - @Nullable PageParams pageParams, + PageParams pageParams, Set tags, ActionListener>> idsListener) { String[] tokens = Strings.tokenizeToStringArray(idExpression, ","); + Set foundResourceIds = new HashSet<>(); + if (tags.isEmpty()) { + foundResourceIds.addAll(matchedResourceIds(tokens)); + } else { + for(String resourceId : matchedResourceIds(tokens)) { + // Does the model as a resource have all the tags? + if (Sets.newHashSet(loadModelFromResource(resourceId, true).getTags()).containsAll(tags)) { + foundResourceIds.add(resourceId); + } + } + } SearchSourceBuilder sourceBuilder = new SearchSourceBuilder() .sort(SortBuilders.fieldSort(TrainedModelConfig.MODEL_ID.getPreferredName()) // If there are no resources, there might be no mapping for the id field. // This makes sure we don't get an error if that happens. .unmappedType("long")) - .query(buildExpandIdsQuery(tokens, tags)); - if (pageParams != null) { - sourceBuilder.from(pageParams.getFrom()).size(pageParams.getSize()); - } + .query(buildExpandIdsQuery(tokens, tags)) + // We "buffer" the from and size to take into account models stored as resources. + // This is so we handle the edge cases when the model that is stored as a resource is at the start/end of + // a page. + .from(Math.max(0, pageParams.getFrom() - foundResourceIds.size())) + .size(Math.min(10_000, pageParams.getSize() + foundResourceIds.size())); sourceBuilder.trackTotalHits(true) // we only care about the item id's .fetchSource(TrainedModelConfig.MODEL_ID.getPreferredName(), null); @@ -406,17 +418,6 @@ public void expandIds(String idExpression, indicesOptions.expandWildcardsClosed(), indicesOptions)) .source(sourceBuilder); - Set foundResourceIds = new LinkedHashSet<>(); - if (tags.isEmpty()) { - foundResourceIds.addAll(matchedResourceIds(tokens)); - } else { - for(String resourceId : matchedResourceIds(tokens)) { - // Does the model as a resource have all the tags? - if (Sets.newHashSet(loadModelFromResource(resourceId, true).getTags()).containsAll(tags)) { - foundResourceIds.add(resourceId); - } - } - } executeAsyncWithOrigin(client.threadPool().getThreadContext(), ML_ORIGIN, @@ -424,6 +425,7 @@ public void expandIds(String idExpression, ActionListener.wrap( response -> { long totalHitCount = response.getHits().getTotalHits().value + foundResourceIds.size(); + Set foundFromDocs = new HashSet<>(); for (SearchHit hit : response.getHits().getHits()) { Map docSource = hit.getSourceAsMap(); if (docSource == null) { @@ -431,15 +433,17 @@ public void expandIds(String idExpression, } Object idValue = docSource.get(TrainedModelConfig.MODEL_ID.getPreferredName()); if (idValue instanceof String) { - foundResourceIds.add(idValue.toString()); + foundFromDocs.add(idValue.toString()); } } + Set allFoundIds = collectIds(pageParams, foundResourceIds, foundFromDocs, totalHitCount); ExpandedIdsMatcher requiredMatches = new ExpandedIdsMatcher(tokens, allowNoResources); - requiredMatches.filterMatchedIds(foundResourceIds); + requiredMatches.filterMatchedIds(allFoundIds); if (requiredMatches.hasUnmatchedIds()) { idsListener.onFailure(ExceptionsHelper.missingTrainedModel(requiredMatches.unmatchedIdsString())); } else { - idsListener.onResponse(Tuple.tuple(totalHitCount, foundResourceIds)); + + idsListener.onResponse(Tuple.tuple(totalHitCount, allFoundIds)); } }, idsListener::onFailure @@ -447,6 +451,40 @@ public void expandIds(String idExpression, client::search); } + static Set collectIds(PageParams pageParams, Set foundFromResources, Set foundFromDocs, long totalMatchedIds) { + TreeSet allFoundIds = new TreeSet<>(foundFromDocs); + allFoundIds.addAll(foundFromResources); + int from = pageParams.getFrom(); + int bufferedFrom = Math.min(foundFromResources.size(), from); + + // If size = 10_000 but there aren't that many total IDs, reduce the size here to make following logic simpler + int sizeLimit = (int)Math.min(pageParams.getSize(), totalMatchedIds - from); + + // Last page this means that if we "buffered" the from pagination due to resources we should clear that out + // We only clear from the front as that would include buffered IDs that fall on the previous page + if (from + sizeLimit >= totalMatchedIds) { + while (bufferedFrom > 0 || allFoundIds.size() > sizeLimit) { + allFoundIds.remove(allFoundIds.first()); + bufferedFrom--; + } + } + + // Systematically remove items while we are above the limit + while (allFoundIds.size() > sizeLimit) { + // If we are still over limit, and have buffered items, that means the first ids belong on the previous page + if (bufferedFrom > 0) { + allFoundIds.remove(allFoundIds.first()); + bufferedFrom--; + } + else { + // If we have removed all items belonging on the previous page, but are still over sized, this means we should + // remove items that belong on the next page. + allFoundIds.remove(allFoundIds.last()); + } + } + return allFoundIds; + } + static QueryBuilder buildExpandIdsQuery(String[] tokens, Collection tags) { BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery() .filter(buildQueryIdExpressionQuery(tokens, TrainedModelConfig.MODEL_ID.getPreferredName())); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java index 1f90313899dc3..eca771838d514 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java @@ -14,12 +14,16 @@ import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.TermQueryBuilder; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.action.util.PageParams; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.TreeSet; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; @@ -86,6 +90,63 @@ public void testExpandIdsQuery() { }); } + public void testExpandIdsPagination() { + //NOTE: these test assume that the query pagination results are "buffered" + + assertThat(TrainedModelProvider.collectIds(new PageParams(0, 3), + Collections.emptySet(), + new HashSet<>(Arrays.asList("a", "b", "c")), + 5), + equalTo(new TreeSet<>(Arrays.asList("a", "b", "c")))); + + assertThat(TrainedModelProvider.collectIds(new PageParams(0, 3), + Collections.singleton("a"), + new HashSet<>(Arrays.asList("b", "c", "d")), + 5), + equalTo(new TreeSet<>(Arrays.asList("a", "b", "c")))); + + assertThat(TrainedModelProvider.collectIds(new PageParams(1, 3), + Collections.singleton("a"), + new HashSet<>(Arrays.asList("b", "c", "d")), + 5), + equalTo(new TreeSet<>(Arrays.asList("b", "c", "d")))); + + assertThat(TrainedModelProvider.collectIds(new PageParams(1, 1), + Collections.singleton("c"), + new HashSet<>(Arrays.asList("a", "b")), 5), + equalTo(new TreeSet<>(Arrays.asList("b")))); + + assertThat(TrainedModelProvider.collectIds(new PageParams(1, 1), + Collections.singleton("b"), + new HashSet<>(Arrays.asList("a", "c")), 5), + equalTo(new TreeSet<>(Arrays.asList("b")))); + + assertThat(TrainedModelProvider.collectIds(new PageParams(4, 1), + Collections.singleton("d"), + new HashSet<>(Arrays.asList("c", "e")), 5), + equalTo(new TreeSet<>(Arrays.asList("e")))); + + assertThat(TrainedModelProvider.collectIds(new PageParams(4, 100), + Collections.singleton("d"), + new HashSet<>(Arrays.asList("c", "e")), 5), + equalTo(new TreeSet<>(Arrays.asList("e")))); + + assertThat(TrainedModelProvider.collectIds(new PageParams(1, 2), + new HashSet<>(Arrays.asList("a", "b")), + new HashSet<>(Arrays.asList("c", "d", "e")), 5), + equalTo(new TreeSet<>(Arrays.asList("b", "c")))); + + assertThat(TrainedModelProvider.collectIds(new PageParams(1, 3), + new HashSet<>(Arrays.asList("a", "b")), + new HashSet<>(Arrays.asList("c", "d", "e")), 5), + equalTo(new TreeSet<>(Arrays.asList("b", "c", "d")))); + + assertThat(TrainedModelProvider.collectIds(new PageParams(2, 3), + new HashSet<>(Arrays.asList("a", "b")), + new HashSet<>(Arrays.asList("c", "d", "e")), 5), + equalTo(new TreeSet<>(Arrays.asList("c", "d", "e")))); + } + public void testGetModelThatExistsAsResourceButIsMissing() { TrainedModelProvider trainedModelProvider = new TrainedModelProvider(mock(Client.class), xContentRegistry()); ElasticsearchException ex = expectThrows(ElasticsearchException.class, diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml index d7cbc9825b7ad..01846afdb20f5 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml @@ -72,6 +72,56 @@ setup: } } } + + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + ml.put_trained_model: + model_id: yyy-classification-model + body: > + { + "description": "empty model for tests", + "input": {"field_names": ["field1", "field2"]}, + "tags": ["classification", "tag3"], + "definition": { + "preprocessors": [], + "trained_model": { + "tree": { + "feature_names": ["field1", "field2"], + "tree_structure": [ + {"node_index": 0, "leaf_value": 1} + ], + "target_type": "classification", + "classification_labels": ["no", "yes"] + } + } + } + } + + - do: + headers: + Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser + ml.put_trained_model: + model_id: zzz-classification-model + body: > + { + "description": "empty model for tests", + "input": {"field_names": ["field1", "field2"]}, + "tags": ["classification", "tag3"], + "definition": { + "preprocessors": [], + "trained_model": { + "tree": { + "feature_names": ["field1", "field2"], + "tree_structure": [ + {"node_index": 0, "leaf_value": 1} + ], + "target_type": "classification", + "classification_labels": ["no", "yes"] + } + } + } + } --- "Test get given missing trained model": @@ -102,10 +152,13 @@ setup: - do: ml.get_trained_models: model_id: "*" - - match: { count: 4 } + - match: { count: 6 } - match: { trained_model_configs.0.model_id: "a-classification-model" } - match: { trained_model_configs.1.model_id: "a-regression-model-0" } - match: { trained_model_configs.2.model_id: "a-regression-model-1" } + - match: { trained_model_configs.3.model_id: "lang_ident_model_1" } + - match: { trained_model_configs.4.model_id: "yyy-classification-model" } + - match: { trained_model_configs.5.model_id: "zzz-classification-model" } - do: ml.get_trained_models: @@ -119,7 +172,7 @@ setup: model_id: "*" from: 0 size: 2 - - match: { count: 4 } + - match: { count: 6 } - match: { trained_model_configs.0.model_id: "a-classification-model" } - match: { trained_model_configs.1.model_id: "a-regression-model-0" } @@ -128,8 +181,67 @@ setup: model_id: "*" from: 1 size: 1 - - match: { count: 4 } + - match: { count: 6 } + - length: { trained_model_configs: 1 } - match: { trained_model_configs.0.model_id: "a-regression-model-0" } + + - do: + ml.get_trained_models: + model_id: "*" + from: 2 + size: 2 + - match: { count: 6 } + - length: { trained_model_configs: 2 } + - match: { trained_model_configs.0.model_id: "a-regression-model-1" } + - match: { trained_model_configs.1.model_id: "lang_ident_model_1" } + + - do: + ml.get_trained_models: + model_id: "*" + from: 3 + size: 1 + - match: { count: 6 } + - length: { trained_model_configs: 1 } + - match: { trained_model_configs.0.model_id: "lang_ident_model_1" } + + - do: + ml.get_trained_models: + model_id: "*" + from: 3 + size: 2 + - match: { count: 6 } + - length: { trained_model_configs: 2 } + - match: { trained_model_configs.0.model_id: "lang_ident_model_1" } + - match: { trained_model_configs.1.model_id: "yyy-classification-model" } + + - do: + ml.get_trained_models: + model_id: "*" + from: 4 + size: 2 + - match: { count: 6 } + - length: { trained_model_configs: 2 } + - match: { trained_model_configs.0.model_id: "yyy-classification-model" } + - match: { trained_model_configs.1.model_id: "zzz-classification-model" } + + - do: + ml.get_trained_models: + model_id: "*" + from: 5 + size: 1 + - match: { count: 6 } + - length: { trained_model_configs: 1 } + - match: { trained_model_configs.0.model_id: "zzz-classification-model" } + + - do: + ml.get_trained_models: + model_id: "*" + from: 5 + size: 2 + - match: { count: 6 } + - length: { trained_model_configs: 1 } + - match: { trained_model_configs.0.model_id: "zzz-classification-model" } + --- "Test get models with tags": - do: From e7e8249ac9c6787ed8625501519fd627d5f0a619 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Wed, 29 Jan 2020 07:27:54 -0500 Subject: [PATCH 2/4] addressing pr comments --- .../xpack/ml/inference/persistence/TrainedModelProvider.java | 3 +-- .../ml/inference/persistence/TrainedModelProviderTests.java | 2 +- .../test/resources/rest-api-spec/test/ml/inference_crud.yml | 3 +++ 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index 956727cce39d9..0140feb005f86 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -475,8 +475,7 @@ static Set collectIds(PageParams pageParams, Set foundFromResour if (bufferedFrom > 0) { allFoundIds.remove(allFoundIds.first()); bufferedFrom--; - } - else { + } else { // If we have removed all items belonging on the previous page, but are still over sized, this means we should // remove items that belong on the next page. allFoundIds.remove(allFoundIds.last()); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java index eca771838d514..282fed885f230 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java @@ -91,7 +91,7 @@ public void testExpandIdsQuery() { } public void testExpandIdsPagination() { - //NOTE: these test assume that the query pagination results are "buffered" + // NOTE: these tests assume that the query pagination results are "buffered" assertThat(TrainedModelProvider.collectIds(new PageParams(0, 3), Collections.emptySet(), diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml index 01846afdb20f5..172cec2c47cf2 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml @@ -153,6 +153,7 @@ setup: ml.get_trained_models: model_id: "*" - match: { count: 6 } + - length: { trained_model_configs: 6 } - match: { trained_model_configs.0.model_id: "a-classification-model" } - match: { trained_model_configs.1.model_id: "a-regression-model-0" } - match: { trained_model_configs.2.model_id: "a-regression-model-1" } @@ -164,6 +165,7 @@ setup: ml.get_trained_models: model_id: "a-regression*" - match: { count: 2 } + - length: { trained_model_configs: 2 } - match: { trained_model_configs.0.model_id: "a-regression-model-0" } - match: { trained_model_configs.1.model_id: "a-regression-model-1" } @@ -173,6 +175,7 @@ setup: from: 0 size: 2 - match: { count: 6 } + - length: { trained_model_configs: 2 } - match: { trained_model_configs.0.model_id: "a-classification-model" } - match: { trained_model_configs.1.model_id: "a-regression-model-0" } From 5a7052f66fe30947a49d67e5e4e043f5e9cf8247 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Wed, 29 Jan 2020 14:37:58 -0500 Subject: [PATCH 3/4] addressing PR comments --- .../persistence/TrainedModelProvider.java | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index 0140feb005f86..c411074c28c09 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -384,11 +384,13 @@ public void expandIds(String idExpression, Set tags, ActionListener>> idsListener) { String[] tokens = Strings.tokenizeToStringArray(idExpression, ","); - Set foundResourceIds = new HashSet<>(); + Set matchedResourceIds = matchedResourceIds(tokens); + Set foundResourceIds; if (tags.isEmpty()) { - foundResourceIds.addAll(matchedResourceIds(tokens)); + foundResourceIds = matchedResourceIds; } else { - for(String resourceId : matchedResourceIds(tokens)) { + foundResourceIds = new HashSet<>(); + for(String resourceId : matchedResourceIds) { // Does the model as a resource have all the tags? if (Sets.newHashSet(loadModelFromResource(resourceId, true).getTags()).containsAll(tags)) { foundResourceIds.add(resourceId); @@ -452,6 +454,12 @@ public void expandIds(String idExpression, } static Set collectIds(PageParams pageParams, Set foundFromResources, Set foundFromDocs, long totalMatchedIds) { + // If there are no matching resource models, there was no buffering and the models from the docs + // are paginated correctly. + if (foundFromResources.isEmpty()) { + return foundFromDocs; + } + TreeSet allFoundIds = new TreeSet<>(foundFromDocs); allFoundIds.addAll(foundFromResources); int from = pageParams.getFrom(); @@ -554,7 +562,7 @@ private static QueryBuilder buildQueryIdExpressionQuery(String[] tokens, String private Set matchedResourceIds(String[] tokens) { if (Strings.isAllOrWildcard(tokens)) { - return new HashSet<>(MODELS_STORED_AS_RESOURCE); + return MODELS_STORED_AS_RESOURCE; } Set matchedModels = new HashSet<>(); @@ -572,7 +580,7 @@ private Set matchedResourceIds(String[] tokens) { } } } - return matchedModels; + return Collections.unmodifiableSet(matchedModels); } private static T handleSearchItem(MultiSearchResponse.Item item, From 87f28c9053b4d0e0d1f84bb67a125880f488983d Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Thu, 30 Jan 2020 11:12:22 -0500 Subject: [PATCH 4/4] Simplifying paging logic --- .../persistence/TrainedModelProvider.java | 33 ++++++------------- .../TrainedModelProviderTests.java | 29 +++++----------- .../rest-api-spec/test/ml/inference_crud.yml | 25 ++++++++++---- 3 files changed, 36 insertions(+), 51 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index c411074c28c09..ad1225e516844 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -438,7 +438,7 @@ public void expandIds(String idExpression, foundFromDocs.add(idValue.toString()); } } - Set allFoundIds = collectIds(pageParams, foundResourceIds, foundFromDocs, totalHitCount); + Set allFoundIds = collectIds(pageParams, foundResourceIds, foundFromDocs); ExpandedIdsMatcher requiredMatches = new ExpandedIdsMatcher(tokens, allowNoResources); requiredMatches.filterMatchedIds(allFoundIds); if (requiredMatches.hasUnmatchedIds()) { @@ -453,7 +453,7 @@ public void expandIds(String idExpression, client::search); } - static Set collectIds(PageParams pageParams, Set foundFromResources, Set foundFromDocs, long totalMatchedIds) { + static Set collectIds(PageParams pageParams, Set foundFromResources, Set foundFromDocs) { // If there are no matching resource models, there was no buffering and the models from the docs // are paginated correctly. if (foundFromResources.isEmpty()) { @@ -462,33 +462,20 @@ static Set collectIds(PageParams pageParams, Set foundFromResour TreeSet allFoundIds = new TreeSet<>(foundFromDocs); allFoundIds.addAll(foundFromResources); - int from = pageParams.getFrom(); - int bufferedFrom = Math.min(foundFromResources.size(), from); - // If size = 10_000 but there aren't that many total IDs, reduce the size here to make following logic simpler - int sizeLimit = (int)Math.min(pageParams.getSize(), totalMatchedIds - from); - - // Last page this means that if we "buffered" the from pagination due to resources we should clear that out - // We only clear from the front as that would include buffered IDs that fall on the previous page - if (from + sizeLimit >= totalMatchedIds) { - while (bufferedFrom > 0 || allFoundIds.size() > sizeLimit) { + if (pageParams.getFrom() > 0) { + // not the first page so there will be extra results at the front to remove + int numToTrimFromFront = Math.min(foundFromResources.size(), pageParams.getFrom()); + for (int i = 0; i < numToTrimFromFront; i++) { allFoundIds.remove(allFoundIds.first()); - bufferedFrom--; } } - // Systematically remove items while we are above the limit - while (allFoundIds.size() > sizeLimit) { - // If we are still over limit, and have buffered items, that means the first ids belong on the previous page - if (bufferedFrom > 0) { - allFoundIds.remove(allFoundIds.first()); - bufferedFrom--; - } else { - // If we have removed all items belonging on the previous page, but are still over sized, this means we should - // remove items that belong on the next page. - allFoundIds.remove(allFoundIds.last()); - } + // trim down to size removing from the rear + while (allFoundIds.size() > pageParams.getSize()) { + allFoundIds.remove(allFoundIds.last()); } + return allFoundIds; } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java index 282fed885f230..aee4c43f22769 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java @@ -95,55 +95,42 @@ public void testExpandIdsPagination() { assertThat(TrainedModelProvider.collectIds(new PageParams(0, 3), Collections.emptySet(), - new HashSet<>(Arrays.asList("a", "b", "c")), - 5), + new HashSet<>(Arrays.asList("a", "b", "c"))), equalTo(new TreeSet<>(Arrays.asList("a", "b", "c")))); assertThat(TrainedModelProvider.collectIds(new PageParams(0, 3), Collections.singleton("a"), - new HashSet<>(Arrays.asList("b", "c", "d")), - 5), + new HashSet<>(Arrays.asList("b", "c", "d"))), equalTo(new TreeSet<>(Arrays.asList("a", "b", "c")))); assertThat(TrainedModelProvider.collectIds(new PageParams(1, 3), Collections.singleton("a"), - new HashSet<>(Arrays.asList("b", "c", "d")), - 5), + new HashSet<>(Arrays.asList("b", "c", "d"))), equalTo(new TreeSet<>(Arrays.asList("b", "c", "d")))); assertThat(TrainedModelProvider.collectIds(new PageParams(1, 1), Collections.singleton("c"), - new HashSet<>(Arrays.asList("a", "b")), 5), + new HashSet<>(Arrays.asList("a", "b"))), equalTo(new TreeSet<>(Arrays.asList("b")))); assertThat(TrainedModelProvider.collectIds(new PageParams(1, 1), Collections.singleton("b"), - new HashSet<>(Arrays.asList("a", "c")), 5), + new HashSet<>(Arrays.asList("a", "c"))), equalTo(new TreeSet<>(Arrays.asList("b")))); - assertThat(TrainedModelProvider.collectIds(new PageParams(4, 1), - Collections.singleton("d"), - new HashSet<>(Arrays.asList("c", "e")), 5), - equalTo(new TreeSet<>(Arrays.asList("e")))); - - assertThat(TrainedModelProvider.collectIds(new PageParams(4, 100), - Collections.singleton("d"), - new HashSet<>(Arrays.asList("c", "e")), 5), - equalTo(new TreeSet<>(Arrays.asList("e")))); - assertThat(TrainedModelProvider.collectIds(new PageParams(1, 2), new HashSet<>(Arrays.asList("a", "b")), - new HashSet<>(Arrays.asList("c", "d", "e")), 5), + new HashSet<>(Arrays.asList("c", "d", "e"))), equalTo(new TreeSet<>(Arrays.asList("b", "c")))); assertThat(TrainedModelProvider.collectIds(new PageParams(1, 3), new HashSet<>(Arrays.asList("a", "b")), - new HashSet<>(Arrays.asList("c", "d", "e")), 5), + new HashSet<>(Arrays.asList("c", "d", "e"))), equalTo(new TreeSet<>(Arrays.asList("b", "c", "d")))); assertThat(TrainedModelProvider.collectIds(new PageParams(2, 3), new HashSet<>(Arrays.asList("a", "b")), - new HashSet<>(Arrays.asList("c", "d", "e")), 5), + new HashSet<>(Arrays.asList("c", "d", "e"))), equalTo(new TreeSet<>(Arrays.asList("c", "d", "e")))); } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml index 172cec2c47cf2..0c9fbb350bb0f 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_crud.yml @@ -229,19 +229,30 @@ setup: - do: ml.get_trained_models: - model_id: "*" - from: 5 + model_id: "a-*,lang*,zzz*" + allow_no_match: true + from: 3 size: 1 - - match: { count: 6 } + - match: { count: 5 } + - length: { trained_model_configs: 1 } + - match: { trained_model_configs.0.model_id: "lang_ident_model_1" } + + - do: + ml.get_trained_models: + model_id: "a-*,lang*,zzz*" + allow_no_match: true + from: 4 + size: 1 + - match: { count: 5 } - length: { trained_model_configs: 1 } - match: { trained_model_configs.0.model_id: "zzz-classification-model" } - do: ml.get_trained_models: - model_id: "*" - from: 5 - size: 2 - - match: { count: 6 } + model_id: "a-*,lang*,zzz*" + from: 4 + size: 100 + - match: { count: 5 } - length: { trained_model_configs: 1 } - match: { trained_model_configs.0.model_id: "zzz-classification-model" }