Skip to content

Commit b4f1851

Browse files
authored
Add support for task cancellation to TransportMasterNodeAction (#72157)
1 parent 351a824 commit b4f1851

File tree

2 files changed

+148
-3
lines changed

2 files changed

+148
-3
lines changed

server/src/main/java/org/elasticsearch/action/support/master/TransportMasterNodeAction.java

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,15 @@
3131
import org.elasticsearch.common.unit.TimeValue;
3232
import org.elasticsearch.discovery.MasterNotDiscoveredException;
3333
import org.elasticsearch.node.NodeClosedException;
34+
import org.elasticsearch.tasks.CancellableTask;
3435
import org.elasticsearch.tasks.Task;
3536
import org.elasticsearch.threadpool.ThreadPool;
3637
import org.elasticsearch.transport.ConnectTransportException;
3738
import org.elasticsearch.transport.RemoteTransportException;
3839
import org.elasticsearch.transport.TransportException;
3940
import org.elasticsearch.transport.TransportService;
4041

42+
import java.util.concurrent.CancellationException;
4143
import java.util.function.Predicate;
4244

4345
/**
@@ -82,6 +84,15 @@ protected TransportMasterNodeAction(String actionName, boolean canTripCircuitBre
8284
protected abstract void masterOperation(Task task, Request request, ClusterState state,
8385
ActionListener<Response> listener) throws Exception;
8486

87+
private void executeMasterOperation(Task task, Request request, ClusterState state,
88+
ActionListener<Response> listener) throws Exception {
89+
if (task instanceof CancellableTask && ((CancellableTask) task).isCancelled()) {
90+
throw new CancellationException("Task was cancelled");
91+
}
92+
93+
masterOperation(task, request, state, listener);
94+
}
95+
8596
protected boolean localExecute(Request request) {
8697
return false;
8798
}
@@ -114,6 +125,10 @@ class AsyncSingleAction {
114125
}
115126

116127
protected void doStart(ClusterState clusterState) {
128+
if (isTaskCancelled()) {
129+
listener.onFailure(new CancellationException("Task was cancelled"));
130+
return;
131+
}
117132
try {
118133
final DiscoveryNodes nodes = clusterState.nodes();
119134
if (nodes.isLocalNodeElectedMaster() || localExecute(request)) {
@@ -148,7 +163,7 @@ protected void doStart(ClusterState clusterState) {
148163
}
149164
});
150165
threadPool.executor(executor)
151-
.execute(ActionRunnable.wrap(delegate, l -> masterOperation(task, request, clusterState, l)));
166+
.execute(ActionRunnable.wrap(delegate, l -> executeMasterOperation(task, request, clusterState, l)));
152167
}
153168
} else {
154169
if (nodes.getMasterNode() == null) {
@@ -218,7 +233,11 @@ public void onTimeout(TimeValue timeout) {
218233
actionName, timeout), failure);
219234
listener.onFailure(new MasterNotDiscoveredException(failure));
220235
}
221-
}, statePredicate);
236+
}, clusterState -> isTaskCancelled() || statePredicate.test(clusterState));
237+
}
238+
239+
private boolean isTaskCancelled() {
240+
return task instanceof CancellableTask && ((CancellableTask) task).isCancelled();
222241
}
223242
}
224243
}

server/src/test/java/org/elasticsearch/action/support/master/TransportMasterNodeActionTests.java

Lines changed: 127 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,17 @@
3131
import org.elasticsearch.cluster.service.ClusterService;
3232
import org.elasticsearch.common.io.stream.StreamInput;
3333
import org.elasticsearch.common.io.stream.StreamOutput;
34+
import org.elasticsearch.common.settings.Settings;
3435
import org.elasticsearch.common.unit.TimeValue;
36+
import org.elasticsearch.common.util.concurrent.EsThreadPoolExecutor;
3537
import org.elasticsearch.discovery.MasterNotDiscoveredException;
3638
import org.elasticsearch.indices.TestIndexNameExpressionResolver;
3739
import org.elasticsearch.node.NodeClosedException;
3840
import org.elasticsearch.rest.RestStatus;
41+
import org.elasticsearch.tasks.CancellableTask;
3942
import org.elasticsearch.tasks.Task;
43+
import org.elasticsearch.tasks.TaskId;
44+
import org.elasticsearch.tasks.TaskManager;
4045
import org.elasticsearch.test.ESTestCase;
4146
import org.elasticsearch.test.transport.CapturingTransport;
4247
import org.elasticsearch.threadpool.TestThreadPool;
@@ -51,8 +56,12 @@
5156
import java.io.IOException;
5257
import java.util.Collections;
5358
import java.util.HashSet;
59+
import java.util.Map;
5460
import java.util.Objects;
5561
import java.util.Set;
62+
import java.util.concurrent.CancellationException;
63+
import java.util.concurrent.CountDownLatch;
64+
import java.util.concurrent.CyclicBarrier;
5665
import java.util.concurrent.ExecutionException;
5766
import java.util.concurrent.TimeUnit;
5867

@@ -126,6 +135,11 @@ public static class Request extends MasterNodeRequest<Request> {
126135
public ActionRequestValidationException validate() {
127136
return null;
128137
}
138+
139+
@Override
140+
public Task createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
141+
return new CancellableTask(id, type, action, "", parentTaskId, headers);
142+
}
129143
}
130144

131145
class Response extends ActionResponse {
@@ -160,12 +174,18 @@ public void writeTo(StreamOutput out) throws IOException {
160174
class Action extends TransportMasterNodeAction<Request, Response> {
161175
Action(String actionName, TransportService transportService, ClusterService clusterService,
162176
ThreadPool threadPool) {
177+
this(actionName, transportService, clusterService, threadPool, ThreadPool.Names.SAME);
178+
}
179+
180+
Action(String actionName, TransportService transportService, ClusterService clusterService,
181+
ThreadPool threadPool, String executor) {
163182
super(actionName, transportService, clusterService, threadPool,
164183
new ActionFilters(new HashSet<>()), Request::new,
165184
TestIndexNameExpressionResolver.newInstance(), Response::new,
166-
ThreadPool.Names.SAME);
185+
executor);
167186
}
168187

188+
169189
@Override
170190
protected void doExecute(Task task, final Request request, ActionListener<Response> listener) {
171191
// remove unneeded threading by wrapping listener with SAME to prevent super.doExecute from wrapping it with LISTENER
@@ -460,4 +480,110 @@ protected void masterOperation(Task task, Request request, ClusterState state,
460480
assertTrue(listener.isDone());
461481
assertThat(listener.get(), equalTo(response));
462482
}
483+
484+
public void testTaskCancellation() {
485+
ClusterBlock block = new ClusterBlock(1,
486+
"",
487+
true,
488+
true,
489+
false,
490+
randomFrom(RestStatus.values()),
491+
ClusterBlockLevel.ALL
492+
);
493+
ClusterState stateWithBlock = ClusterState.builder(ClusterStateCreationUtils.state(localNode, localNode, allNodes))
494+
.blocks(ClusterBlocks.builder().addGlobalBlock(block)).build();
495+
496+
// Update the cluster state with a block so the request waits until it's unblocked
497+
setState(clusterService, stateWithBlock);
498+
499+
TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Collections.emptySet());
500+
501+
Request request = new Request();
502+
final CancellableTask task = (CancellableTask) taskManager.register("type", "internal:testAction", request);
503+
504+
boolean cancelBeforeStart = randomBoolean();
505+
if (cancelBeforeStart) {
506+
taskManager.cancel(task, "", () -> {});
507+
assertThat(task.isCancelled(), equalTo(true));
508+
}
509+
510+
PlainActionFuture<Response> listener = new PlainActionFuture<>();
511+
ActionTestUtils.execute(new Action("internal:testAction", transportService, clusterService, threadPool) {
512+
@Override
513+
protected ClusterBlockException checkBlock(Request request, ClusterState state) {
514+
Set<ClusterBlock> blocks = state.blocks().global();
515+
return blocks.isEmpty() ? null : new ClusterBlockException(blocks);
516+
}
517+
}, task, request, listener);
518+
519+
final int genericThreads = threadPool.info(ThreadPool.Names.GENERIC).getMax();
520+
final EsThreadPoolExecutor executor = (EsThreadPoolExecutor) threadPool.executor(ThreadPool.Names.GENERIC);
521+
final CyclicBarrier barrier = new CyclicBarrier(genericThreads + 1);
522+
final CountDownLatch latch = new CountDownLatch(1);
523+
524+
if (cancelBeforeStart == false) {
525+
assertThat(listener.isDone(), equalTo(false));
526+
527+
taskManager.cancel(task, "", () -> {});
528+
assertThat(task.isCancelled(), equalTo(true));
529+
530+
// Now that the task is cancelled, let the request to be executed
531+
final ClusterState.Builder newStateBuilder = ClusterState.builder(stateWithBlock);
532+
533+
// Either unblock the cluster state or just do an unrelated cluster state change that will check
534+
// if the task has been cancelled
535+
if (randomBoolean()) {
536+
newStateBuilder.blocks(ClusterBlocks.EMPTY_CLUSTER_BLOCK);
537+
} else {
538+
newStateBuilder.incrementVersion();
539+
}
540+
setState(clusterService, newStateBuilder.build());
541+
}
542+
expectThrows(CancellationException.class, listener::actionGet);
543+
}
544+
545+
public void testTaskCancellationOnceActionItIsDispatchedToMaster() throws Exception {
546+
TaskManager taskManager = new TaskManager(Settings.EMPTY, threadPool, Collections.emptySet());
547+
548+
Request request = new Request();
549+
final CancellableTask task = (CancellableTask) taskManager.register("type", "internal:testAction", request);
550+
551+
// Block all the threads of the executor in which the master operation will be dispatched to
552+
// ensure that the master operation won't be executed until the threads are released
553+
final String executorName = ThreadPool.Names.GENERIC;
554+
final Runnable releaseBlockedThreads = blockAllThreads(executorName);
555+
556+
PlainActionFuture<Response> listener = new PlainActionFuture<>();
557+
ActionTestUtils.execute(new Action("internal:testAction", transportService, clusterService, threadPool, executorName),
558+
task,
559+
request,
560+
listener
561+
);
562+
563+
taskManager.cancel(task, "", () -> {});
564+
assertThat(task.isCancelled(), equalTo(true));
565+
566+
releaseBlockedThreads.run();
567+
568+
expectThrows(CancellationException.class, listener::actionGet);
569+
}
570+
571+
private Runnable blockAllThreads(String executorName) throws Exception {
572+
final int numberOfThreads = threadPool.info(executorName).getMax();
573+
final EsThreadPoolExecutor executor = (EsThreadPoolExecutor) threadPool.executor(executorName);
574+
final CyclicBarrier barrier = new CyclicBarrier(numberOfThreads + 1);
575+
final CountDownLatch latch = new CountDownLatch(1);
576+
for (int i = 0; i < numberOfThreads; i++) {
577+
executor.submit(() -> {
578+
try {
579+
barrier.await();
580+
latch.await();
581+
} catch (Exception e) {
582+
throw new AssertionError(e);
583+
}
584+
});
585+
}
586+
barrier.await();
587+
return latch::countDown;
588+
}
463589
}

0 commit comments

Comments
 (0)