Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
import org.elasticsearch.common.xcontent.XContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
Expand Down Expand Up @@ -62,6 +64,11 @@ public class MultiSearchRequest extends ActionRequest implements CompositeIndice

public MultiSearchRequest() {}

@Override
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new SearchTask(id, type, action, "multi_search", parentTaskId, headers);
}

/**
* Add a search request to execute. Note, the order is important, the search response will be returned in the
* same order as the search requests.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ void doExecute(ActionType<Response> action, Request request, ActionListener<Resp
* Execute an {@link ActionType} locally, returning that {@link Task} used to track it, and linking an {@link ActionListener}.
* Prefer this method if you don't need access to the task when listening for the response. This is the method used to
* implement the {@link Client} interface.
*
* @return The {@link Task} to track the action or null if an exception occurs before creating the task.
*/
public < Request extends ActionRequest,
Response extends ActionResponse
Expand All @@ -92,6 +94,8 @@ > Task executeLocally(ActionType<Response> action, Request request, ActionListen
/**
* Execute an {@link ActionType} locally, returning that {@link Task} used to track it, and linking an {@link TaskListener}.
* Prefer this method if you need access to the task when listening for the response.
*
* @return The {@link Task} to track the action or null if an exception occurs before creating the task.
*/
public < Request extends ActionRequest,
Response extends ActionResponse
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,10 @@ public void onFailure(Exception e) {
}
}
});
closeListener.registerTask(taskHolder, new TaskId(client.getLocalNodeId(), task.getId()));
closeListener.maybeRegisterChannel(httpChannel);
if (task != null) {
closeListener.registerTask(taskHolder, new TaskId(client.getLocalNodeId(), task.getId()));
closeListener.maybeRegisterChannel(httpChannel);
}
}

