77
88import org .apache .logging .log4j .LogManager ;
99import org .apache .logging .log4j .Logger ;
10+ import org .apache .logging .log4j .message .ParameterizedMessage ;
1011import org .elasticsearch .action .search .ClearScrollAction ;
1112import org .elasticsearch .action .search .ClearScrollRequest ;
1213import org .elasticsearch .action .search .SearchAction ;
2021import org .elasticsearch .search .SearchHit ;
2122import org .elasticsearch .search .sort .SortOrder ;
2223import org .elasticsearch .xpack .core .ClientHelper ;
23- import org .elasticsearch .xpack .core .ml .datafeed .extractor .ExtractorUtils ;
2424import org .elasticsearch .xpack .ml .datafeed .extractor .fields .ExtractedField ;
2525import org .elasticsearch .xpack .ml .dataframe .DataFrameAnalyticsFields ;
2626
3434import java .util .Objects ;
3535import java .util .Optional ;
3636import java .util .concurrent .TimeUnit ;
37+ import java .util .function .Supplier ;
3738import java .util .stream .Collectors ;
3839
3940/**
@@ -91,9 +92,28 @@ public Optional<List<Row>> next() throws IOException {
9192
9293 protected List <Row > initScroll () throws IOException {
9394 LOGGER .debug ("[{}] Initializing scroll" , context .jobId );
94- SearchResponse searchResponse = executeSearchRequest (buildSearchRequest ());
95- LOGGER .debug ("[{}] Search response was obtained" , context .jobId );
96- return processSearchResponse (searchResponse );
95+ return tryRequestWithSearchResponse (() -> executeSearchRequest (buildSearchRequest ()));
96+ }
97+
98+ private List <Row > tryRequestWithSearchResponse (Supplier <SearchResponse > request ) throws IOException {
99+ try {
100+ // We've set allow_partial_search_results to false which means if something
101+ // goes wrong the request will throw.
102+ SearchResponse searchResponse = request .get ();
103+ LOGGER .debug ("[{}] Search response was obtained" , context .jobId );
104+
105+ // Request was successful so we can restore the flag to retry if a future failure occurs
106+ searchHasShardFailure = false ;
107+
108+ return processSearchResponse (searchResponse );
109+ } catch (Exception e ) {
110+ if (searchHasShardFailure ) {
111+ throw e ;
112+ }
113+ LOGGER .warn (new ParameterizedMessage ("[{}] Search resulted to failure; retrying once" , context .jobId ), e );
114+ markScrollAsErrored ();
115+ return initScroll ();
116+ }
97117 }
98118
99119 protected SearchResponse executeSearchRequest (SearchRequestBuilder searchRequestBuilder ) {
@@ -103,6 +123,8 @@ protected SearchResponse executeSearchRequest(SearchRequestBuilder searchRequest
103123 private SearchRequestBuilder buildSearchRequest () {
104124 SearchRequestBuilder searchRequestBuilder = new SearchRequestBuilder (client , SearchAction .INSTANCE )
105125 .setScroll (SCROLL_TIMEOUT )
126+ // This ensures the search throws if there are failures and the scroll context gets cleared automatically
127+ .setAllowPartialSearchResults (false )
106128 .addSort (DataFrameAnalyticsFields .ID , SortOrder .ASC )
107129 .setIndices (context .indices )
108130 .setSize (context .scrollSize )
@@ -117,14 +139,6 @@ private SearchRequestBuilder buildSearchRequest() {
117139 }
118140
119141 private List <Row > processSearchResponse (SearchResponse searchResponse ) throws IOException {
120-
121- if (searchResponse .getFailedShards () > 0 && searchHasShardFailure == false ) {
122- LOGGER .debug ("[{}] Resetting scroll search after shard failure" , context .jobId );
123- markScrollAsErrored ();
124- return initScroll ();
125- }
126-
127- ExtractorUtils .checkSearchWasSuccessful (context .jobId , searchResponse );
128142 scrollId = searchResponse .getScrollId ();
129143 if (searchResponse .getHits ().getHits ().length == 0 ) {
130144 hasNext = false ;
@@ -143,7 +157,6 @@ private List<Row> processSearchResponse(SearchResponse searchResponse) throws IO
143157 rows .add (createRow (hit ));
144158 }
145159 return rows ;
146-
147160 }
148161
149162 private Row createRow (SearchHit hit ) {
@@ -163,15 +176,13 @@ private Row createRow(SearchHit hit) {
163176
164177 private List <Row > continueScroll () throws IOException {
165178 LOGGER .debug ("[{}] Continuing scroll with id [{}]" , context .jobId , scrollId );
166- SearchResponse searchResponse = executeSearchScrollRequest (scrollId );
167- LOGGER .debug ("[{}] Search response was obtained" , context .jobId );
168- return processSearchResponse (searchResponse );
179+ return tryRequestWithSearchResponse (() -> executeSearchScrollRequest (scrollId ));
169180 }
170181
171182 private void markScrollAsErrored () {
172183 // This could be a transient error with the scroll Id.
173184 // Reinitialise the scroll and try again but only once.
174- resetScroll () ;
185+ scrollId = null ;
175186 searchHasShardFailure = true ;
176187 }
177188
@@ -183,11 +194,6 @@ protected SearchResponse executeSearchScrollRequest(String scrollId) {
183194 .get ());
184195 }
185196
186- private void resetScroll () {
187- clearScroll (scrollId );
188- scrollId = null ;
189- }
190-
191197 private void clearScroll (String scrollId ) {
192198 if (scrollId != null ) {
193199 ClearScrollRequest request = new ClearScrollRequest ();
0 commit comments