3939import org .elasticsearch .search .SearchHit ;
4040import org .elasticsearch .search .SearchHits ;
4141import org .elasticsearch .search .SearchPhaseResult ;
42+ import org .elasticsearch .search .SearchService ;
4243import org .elasticsearch .search .SearchShardTarget ;
4344import org .elasticsearch .search .aggregations .InternalAggregation ;
4445import org .elasticsearch .search .aggregations .InternalAggregation .ReduceContext ;
6566import java .util .Map ;
6667import java .util .function .Function ;
6768import java .util .function .IntFunction ;
69+ import java .util .stream .Collectors ;
6870
6971public final class SearchPhaseController {
7072
@@ -427,6 +429,15 @@ private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResu
427429 return new ReducedQueryPhase (totalHits , topDocsStats .fetchHits , topDocsStats .getMaxScore (),
428430 false , null , null , null , null , SortedTopDocs .EMPTY , null , numReducePhases , 0 , 0 , true );
429431 }
432+ int total = queryResults .size ();
433+ queryResults = queryResults .stream ()
434+ .filter (res -> res .queryResult ().isNull () == false )
435+ .collect (Collectors .toList ());
436+ String errorMsg = "must have at least one non-empty search result, got 0 out of " + total ;
437+ assert queryResults .isEmpty () == false : errorMsg ;
438+ if (queryResults .isEmpty ()) {
439+ throw new IllegalStateException (errorMsg );
440+ }
430441 final QuerySearchResult firstResult = queryResults .stream ().findFirst ().get ().queryResult ();
431442 final boolean hasSuggest = firstResult .suggest () != null ;
432443 final boolean hasProfileResults = firstResult .hasProfileResults ();
@@ -497,6 +508,18 @@ private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResu
497508 firstResult .sortValueFormats (), numReducePhases , size , from , false );
498509 }
499510
511+ /*
512+ * Returns the size of the requested top documents (from + size)
513+ */
514+ static int getTopDocsSize (SearchRequest request ) {
515+ if (request .source () == null ) {
516+ return SearchService .DEFAULT_SIZE ;
517+ }
518+ SearchSourceBuilder source = request .source ();
519+ return (source .size () == -1 ? SearchService .DEFAULT_SIZE : source .size ()) +
520+ (source .from () == -1 ? SearchService .DEFAULT_FROM : source .from ());
521+ }
522+
500523 public static final class ReducedQueryPhase {
501524 // the sum of all hits across all reduces shards
502525 final TotalHits totalHits ;
@@ -576,6 +599,7 @@ static final class QueryPhaseResultConsumer extends ArraySearchPhaseResults<Sear
576599 private final SearchProgressListener progressListener ;
577600 private int numReducePhases = 0 ;
578601 private final TopDocsStats topDocsStats ;
602+ private final int topNSize ;
579603 private final boolean performFinalReduce ;
580604
581605 /**
@@ -589,7 +613,7 @@ static final class QueryPhaseResultConsumer extends ArraySearchPhaseResults<Sear
589613 */
590614 private QueryPhaseResultConsumer (SearchProgressListener progressListener , SearchPhaseController controller ,
591615 int expectedResultSize , int bufferSize , boolean hasTopDocs , boolean hasAggs ,
592- int trackTotalHitsUpTo , boolean performFinalReduce ) {
616+ int trackTotalHitsUpTo , int topNSize , boolean performFinalReduce ) {
593617 super (expectedResultSize );
594618 if (expectedResultSize != 1 && bufferSize < 2 ) {
595619 throw new IllegalArgumentException ("buffer size must be >= 2 if there is more than one expected result" );
@@ -610,6 +634,7 @@ private QueryPhaseResultConsumer(SearchProgressListener progressListener, Search
610634 this .hasAggs = hasAggs ;
611635 this .bufferSize = bufferSize ;
612636 this .topDocsStats = new TopDocsStats (trackTotalHitsUpTo );
637+ this .topNSize = topNSize ;
613638 this .performFinalReduce = performFinalReduce ;
614639 }
615640
@@ -622,36 +647,38 @@ public void consumeResult(SearchPhaseResult result) {
622647 }
623648
624649 private synchronized void consumeInternal (QuerySearchResult querySearchResult ) {
625- if (index == bufferSize ) {
650+ if (querySearchResult .isNull () == false ) {
651+ if (index == bufferSize ) {
652+ if (hasAggs ) {
653+ ReduceContext reduceContext = controller .reduceContextFunction .apply (false );
654+ InternalAggregations reducedAggs = InternalAggregations .topLevelReduce (Arrays .asList (aggsBuffer ), reduceContext );
655+ Arrays .fill (aggsBuffer , null );
656+ aggsBuffer [0 ] = reducedAggs ;
657+ }
658+ if (hasTopDocs ) {
659+ TopDocs reducedTopDocs = mergeTopDocs (Arrays .asList (topDocsBuffer ),
660+ // we have to merge here in the same way we collect on a shard
661+ topNSize , 0 );
662+ Arrays .fill (topDocsBuffer , null );
663+ topDocsBuffer [0 ] = reducedTopDocs ;
664+ }
665+ numReducePhases ++;
666+ index = 1 ;
667+ if (hasAggs ) {
668+ progressListener .notifyPartialReduce (progressListener .searchShards (processedShards ),
669+ topDocsStats .getTotalHits (), aggsBuffer [0 ], numReducePhases );
670+ }
671+ }
672+ final int i = index ++;
626673 if (hasAggs ) {
627- ReduceContext reduceContext = controller .reduceContextFunction .apply (false );
628- InternalAggregations reducedAggs = InternalAggregations .topLevelReduce (Arrays .asList (aggsBuffer ), reduceContext );
629- Arrays .fill (aggsBuffer , null );
630- aggsBuffer [0 ] = reducedAggs ;
674+ aggsBuffer [i ] = (InternalAggregations ) querySearchResult .consumeAggs ();
631675 }
632676 if (hasTopDocs ) {
633- TopDocs reducedTopDocs = mergeTopDocs (Arrays .asList (topDocsBuffer ),
634- // we have to merge here in the same way we collect on a shard
635- querySearchResult .from () + querySearchResult .size (), 0 );
636- Arrays .fill (topDocsBuffer , null );
637- topDocsBuffer [0 ] = reducedTopDocs ;
677+ final TopDocsAndMaxScore topDocs = querySearchResult .consumeTopDocs (); // can't be null
678+ topDocsStats .add (topDocs , querySearchResult .searchTimedOut (), querySearchResult .terminatedEarly ());
679+ setShardIndex (topDocs .topDocs , querySearchResult .getShardIndex ());
680+ topDocsBuffer [i ] = topDocs .topDocs ;
638681 }
639- numReducePhases ++;
640- index = 1 ;
641- if (hasAggs ) {
642- progressListener .notifyPartialReduce (progressListener .searchShards (processedShards ),
643- topDocsStats .getTotalHits (), aggsBuffer [0 ], numReducePhases );
644- }
645- }
646- final int i = index ++;
647- if (hasAggs ) {
648- aggsBuffer [i ] = (InternalAggregations ) querySearchResult .consumeAggs ();
649- }
650- if (hasTopDocs ) {
651- final TopDocsAndMaxScore topDocs = querySearchResult .consumeTopDocs (); // can't be null
652- topDocsStats .add (topDocs , querySearchResult .searchTimedOut (), querySearchResult .terminatedEarly ());
653- setShardIndex (topDocs .topDocs , querySearchResult .getShardIndex ());
654- topDocsBuffer [i ] = topDocs .topDocs ;
655682 }
656683 processedShards [querySearchResult .getShardIndex ()] = querySearchResult .getSearchShardTarget ();
657684 }
@@ -706,9 +733,10 @@ ArraySearchPhaseResults<SearchPhaseResult> newSearchPhaseResults(SearchProgressL
706733 if (isScrollRequest == false && (hasAggs || hasTopDocs )) {
707734 // no incremental reduce if scroll is used - we only hit a single shard or sometimes more...
708735 if (request .getBatchedReduceSize () < numShards ) {
736+ int topNSize = getTopDocsSize (request );
709737 // only use this if there are aggs and if there are more shards than we should reduce at once
710738 return new QueryPhaseResultConsumer (listener , this , numShards , request .getBatchedReduceSize (), hasTopDocs , hasAggs ,
711- trackTotalHitsUpTo , request .isFinalReduce ());
739+ trackTotalHitsUpTo , topNSize , request .isFinalReduce ());
712740 }
713741 }
714742 return new ArraySearchPhaseResults <SearchPhaseResult >(numShards ) {
@@ -731,7 +759,7 @@ ReducedQueryPhase reduce() {
731759
732760 static final class TopDocsStats {
733761 final int trackTotalHitsUpTo ;
734- private long totalHits ;
762+ long totalHits ;
735763 private TotalHits .Relation totalHitsRelation ;
736764 long fetchHits ;
737765 private float maxScore = Float .NEGATIVE_INFINITY ;
0 commit comments