public int getNumChannels() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,13 @@ public Task register(String type, String action, TaskAwareRequest request) {
public <Request extends ActionRequest, Response extends ActionResponse>
Task registerAndExecute(String type, TransportAction<Request, Response> action, Request request,
BiConsumer<Task, Response> onResponse, BiConsumer<Task, Exception> onFailure) {
Task task = register(type, action.actionName, request);
final Task task;
try {
task = register(type, action.actionName, request);
} catch (Exception e) {
onFailure.accept(null, e);
return null;
}
// NOTE: ActionListener cannot infer Response, see https://bugs.openjdk.java.net/browse/JDK-8203195
action.execute(task, request, new ActionListener<Response>() {
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,14 @@ public <T extends TransportResponse> void sendChildRequest(final Transport.Conne
final TransportRequest request, final Task parentTask,
final TransportRequestOptions options,
final TransportResponseHandler<T> handler) {
request.setParentTask(localNode.getId(), parentTask.getId());
if (parentTask.getParentTaskId() != null && parentTask.getParentTaskId().isSet()) {
// the parent task is already a child of another task so we associate the child request with the
// grand-parent in order to be able to cancel the root task efficiently (e.g. cancelling _msearch
// request should cancel all sub-tasks).
request.setParentTask(parentTask.getParentTaskId());
} else {
request.setParentTask(localNode.getId(), parentTask.getId());
}
try {
sendRequest(connection, action, request, options, handler);
} catch (TaskCancelledException ex) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
import org.elasticsearch.rest.action.search.RestMultiSearchAction;
import org.elasticsearch.search.Scroll;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.StreamsUtils;
import org.elasticsearch.test.rest.FakeRestRequest;
Expand Down Expand Up @@ -260,6 +262,13 @@ public void testEqualsAndHashcode() {
checkEqualsAndHashCode(createMultiSearchRequest(), MultiSearchRequestTests::copyRequest, MultiSearchRequestTests::mutate);
}

public void testTaskIsCancellable() {
MultiSearchRequest multiSearchRequest = new MultiSearchRequest();
Task task = multiSearchRequest.createTask(1, "type", "action", null, Collections.emptyMap());
assertThat(task, instanceOf(CancellableTask.class));
assertTrue(((CancellableTask) task).shouldCancelChildrenOnCancellation());
}

private static MultiSearchRequest mutate(MultiSearchRequest searchRequest) throws IOException {
MultiSearchRequest mutation = copyRequest(searchRequest);
List<CheckedRunnable<IOException>> mutators = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,12 @@
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse;
import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksResponse;
import org.elasticsearch.action.bulk.BulkRequestBuilder;
import org.elasticsearch.action.search.MultiSearchAction;
import org.elasticsearch.action.search.MultiSearchRequestBuilder;
import org.elasticsearch.action.search.MultiSearchResponse;
import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchPhaseExecutionException;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.SearchScrollAction;
import org.elasticsearch.action.support.WriteRequest;
Expand All @@ -37,6 +41,7 @@
import org.elasticsearch.script.MockScriptPlugin;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptType;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.lookup.LeafFieldsLookup;
import org.elasticsearch.tasks.TaskInfo;
import org.elasticsearch.test.ESIntegTestCase;
Expand All @@ -52,10 +57,13 @@

import static org.elasticsearch.index.query.QueryBuilders.scriptQuery;
import static org.elasticsearch.search.SearchCancellationIT.ScriptedBlockPlugin.SCRIPT_NAME;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoFailures;
import static org.hamcrest.Matchers.either;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;

@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.SUITE)
public class SearchCancellationIT extends ESIntegTestCase {
Expand All @@ -76,6 +84,9 @@ protected Settings nodeSettings(int nodeOrdinal) {
}

private void indexTestData() {
assertAcked(client().admin().indices().prepareCreate("test")
.setSettings(Settings.builder().put("index.number_of_shards", randomIntBetween(1, 10)))
.get());
for (int i = 0; i < 5; i++) {
// Make sure we have a few segments
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk().setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
Expand Down Expand Up @@ -245,6 +256,72 @@ public void testCancellationOfScrollSearchesOnFollowupRequests() throws Exceptio
client().prepareClearScroll().addScrollId(scrollId).get();
}

public void testMultiSearchQueryPhaseCancellation() throws Exception {
List<ScriptedBlockPlugin> plugins = initBlockFactory();
indexTestData();

logger.info("Executing msearch");
MultiSearchRequestBuilder multiSearchRequestBuilder = client().prepareMultiSearch();
int numSearches = randomIntBetween(1, 5);
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().query(scriptQuery(new Script(
ScriptType.INLINE, "mockscript", SCRIPT_NAME, Collections.emptyMap())));
for (int i = 0; i < numSearches; i++) {
multiSearchRequestBuilder.add(new SearchRequest("test").source(sourceBuilder));
}

ActionFuture<MultiSearchResponse> multiSearchResponseFuture = multiSearchRequestBuilder.execute();

awaitForBlock(plugins);
cancelSearch(MultiSearchAction.NAME);
disableBlocks(plugins);
logger.info("Segments {}", Strings.toString(client().admin().indices().prepareSegments("test").get()));

MultiSearchResponse multiSearchResponse = multiSearchResponseFuture.actionGet();
for (MultiSearchResponse.Item item : multiSearchResponse) {
if (item.isFailure()) {
logger.info("All shards failed with", item.getFailure());
assertThat(item.getFailure(), instanceOf(SearchPhaseExecutionException.class));
} else {
SearchResponse response = item.getResponse();
logger.info("Search response {}", response);
assertNotEquals("At least one shard should have failed", 0, response.getFailedShards());
}
}
}

public void testMultiSearchFetchPhaseCancellation() throws Exception {
List<ScriptedBlockPlugin> plugins = initBlockFactory();
indexTestData();

logger.info("Executing msearch");
MultiSearchRequestBuilder multiSearchRequestBuilder = client().prepareMultiSearch();
int numSearches = 100;//randomIntBetween(1, 5);
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder()
.scriptField("test_field", new Script(ScriptType.INLINE, "mockscript", SCRIPT_NAME, Collections.emptyMap()));
for (int i = 0; i < numSearches; i++) {
multiSearchRequestBuilder.add(new SearchRequest("test").source(sourceBuilder));
}

ActionFuture<MultiSearchResponse> multiSearchResponseFuture = multiSearchRequestBuilder.execute();

awaitForBlock(plugins);
cancelSearch(MultiSearchAction.NAME);
disableBlocks(plugins);
logger.info("Segments {}", Strings.toString(client().admin().indices().prepareSegments("test").get()));

MultiSearchResponse multiSearchResponse = multiSearchResponseFuture.actionGet();
for (MultiSearchResponse.Item item : multiSearchResponse) {
if (item.isFailure()) {
logger.info("All shards failed with", item.getFailure());
assertThat(item.getFailure(), either(instanceOf(SearchPhaseExecutionException.class))
.or(instanceOf(IllegalStateException.class)));
} else {
SearchResponse response = item.getResponse();
logger.info("Search response {}", response);
assertNotEquals("At least one shard should have failed", 0, response.getFailedShards());
}
}
}

public static class ScriptedBlockPlugin extends MockScriptPlugin {
static final String SCRIPT_NAME = "search_block";
Expand Down