Skip to content

Commit 521734d

Browse files
[FEATURE][ML] Ensure data extractor is not leaking scroll contexts (#42960)
1 parent 16a26a5 commit 521734d

File tree

3 files changed

+373
-31
lines changed

3 files changed

+373
-31
lines changed

client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,6 @@
164164
import static org.hamcrest.CoreMatchers.hasItem;
165165
import static org.hamcrest.CoreMatchers.hasItems;
166166
import static org.hamcrest.CoreMatchers.not;
167-
import static org.hamcrest.Matchers.anyOf;
168167
import static org.hamcrest.Matchers.closeTo;
169168
import static org.hamcrest.Matchers.contains;
170169
import static org.hamcrest.Matchers.containsInAnyOrder;
@@ -1365,7 +1364,8 @@ public void testStartDataFrameAnalyticsConfig() throws Exception {
13651364
String sourceIndex = "start-test-source-index";
13661365
String destIndex = "start-test-dest-index";
13671366
createIndex(sourceIndex, defaultMappingForTest());
1368-
highLevelClient().index(new IndexRequest(sourceIndex).source(XContentType.JSON, "total", 10000), RequestOptions.DEFAULT);
1367+
highLevelClient().index(new IndexRequest(sourceIndex).source(XContentType.JSON, "total", 10000)
1368+
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE), RequestOptions.DEFAULT);
13691369

13701370
// Verify that the destination index does not exist. Otherwise, analytics' reindexing step would fail.
13711371
assertFalse(highLevelClient().indices().exists(new GetIndexRequest(destIndex), RequestOptions.DEFAULT));
@@ -1391,12 +1391,6 @@ public void testStartDataFrameAnalyticsConfig() throws Exception {
13911391
new StartDataFrameAnalyticsRequest(configId),
13921392
machineLearningClient::startDataFrameAnalytics, machineLearningClient::startDataFrameAnalyticsAsync);
13931393
assertTrue(startDataFrameAnalyticsResponse.isAcknowledged());
1394-
assertThat(
1395-
getAnalyticsState(configId),
1396-
anyOf(
1397-
equalTo(DataFrameAnalyticsState.STARTED),
1398-
equalTo(DataFrameAnalyticsState.REINDEXING),
1399-
equalTo(DataFrameAnalyticsState.ANALYZING)));
14001394

14011395
// Wait for the analytics to stop.
14021396
assertBusy(() -> assertThat(getAnalyticsState(configId), equalTo(DataFrameAnalyticsState.STOPPED)), 30, TimeUnit.SECONDS);
@@ -1409,7 +1403,8 @@ public void testStopDataFrameAnalyticsConfig() throws Exception {
14091403
String sourceIndex = "stop-test-source-index";
14101404
String destIndex = "stop-test-dest-index";
14111405
createIndex(sourceIndex, mappingForClassification());
1412-
highLevelClient().index(new IndexRequest(sourceIndex).source(XContentType.JSON, "total", 10000), RequestOptions.DEFAULT);
1406+
highLevelClient().index(new IndexRequest(sourceIndex).source(XContentType.JSON, "total", 10000)
1407+
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE), RequestOptions.DEFAULT);
14131408

14141409
// Verify that the destination index does not exist. Otherwise, analytics' reindexing step would fail.
14151410
assertFalse(highLevelClient().indices().exists(new GetIndexRequest(destIndex), RequestOptions.DEFAULT));

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import org.apache.logging.log4j.LogManager;
99
import org.apache.logging.log4j.Logger;
10+
import org.apache.logging.log4j.message.ParameterizedMessage;
1011
import org.elasticsearch.action.search.ClearScrollAction;
1112
import org.elasticsearch.action.search.ClearScrollRequest;
1213
import org.elasticsearch.action.search.SearchAction;
@@ -20,7 +21,6 @@
2021
import org.elasticsearch.search.SearchHit;
2122
import org.elasticsearch.search.sort.SortOrder;
2223
import org.elasticsearch.xpack.core.ClientHelper;
23-
import org.elasticsearch.xpack.core.ml.datafeed.extractor.ExtractorUtils;
2424
import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField;
2525
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsFields;
2626

@@ -34,6 +34,7 @@
3434
import java.util.Objects;
3535
import java.util.Optional;
3636
import java.util.concurrent.TimeUnit;
37+
import java.util.function.Supplier;
3738
import java.util.stream.Collectors;
3839

3940
/**
@@ -91,9 +92,28 @@ public Optional<List<Row>> next() throws IOException {
9192

9293
protected List<Row> initScroll() throws IOException {
9394
LOGGER.debug("[{}] Initializing scroll", context.jobId);
94-
SearchResponse searchResponse = executeSearchRequest(buildSearchRequest());
95-
LOGGER.debug("[{}] Search response was obtained", context.jobId);
96-
return processSearchResponse(searchResponse);
95+
return tryRequestWithSearchResponse(() -> executeSearchRequest(buildSearchRequest()));
96+
}
97+
98+
private List<Row> tryRequestWithSearchResponse(Supplier<SearchResponse> request) throws IOException {
99+
try {
100+
// We've set allow_partial_search_results to false which means if something
101+
// goes wrong the request will throw.
102+
SearchResponse searchResponse = request.get();
103+
LOGGER.debug("[{}] Search response was obtained", context.jobId);
104+
105+
// Request was successful so we can restore the flag to retry if a future failure occurs
106+
searchHasShardFailure = false;
107+
108+
return processSearchResponse(searchResponse);
109+
} catch (Exception e) {
110+
if (searchHasShardFailure) {
111+
throw e;
112+
}
113+
LOGGER.warn(new ParameterizedMessage("[{}] Search resulted to failure; retrying once", context.jobId), e);
114+
markScrollAsErrored();
115+
return initScroll();
116+
}
97117
}
98118

99119
protected SearchResponse executeSearchRequest(SearchRequestBuilder searchRequestBuilder) {
@@ -103,6 +123,8 @@ protected SearchResponse executeSearchRequest(SearchRequestBuilder searchRequest
103123
private SearchRequestBuilder buildSearchRequest() {
104124
SearchRequestBuilder searchRequestBuilder = new SearchRequestBuilder(client, SearchAction.INSTANCE)
105125
.setScroll(SCROLL_TIMEOUT)
126+
// This ensures the search throws if there are failures and the scroll context gets cleared automatically
127+
.setAllowPartialSearchResults(false)
106128
.addSort(DataFrameAnalyticsFields.ID, SortOrder.ASC)
107129
.setIndices(context.indices)
108130
.setSize(context.scrollSize)
@@ -117,14 +139,6 @@ private SearchRequestBuilder buildSearchRequest() {
117139
}
118140

119141
private List<Row> processSearchResponse(SearchResponse searchResponse) throws IOException {
120-
121-
if (searchResponse.getFailedShards() > 0 && searchHasShardFailure == false) {
122-
LOGGER.debug("[{}] Resetting scroll search after shard failure", context.jobId);
123-
markScrollAsErrored();
124-
return initScroll();
125-
}
126-
127-
ExtractorUtils.checkSearchWasSuccessful(context.jobId, searchResponse);
128142
scrollId = searchResponse.getScrollId();
129143
if (searchResponse.getHits().getHits().length == 0) {
130144
hasNext = false;
@@ -143,7 +157,6 @@ private List<Row> processSearchResponse(SearchResponse searchResponse) throws IO
143157
rows.add(createRow(hit));
144158
}
145159
return rows;
146-
147160
}
148161

149162
private Row createRow(SearchHit hit) {
@@ -163,15 +176,13 @@ private Row createRow(SearchHit hit) {
163176

164177
private List<Row> continueScroll() throws IOException {
165178
LOGGER.debug("[{}] Continuing scroll with id [{}]", context.jobId, scrollId);
166-
SearchResponse searchResponse = executeSearchScrollRequest(scrollId);
167-
LOGGER.debug("[{}] Search response was obtained", context.jobId);
168-
return processSearchResponse(searchResponse);
179+
return tryRequestWithSearchResponse(() -> executeSearchScrollRequest(scrollId));
169180
}
170181

171182
private void markScrollAsErrored() {
172183
// This could be a transient error with the scroll Id.
173184
// Reinitialise the scroll and try again but only once.
174-
resetScroll();
185+
scrollId = null;
175186
searchHasShardFailure = true;
176187
}
177188

@@ -183,11 +194,6 @@ protected SearchResponse executeSearchScrollRequest(String scrollId) {
183194
.get());
184195
}
185196

186-
private void resetScroll() {
187-
clearScroll(scrollId);
188-
scrollId = null;
189-
}
190-
191197
private void clearScroll(String scrollId) {
192198
if (scrollId != null) {
193199
ClearScrollRequest request = new ClearScrollRequest();

0 commit comments

Comments
 (0)