diff --git a/server/src/main/java/org/elasticsearch/action/search/MultiSearchRequest.java b/server/src/main/java/org/elasticsearch/action/search/MultiSearchRequest.java index e6e11e5eeb5c4..616c939fa9bfb 100644 --- a/server/src/main/java/org/elasticsearch/action/search/MultiSearchRequest.java +++ b/server/src/main/java/org/elasticsearch/action/search/MultiSearchRequest.java @@ -33,6 +33,8 @@ import org.elasticsearch.common.xcontent.XContent; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskId; import java.io.ByteArrayOutputStream; import java.io.IOException; @@ -62,6 +64,11 @@ public class MultiSearchRequest extends ActionRequest implements CompositeIndice public MultiSearchRequest() {} + @Override + public Task createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { + return new SearchTask(id, type, action, "multi_search", parentTaskId, headers); + } + /** * Add a search request to execute. Note, the order is important, the search response will be returned in the * same order as the search requests. diff --git a/server/src/main/java/org/elasticsearch/client/node/NodeClient.java b/server/src/main/java/org/elasticsearch/client/node/NodeClient.java index 091e0cdf63b89..c561552a41658 100644 --- a/server/src/main/java/org/elasticsearch/client/node/NodeClient.java +++ b/server/src/main/java/org/elasticsearch/client/node/NodeClient.java @@ -81,6 +81,8 @@ void doExecute(ActionType action, Request request, ActionListener Task executeLocally(ActionType action, Request request, ActionListen /** * Execute an {@link ActionType} locally, returning that {@link Task} used to track it, and linking an {@link TaskListener}. * Prefer this method if you need access to the task when listening for the response. + * + * @return The {@link Task} to track the action or null if an exception occurs before creating the task. */ public < Request extends ActionRequest, Response extends ActionResponse diff --git a/server/src/main/java/org/elasticsearch/rest/action/search/HttpChannelTaskHandler.java b/server/src/main/java/org/elasticsearch/rest/action/search/HttpChannelTaskHandler.java index 5864551854fca..1e46b16e993e7 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/search/HttpChannelTaskHandler.java +++ b/server/src/main/java/org/elasticsearch/rest/action/search/HttpChannelTaskHandler.java @@ -77,8 +77,10 @@ public void onFailure(Exception e) { } } }); - closeListener.registerTask(taskHolder, new TaskId(client.getLocalNodeId(), task.getId())); - closeListener.maybeRegisterChannel(httpChannel); + if (task != null) { + closeListener.registerTask(taskHolder, new TaskId(client.getLocalNodeId(), task.getId())); + closeListener.maybeRegisterChannel(httpChannel); + } } public int getNumChannels() { diff --git a/server/src/main/java/org/elasticsearch/tasks/TaskManager.java b/server/src/main/java/org/elasticsearch/tasks/TaskManager.java index 14fda1fdb85de..080f18d9ce9c0 100644 --- a/server/src/main/java/org/elasticsearch/tasks/TaskManager.java +++ b/server/src/main/java/org/elasticsearch/tasks/TaskManager.java @@ -133,7 +133,13 @@ public Task register(String type, String action, TaskAwareRequest request) { public Task registerAndExecute(String type, TransportAction action, Request request, BiConsumer onResponse, BiConsumer onFailure) { - Task task = register(type, action.actionName, request); + final Task task; + try { + task = register(type, action.actionName, request); + } catch (Exception e) { + onFailure.accept(null, e); + return null; + } // NOTE: ActionListener cannot infer Response, see https://bugs.openjdk.java.net/browse/JDK-8203195 action.execute(task, request, new ActionListener() { @Override diff --git a/server/src/main/java/org/elasticsearch/transport/TransportService.java b/server/src/main/java/org/elasticsearch/transport/TransportService.java index ddd7a0d4cab19..bfb0082491b4a 100644 --- a/server/src/main/java/org/elasticsearch/transport/TransportService.java +++ b/server/src/main/java/org/elasticsearch/transport/TransportService.java @@ -613,7 +613,14 @@ public void sendChildRequest(final Transport.Conne final TransportRequest request, final Task parentTask, final TransportRequestOptions options, final TransportResponseHandler handler) { - request.setParentTask(localNode.getId(), parentTask.getId()); + if (parentTask.getParentTaskId() != null && parentTask.getParentTaskId().isSet()) { + // the parent task is already a child of another task so we associate the child request with the + // grand-parent in order to be able to cancel the root task efficiently (e.g. cancelling _msearch + // request should cancel all sub-tasks). + request.setParentTask(parentTask.getParentTaskId()); + } else { + request.setParentTask(localNode.getId(), parentTask.getId()); + } try { sendRequest(connection, action, request, options, handler); } catch (TaskCancelledException ex) { diff --git a/server/src/test/java/org/elasticsearch/action/search/MultiSearchRequestTests.java b/server/src/test/java/org/elasticsearch/action/search/MultiSearchRequestTests.java index 35f60546bb023..c82131364572d 100644 --- a/server/src/test/java/org/elasticsearch/action/search/MultiSearchRequestTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/MultiSearchRequestTests.java @@ -34,6 +34,8 @@ import org.elasticsearch.rest.action.search.RestMultiSearchAction; import org.elasticsearch.search.Scroll; import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.tasks.CancellableTask; +import org.elasticsearch.tasks.Task; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.StreamsUtils; import org.elasticsearch.test.rest.FakeRestRequest; @@ -260,6 +262,13 @@ public void testEqualsAndHashcode() { checkEqualsAndHashCode(createMultiSearchRequest(), MultiSearchRequestTests::copyRequest, MultiSearchRequestTests::mutate); } + public void testTaskIsCancellable() { + MultiSearchRequest multiSearchRequest = new MultiSearchRequest(); + Task task = multiSearchRequest.createTask(1, "type", "action", null, Collections.emptyMap()); + assertThat(task, instanceOf(CancellableTask.class)); + assertTrue(((CancellableTask) task).shouldCancelChildrenOnCancellation()); + } + private static MultiSearchRequest mutate(MultiSearchRequest searchRequest) throws IOException { MultiSearchRequest mutation = copyRequest(searchRequest); List> mutators = new ArrayList<>(); diff --git a/server/src/test/java/org/elasticsearch/search/SearchCancellationIT.java b/server/src/test/java/org/elasticsearch/search/SearchCancellationIT.java index 6da30714c7509..321b60892bca9 100644 --- a/server/src/test/java/org/elasticsearch/search/SearchCancellationIT.java +++ b/server/src/test/java/org/elasticsearch/search/SearchCancellationIT.java @@ -24,8 +24,12 @@ import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse; import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksResponse; import org.elasticsearch.action.bulk.BulkRequestBuilder; +import org.elasticsearch.action.search.MultiSearchAction; +import org.elasticsearch.action.search.MultiSearchRequestBuilder; +import org.elasticsearch.action.search.MultiSearchResponse; import org.elasticsearch.action.search.SearchAction; import org.elasticsearch.action.search.SearchPhaseExecutionException; +import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchScrollAction; import org.elasticsearch.action.support.WriteRequest; @@ -37,6 +41,7 @@ import org.elasticsearch.script.MockScriptPlugin; import org.elasticsearch.script.Script; import org.elasticsearch.script.ScriptType; +import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.lookup.LeafFieldsLookup; import org.elasticsearch.tasks.TaskInfo; import org.elasticsearch.test.ESIntegTestCase; @@ -52,10 +57,13 @@ import static org.elasticsearch.index.query.QueryBuilders.scriptQuery; import static org.elasticsearch.search.SearchCancellationIT.ScriptedBlockPlugin.SCRIPT_NAME; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailures; +import static org.hamcrest.Matchers.either; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; @ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.SUITE) public class SearchCancellationIT extends ESIntegTestCase { @@ -76,6 +84,9 @@ protected Settings nodeSettings(int nodeOrdinal) { } private void indexTestData() { + assertAcked(client().admin().indices().prepareCreate("test") + .setSettings(Settings.builder().put("index.number_of_shards", randomIntBetween(1, 10))) + .get()); for (int i = 0; i < 5; i++) { // Make sure we have a few segments BulkRequestBuilder bulkRequestBuilder = client().prepareBulk().setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); @@ -245,6 +256,72 @@ public void testCancellationOfScrollSearchesOnFollowupRequests() throws Exceptio client().prepareClearScroll().addScrollId(scrollId).get(); } + public void testMultiSearchQueryPhaseCancellation() throws Exception { + List plugins = initBlockFactory(); + indexTestData(); + + logger.info("Executing msearch"); + MultiSearchRequestBuilder multiSearchRequestBuilder = client().prepareMultiSearch(); + int numSearches = randomIntBetween(1, 5); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().query(scriptQuery(new Script( + ScriptType.INLINE, "mockscript", SCRIPT_NAME, Collections.emptyMap()))); + for (int i = 0; i < numSearches; i++) { + multiSearchRequestBuilder.add(new SearchRequest("test").source(sourceBuilder)); + } + + ActionFuture multiSearchResponseFuture = multiSearchRequestBuilder.execute(); + + awaitForBlock(plugins); + cancelSearch(MultiSearchAction.NAME); + disableBlocks(plugins); + logger.info("Segments {}", Strings.toString(client().admin().indices().prepareSegments("test").get())); + + MultiSearchResponse multiSearchResponse = multiSearchResponseFuture.actionGet(); + for (MultiSearchResponse.Item item : multiSearchResponse) { + if (item.isFailure()) { + logger.info("All shards failed with", item.getFailure()); + assertThat(item.getFailure(), instanceOf(SearchPhaseExecutionException.class)); + } else { + SearchResponse response = item.getResponse(); + logger.info("Search response {}", response); + assertNotEquals("At least one shard should have failed", 0, response.getFailedShards()); + } + } + } + + public void testMultiSearchFetchPhaseCancellation() throws Exception { + List plugins = initBlockFactory(); + indexTestData(); + + logger.info("Executing msearch"); + MultiSearchRequestBuilder multiSearchRequestBuilder = client().prepareMultiSearch(); + int numSearches = 100;//randomIntBetween(1, 5); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder() + .scriptField("test_field", new Script(ScriptType.INLINE, "mockscript", SCRIPT_NAME, Collections.emptyMap())); + for (int i = 0; i < numSearches; i++) { + multiSearchRequestBuilder.add(new SearchRequest("test").source(sourceBuilder)); + } + + ActionFuture multiSearchResponseFuture = multiSearchRequestBuilder.execute(); + + awaitForBlock(plugins); + cancelSearch(MultiSearchAction.NAME); + disableBlocks(plugins); + logger.info("Segments {}", Strings.toString(client().admin().indices().prepareSegments("test").get())); + + MultiSearchResponse multiSearchResponse = multiSearchResponseFuture.actionGet(); + for (MultiSearchResponse.Item item : multiSearchResponse) { + if (item.isFailure()) { + logger.info("All shards failed with", item.getFailure()); + assertThat(item.getFailure(), either(instanceOf(SearchPhaseExecutionException.class)) + .or(instanceOf(IllegalStateException.class))); + } else { + SearchResponse response = item.getResponse(); + logger.info("Search response {}", response); + assertNotEquals("At least one shard should have failed", 0, response.getFailedShards()); + } + } + } public static class ScriptedBlockPlugin extends MockScriptPlugin { static final String SCRIPT_NAME = "search_block";