|
31 | 31 | import org.elasticsearch.cluster.service.ClusterService; |
32 | 32 | import org.elasticsearch.common.io.stream.StreamInput; |
33 | 33 | import org.elasticsearch.common.io.stream.StreamOutput; |
| 34 | +import org.elasticsearch.common.settings.Settings; |
34 | 35 | import org.elasticsearch.common.unit.TimeValue; |
| 36 | +import org.elasticsearch.common.util.concurrent.EsThreadPoolExecutor; |
35 | 37 | import org.elasticsearch.discovery.MasterNotDiscoveredException; |
36 | 38 | import org.elasticsearch.indices.TestIndexNameExpressionResolver; |
37 | 39 | import org.elasticsearch.node.NodeClosedException; |
38 | 40 | import org.elasticsearch.rest.RestStatus; |
| 41 | +import org.elasticsearch.tasks.CancellableTask; |
39 | 42 | import org.elasticsearch.tasks.Task; |
| 43 | +import org.elasticsearch.tasks.TaskId; |
| 44 | +import org.elasticsearch.tasks.TaskManager; |
40 | 45 | import org.elasticsearch.test.ESTestCase; |
41 | 46 | import org.elasticsearch.test.transport.CapturingTransport; |
42 | 47 | import org.elasticsearch.threadpool.TestThreadPool; |
|
51 | 56 | import java.io.IOException; |
52 | 57 | import java.util.Collections; |
53 | 58 | import java.util.HashSet; |
| 59 | +import java.util.Map; |
54 | 60 | import java.util.Objects; |
55 | 61 | import java.util.Set; |
| 62 | +import java.util.concurrent.CancellationException; |
| 63 | +import java.util.concurrent.CountDownLatch; |
| 64 | +import java.util.concurrent.CyclicBarrier; |
56 | 65 | import java.util.concurrent.ExecutionException; |
57 | 66 | import java.util.concurrent.TimeUnit; |
58 | 67 |
|
@@ -126,6 +135,11 @@ public static class Request extends MasterNodeRequest<Request> { |
126 | 135 | public ActionRequestValidationException validate() { |
127 | 136 | return null; |
128 | 137 | } |
| 138 | + |
| 139 | + @Override |
| 140 | + public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) { |
| 141 | + return new CancellableTask(id, type, action, "", parentTaskId, headers); |
| 142 | + } |
129 | 143 | } |
130 | 144 |
|
131 | 145 | class Response extends ActionResponse { |
@@ -160,12 +174,18 @@ public void writeTo(StreamOutput out) throws IOException { |
160 | 174 | class Action extends TransportMasterNodeAction<Request, Response> { |
161 | 175 | Action(String actionName, TransportService transportService, ClusterService clusterService, |
162 | 176 | ThreadPool threadPool) { |
| 177 | + this(actionName, transportService, clusterService, threadPool, ThreadPool.Names.SAME); |
| 178 | + } |
| 179 | + |
| 180 | + Action(String actionName, TransportService transportService, ClusterService clusterService, |
| 181 | + ThreadPool threadPool, String executor) { |
163 | 182 | super(actionName, transportService, clusterService, threadPool, |
164 | 183 | new ActionFilters(new HashSet<>()), Request::new, |
165 | 184 | TestIndexNameExpressionResolver.newInstance(), Response::new, |
166 | | - ThreadPool.Names.SAME); |
| 185 | + executor); |
167 | 186 | } |
168 | 187 |
|
| 188 | + |
169 | 189 | @Override |
170 | 190 | protected void doExecute(Task task, final Request request, ActionListener<Response> listener) { |
171 | 191 | // remove unneeded threading by wrapping listener with SAME to prevent super.doExecute from wrapping it with LISTENER |
@@ -460,4 +480,110 @@ protected void masterOperation(Task task, Request request, ClusterState state, |
460 | 480 | assertTrue(listener.isDone()); |
461 | 481 | assertThat(listener.get(), equalTo(response)); |
462 | 482 | } |
| 483 | + |
| 484 | + public void testTaskCancellation() { |
| 485 | + ClusterBlock block = new ClusterBlock(1, |
| 486 | + "", |
| 487 | + true, |
| 488 | + true, |
| 489 | + false, |
| 490 | + randomFrom(RestStatus.values()), |
| 491 | + ClusterBlockLevel.ALL |
| 492 | + ); |
| 493 | + ClusterState stateWithBlock = ClusterState.builder(ClusterStateCreationUtils.state(localNode, localNode, allNodes)) |
| 494 | + .blocks(ClusterBlocks.builder().addGlobalBlock(block)).build(); |
| 495 | + |
| 496 | + // Update the cluster state with a block so the request waits until it's unblocked |
| 497 | + setState(clusterService, stateWithBlock); |
| 498 | + |
| 499 | + TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Collections.emptySet()); |
| 500 | + |
| 501 | + Request request = new Request(); |
| 502 | + final CancellableTask task = (CancellableTask) taskManager.register("type", "internal:testAction", request); |
| 503 | + |
| 504 | + boolean cancelBeforeStart = randomBoolean(); |
| 505 | + if (cancelBeforeStart) { |
| 506 | + taskManager.cancel(task, "", () -> {}); |
| 507 | + assertThat(task.isCancelled(), equalTo(true)); |
| 508 | + } |
| 509 | + |
| 510 | + PlainActionFuture<Response> listener = new PlainActionFuture<>(); |
| 511 | + ActionTestUtils.execute(new Action("internal:testAction", transportService, clusterService, threadPool) { |
| 512 | + @Override |
| 513 | + protected ClusterBlockException checkBlock(Request request, ClusterState state) { |
| 514 | + Set<ClusterBlock> blocks = state.blocks().global(); |
| 515 | + return blocks.isEmpty() ? null : new ClusterBlockException(blocks); |
| 516 | + } |
| 517 | + }, task, request, listener); |
| 518 | + |
| 519 | + final int genericThreads = threadPool.info(ThreadPool.Names.GENERIC).getMax(); |
| 520 | + final EsThreadPoolExecutor executor = (EsThreadPoolExecutor) threadPool.executor(ThreadPool.Names.GENERIC); |
| 521 | + final CyclicBarrier barrier = new CyclicBarrier(genericThreads + 1); |
| 522 | + final CountDownLatch latch = new CountDownLatch(1); |
| 523 | + |
| 524 | + if (cancelBeforeStart == false) { |
| 525 | + assertThat(listener.isDone(), equalTo(false)); |
| 526 | + |
| 527 | + taskManager.cancel(task, "", () -> {}); |
| 528 | + assertThat(task.isCancelled(), equalTo(true)); |
| 529 | + |
| 530 | + // Now that the task is cancelled, let the request to be executed |
| 531 | + final ClusterState.Builder newStateBuilder = ClusterState.builder(stateWithBlock); |
| 532 | + |
| 533 | + // Either unblock the cluster state or just do an unrelated cluster state change that will check |
| 534 | + // if the task has been cancelled |
| 535 | + if (randomBoolean()) { |
| 536 | + newStateBuilder.blocks(ClusterBlocks.EMPTY_CLUSTER_BLOCK); |
| 537 | + } else { |
| 538 | + newStateBuilder.incrementVersion(); |
| 539 | + } |
| 540 | + setState(clusterService, newStateBuilder.build()); |
| 541 | + } |
| 542 | + expectThrows(CancellationException.class, listener::actionGet); |
| 543 | + } |
| 544 | + |
| 545 | + public void testTaskCancellationOnceActionItIsDispatchedToMaster() throws Exception { |
| 546 | + TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Collections.emptySet()); |
| 547 | + |
| 548 | + Request request = new Request(); |
| 549 | + final CancellableTask task = (CancellableTask) taskManager.register("type", "internal:testAction", request); |
| 550 | + |
| 551 | + // Block all the threads of the executor in which the master operation will be dispatched to |
| 552 | + // ensure that the master operation won't be executed until the threads are released |
| 553 | + final String executorName = ThreadPool.Names.GENERIC; |
| 554 | + final Runnable releaseBlockedThreads = blockAllThreads(executorName); |
| 555 | + |
| 556 | + PlainActionFuture<Response> listener = new PlainActionFuture<>(); |
| 557 | + ActionTestUtils.execute(new Action("internal:testAction", transportService, clusterService, threadPool, executorName), |
| 558 | + task, |
| 559 | + request, |
| 560 | + listener |
| 561 | + ); |
| 562 | + |
| 563 | + taskManager.cancel(task, "", () -> {}); |
| 564 | + assertThat(task.isCancelled(), equalTo(true)); |
| 565 | + |
| 566 | + releaseBlockedThreads.run(); |
| 567 | + |
| 568 | + expectThrows(CancellationException.class, listener::actionGet); |
| 569 | + } |
| 570 | + |
| 571 | + private Runnable blockAllThreads(String executorName) throws Exception { |
| 572 | + final int numberOfThreads = threadPool.info(executorName).getMax(); |
| 573 | + final EsThreadPoolExecutor executor = (EsThreadPoolExecutor) threadPool.executor(executorName); |
| 574 | + final CyclicBarrier barrier = new CyclicBarrier(numberOfThreads + 1); |
| 575 | + final CountDownLatch latch = new CountDownLatch(1); |
| 576 | + for (int i = 0; i < numberOfThreads; i++) { |
| 577 | + executor.submit(() -> { |
| 578 | + try { |
| 579 | + barrier.await(); |
| 580 | + latch.await(); |
| 581 | + } catch (Exception e) { |
| 582 | + throw new AssertionError(e); |
| 583 | + } |
| 584 | + }); |
| 585 | + } |
| 586 | + barrier.await(); |
| 587 | + return latch::countDown; |
| 588 | + } |
463 | 589 | } |
0 commit comments