Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@
import static org.hamcrest.CoreMatchers.hasItem;
import static org.hamcrest.CoreMatchers.hasItems;
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsInAnyOrder;
Expand Down Expand Up @@ -1365,7 +1364,8 @@ public void testStartDataFrameAnalyticsConfig() throws Exception {
String sourceIndex = "start-test-source-index";
String destIndex = "start-test-dest-index";
createIndex(sourceIndex, defaultMappingForTest());
highLevelClient().index(new IndexRequest(sourceIndex).source(XContentType.JSON, "total", 10000), RequestOptions.DEFAULT);
highLevelClient().index(new IndexRequest(sourceIndex).source(XContentType.JSON, "total", 10000)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE), RequestOptions.DEFAULT);

// Verify that the destination index does not exist. Otherwise, analytics' reindexing step would fail.
assertFalse(highLevelClient().indices().exists(new GetIndexRequest(destIndex), RequestOptions.DEFAULT));
Expand All @@ -1391,12 +1391,6 @@ public void testStartDataFrameAnalyticsConfig() throws Exception {
new StartDataFrameAnalyticsRequest(configId),
machineLearningClient::startDataFrameAnalytics, machineLearningClient::startDataFrameAnalyticsAsync);
assertTrue(startDataFrameAnalyticsResponse.isAcknowledged());
assertThat(
getAnalyticsState(configId),
anyOf(
equalTo(DataFrameAnalyticsState.STARTED),
equalTo(DataFrameAnalyticsState.REINDEXING),
equalTo(DataFrameAnalyticsState.ANALYZING)));

// Wait for the analytics to stop.
assertBusy(() -> assertThat(getAnalyticsState(configId), equalTo(DataFrameAnalyticsState.STOPPED)), 30, TimeUnit.SECONDS);
Expand All @@ -1409,7 +1403,8 @@ public void testStopDataFrameAnalyticsConfig() throws Exception {
String sourceIndex = "stop-test-source-index";
String destIndex = "stop-test-dest-index";
createIndex(sourceIndex, mappingForClassification());
highLevelClient().index(new IndexRequest(sourceIndex).source(XContentType.JSON, "total", 10000), RequestOptions.DEFAULT);
highLevelClient().index(new IndexRequest(sourceIndex).source(XContentType.JSON, "total", 10000)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE), RequestOptions.DEFAULT);

// Verify that the destination index does not exist. Otherwise, analytics' reindexing step would fail.
assertFalse(highLevelClient().indices().exists(new GetIndexRequest(destIndex), RequestOptions.DEFAULT));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.action.search.ClearScrollAction;
import org.elasticsearch.action.search.ClearScrollRequest;
import org.elasticsearch.action.search.SearchAction;
Expand All @@ -20,7 +21,6 @@
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.sort.SortOrder;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.datafeed.extractor.ExtractorUtils;
import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField;
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsFields;

Expand All @@ -34,6 +34,7 @@
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import java.util.stream.Collectors;

/**
Expand Down Expand Up @@ -91,9 +92,28 @@ public Optional<List<Row>> next() throws IOException {

protected List<Row> initScroll() throws IOException {
LOGGER.debug("[{}] Initializing scroll", context.jobId);
SearchResponse searchResponse = executeSearchRequest(buildSearchRequest());
LOGGER.debug("[{}] Search response was obtained", context.jobId);
return processSearchResponse(searchResponse);
return tryRequestWithSearchResponse(() -> executeSearchRequest(buildSearchRequest()));
}

private List<Row> tryRequestWithSearchResponse(Supplier<SearchResponse> request) throws IOException {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think having this method is a nice way of handling with the 2 different requests in a single place. If people like it, we can reuse this in ScrollDataExtractor to simplify and rely on allow_partial_search_results to do its job of clearing the scroll context on error.

try {
// We've set allow_partial_search_results to false which means if something
// goes wrong the request will throw.
SearchResponse searchResponse = request.get();
LOGGER.debug("[{}] Search response was obtained", context.jobId);

// Request was successful so we can restore the flag to retry if a future failure occurs
searchHasShardFailure = false;

return processSearchResponse(searchResponse);
} catch (Exception e) {
if (searchHasShardFailure) {
throw e;
}
LOGGER.warn(new ParameterizedMessage("[{}] Search resulted to failure; retrying once", context.jobId), e);
markScrollAsErrored();
return initScroll();
}
}

protected SearchResponse executeSearchRequest(SearchRequestBuilder searchRequestBuilder) {
Expand All @@ -103,6 +123,8 @@ protected SearchResponse executeSearchRequest(SearchRequestBuilder searchRequest
private SearchRequestBuilder buildSearchRequest() {
SearchRequestBuilder searchRequestBuilder = new SearchRequestBuilder(client, SearchAction.INSTANCE)
.setScroll(SCROLL_TIMEOUT)
// This ensures the search throws if there are failures and the scroll context gets cleared automatically
.setAllowPartialSearchResults(false)
.addSort(DataFrameAnalyticsFields.ID, SortOrder.ASC)
.setIndices(context.indices)
.setSize(context.scrollSize)
Expand All @@ -117,14 +139,6 @@ private SearchRequestBuilder buildSearchRequest() {
}

private List<Row> processSearchResponse(SearchResponse searchResponse) throws IOException {

if (searchResponse.getFailedShards() > 0 && searchHasShardFailure == false) {
LOGGER.debug("[{}] Resetting scroll search after shard failure", context.jobId);
markScrollAsErrored();
return initScroll();
}

ExtractorUtils.checkSearchWasSuccessful(context.jobId, searchResponse);
scrollId = searchResponse.getScrollId();
if (searchResponse.getHits().getHits().length == 0) {
hasNext = false;
Expand All @@ -143,7 +157,6 @@ private List<Row> processSearchResponse(SearchResponse searchResponse) throws IO
rows.add(createRow(hit));
}
return rows;

}

private Row createRow(SearchHit hit) {
Expand All @@ -163,15 +176,13 @@ private Row createRow(SearchHit hit) {

private List<Row> continueScroll() throws IOException {
LOGGER.debug("[{}] Continuing scroll with id [{}]", context.jobId, scrollId);
SearchResponse searchResponse = executeSearchScrollRequest(scrollId);
LOGGER.debug("[{}] Search response was obtained", context.jobId);
return processSearchResponse(searchResponse);
return tryRequestWithSearchResponse(() -> executeSearchScrollRequest(scrollId));
}

private void markScrollAsErrored() {
// This could be a transient error with the scroll Id.
// Reinitialise the scroll and try again but only once.
resetScroll();
scrollId = null;
searchHasShardFailure = true;
}

Expand All @@ -183,11 +194,6 @@ protected SearchResponse executeSearchScrollRequest(String scrollId) {
.get());
}

private void resetScroll() {
clearScroll(scrollId);
scrollId = null;
}

private void clearScroll(String scrollId) {
if (scrollId != null) {
ClearScrollRequest request = new ClearScrollRequest();
Expand Down
Loading