2828import org .elasticsearch .action .support .WriteRequest ;
2929import org .elasticsearch .client .Client ;
3030import org .elasticsearch .common .CheckedBiFunction ;
31- import org .elasticsearch .common .Nullable ;
3231import org .elasticsearch .common .Strings ;
3332import org .elasticsearch .common .bytes .BytesReference ;
3433import org .elasticsearch .common .collect .Tuple ;
7473import java .util .Collections ;
7574import java .util .Comparator ;
7675import java .util .HashSet ;
77- import java .util .LinkedHashSet ;
7876import java .util .List ;
7977import java .util .Map ;
8078import java .util .Set ;
79+ import java .util .TreeSet ;
8180
8281import static org .elasticsearch .xpack .core .ClientHelper .ML_ORIGIN ;
8382import static org .elasticsearch .xpack .core .ClientHelper .executeAsyncWithOrigin ;
@@ -382,19 +381,34 @@ public void deleteTrainedModel(String modelId, ActionListener<Boolean> listener)
382381
383382 public void expandIds (String idExpression ,
384383 boolean allowNoResources ,
385- @ Nullable PageParams pageParams ,
384+ PageParams pageParams ,
386385 Set <String > tags ,
387386 ActionListener <Tuple <Long , Set <String >>> idsListener ) {
388387 String [] tokens = Strings .tokenizeToStringArray (idExpression , "," );
388+ Set <String > matchedResourceIds = matchedResourceIds (tokens );
389+ Set <String > foundResourceIds ;
390+ if (tags .isEmpty ()) {
391+ foundResourceIds = matchedResourceIds ;
392+ } else {
393+ foundResourceIds = new HashSet <>();
394+ for (String resourceId : matchedResourceIds ) {
395+ // Does the model as a resource have all the tags?
396+ if (Sets .newHashSet (loadModelFromResource (resourceId , true ).getTags ()).containsAll (tags )) {
397+ foundResourceIds .add (resourceId );
398+ }
399+ }
400+ }
389401 SearchSourceBuilder sourceBuilder = new SearchSourceBuilder ()
390402 .sort (SortBuilders .fieldSort (TrainedModelConfig .MODEL_ID .getPreferredName ())
391403 // If there are no resources, there might be no mapping for the id field.
392404 // This makes sure we don't get an error if that happens.
393405 .unmappedType ("long" ))
394- .query (buildExpandIdsQuery (tokens , tags ));
395- if (pageParams != null ) {
396- sourceBuilder .from (pageParams .getFrom ()).size (pageParams .getSize ());
397- }
406+ .query (buildExpandIdsQuery (tokens , tags ))
407+ // We "buffer" the from and size to take into account models stored as resources.
408+ // This is so we handle the edge cases when the model that is stored as a resource is at the start/end of
409+ // a page.
410+ .from (Math .max (0 , pageParams .getFrom () - foundResourceIds .size ()))
411+ .size (Math .min (10_000 , pageParams .getSize () + foundResourceIds .size ()));
398412 sourceBuilder .trackTotalHits (true )
399413 // we only care about the item id's
400414 .fetchSource (TrainedModelConfig .MODEL_ID .getPreferredName (), null );
@@ -407,47 +421,65 @@ public void expandIds(String idExpression,
407421 indicesOptions .expandWildcardsClosed (),
408422 indicesOptions ))
409423 .source (sourceBuilder );
410- Set <String > foundResourceIds = new LinkedHashSet <>();
411- if (tags .isEmpty ()) {
412- foundResourceIds .addAll (matchedResourceIds (tokens ));
413- } else {
414- for (String resourceId : matchedResourceIds (tokens )) {
415- // Does the model as a resource have all the tags?
416- if (Sets .newHashSet (loadModelFromResource (resourceId , true ).getTags ()).containsAll (tags )) {
417- foundResourceIds .add (resourceId );
418- }
419- }
420- }
421424
422425 executeAsyncWithOrigin (client .threadPool ().getThreadContext (),
423426 ML_ORIGIN ,
424427 searchRequest ,
425428 ActionListener .<SearchResponse >wrap (
426429 response -> {
427430 long totalHitCount = response .getHits ().getTotalHits ().value + foundResourceIds .size ();
431+ Set <String > foundFromDocs = new HashSet <>();
428432 for (SearchHit hit : response .getHits ().getHits ()) {
429433 Map <String , Object > docSource = hit .getSourceAsMap ();
430434 if (docSource == null ) {
431435 continue ;
432436 }
433437 Object idValue = docSource .get (TrainedModelConfig .MODEL_ID .getPreferredName ());
434438 if (idValue instanceof String ) {
435- foundResourceIds .add (idValue .toString ());
439+ foundFromDocs .add (idValue .toString ());
436440 }
437441 }
442+ Set <String > allFoundIds = collectIds (pageParams , foundResourceIds , foundFromDocs );
438443 ExpandedIdsMatcher requiredMatches = new ExpandedIdsMatcher (tokens , allowNoResources );
439- requiredMatches .filterMatchedIds (foundResourceIds );
444+ requiredMatches .filterMatchedIds (allFoundIds );
440445 if (requiredMatches .hasUnmatchedIds ()) {
441446 idsListener .onFailure (ExceptionsHelper .missingTrainedModel (requiredMatches .unmatchedIdsString ()));
442447 } else {
443- idsListener .onResponse (Tuple .tuple (totalHitCount , foundResourceIds ));
448+
449+ idsListener .onResponse (Tuple .tuple (totalHitCount , allFoundIds ));
444450 }
445451 },
446452 idsListener ::onFailure
447453 ),
448454 client ::search );
449455 }
450456
457+ static Set <String > collectIds (PageParams pageParams , Set <String > foundFromResources , Set <String > foundFromDocs ) {
458+ // If there are no matching resource models, there was no buffering and the models from the docs
459+ // are paginated correctly.
460+ if (foundFromResources .isEmpty ()) {
461+ return foundFromDocs ;
462+ }
463+
464+ TreeSet <String > allFoundIds = new TreeSet <>(foundFromDocs );
465+ allFoundIds .addAll (foundFromResources );
466+
467+ if (pageParams .getFrom () > 0 ) {
468+ // not the first page so there will be extra results at the front to remove
469+ int numToTrimFromFront = Math .min (foundFromResources .size (), pageParams .getFrom ());
470+ for (int i = 0 ; i < numToTrimFromFront ; i ++) {
471+ allFoundIds .remove (allFoundIds .first ());
472+ }
473+ }
474+
475+ // trim down to size removing from the rear
476+ while (allFoundIds .size () > pageParams .getSize ()) {
477+ allFoundIds .remove (allFoundIds .last ());
478+ }
479+
480+ return allFoundIds ;
481+ }
482+
451483 static QueryBuilder buildExpandIdsQuery (String [] tokens , Collection <String > tags ) {
452484 BoolQueryBuilder boolQueryBuilder = QueryBuilders .boolQuery ()
453485 .filter (buildQueryIdExpressionQuery (tokens , TrainedModelConfig .MODEL_ID .getPreferredName ()));
@@ -518,7 +550,7 @@ private static QueryBuilder buildQueryIdExpressionQuery(String[] tokens, String
518550
519551 private Set <String > matchedResourceIds (String [] tokens ) {
520552 if (Strings .isAllOrWildcard (tokens )) {
521- return new HashSet <>( MODELS_STORED_AS_RESOURCE ) ;
553+ return MODELS_STORED_AS_RESOURCE ;
522554 }
523555
524556 Set <String > matchedModels = new HashSet <>();
@@ -536,7 +568,7 @@ private Set<String> matchedResourceIds(String[] tokens) {
536568 }
537569 }
538570 }
539- return matchedModels ;
571+ return Collections . unmodifiableSet ( matchedModels ) ;
540572 }
541573
542574 private static <T > T handleSearchItem (MultiSearchResponse .Item item ,
0 commit comments