Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.RemoteClusterService;
import org.elasticsearch.transport.Transport;

import java.util.Map;
import java.util.function.Supplier;
Expand All @@ -54,6 +55,7 @@ public class NodeClient extends AbstractClient {
* {@link #executeLocally(ActionType, ActionRequest, TaskListener)}.
*/
private Supplier<String> localNodeId;
private Transport.Connection localConnection;
private RemoteClusterService remoteClusterService;
private NamedWriteableRegistry namedWriteableRegistry;

Expand All @@ -63,10 +65,12 @@ public NodeClient(Settings settings, ThreadPool threadPool) {

@SuppressWarnings("rawtypes")
public void initialize(Map<ActionType, TransportAction> actions, TaskManager taskManager, Supplier<String> localNodeId,
RemoteClusterService remoteClusterService, NamedWriteableRegistry namedWriteableRegistry) {
Transport.Connection localConnection, RemoteClusterService remoteClusterService,
NamedWriteableRegistry namedWriteableRegistry) {
this.actions = actions;
this.taskManager = taskManager;
this.localNodeId = localNodeId;
this.localConnection = localConnection;
this.remoteClusterService = remoteClusterService;
this.namedWriteableRegistry = namedWriteableRegistry;
}
Expand Down Expand Up @@ -101,7 +105,7 @@ void doExecute(ActionType<Response> action, Request request, ActionListener<Resp
public < Request extends ActionRequest,
Response extends ActionResponse
> Task executeLocally(ActionType<Response> action, Request request, ActionListener<Response> listener) {
return taskManager.registerAndExecute("transport", transportAction(action), request,
return taskManager.registerAndExecute("transport", transportAction(action), request, localConnection,
(t, r) -> {
try {
listener.onResponse(r);
Expand Down Expand Up @@ -129,7 +133,7 @@ > Task executeLocally(ActionType<Response> action, Request request, ActionListen
public < Request extends ActionRequest,
Response extends ActionResponse
> Task executeLocally(ActionType<Response> action, Request request, TaskListener<Response> listener) {
return taskManager.registerAndExecute("transport", transportAction(action), request,
return taskManager.registerAndExecute("transport", transportAction(action), request, localConnection,
listener::onResponse, listener::onFailure);
}

Expand Down
11 changes: 6 additions & 5 deletions server/src/main/java/org/elasticsearch/node/Node.java
Original file line number Diff line number Diff line change
Expand Up @@ -702,11 +702,12 @@ protected Node(final Environment initialEnvironment,
resourcesToClose.addAll(pluginLifecycleComponents);
resourcesToClose.add(injector.getInstance(PeerRecoverySourceService.class));
this.pluginLifecycleComponents = Collections.unmodifiableList(pluginLifecycleComponents);
client.initialize(injector.getInstance(new Key<Map<ActionType, TransportAction>>() {}), transportService.getTaskManager(),
() -> clusterService.localNode().getId(), transportService.getRemoteClusterService(),
namedWriteableRegistry

);
client.initialize(injector.getInstance(new Key<Map<ActionType, TransportAction>>() {}),
transportService.getTaskManager(),
() -> clusterService.localNode().getId(),
transportService.getLocalNodeConnection(),
transportService.getRemoteClusterService(),
namedWriteableRegistry);
this.namedWriteableRegistry = namedWriteableRegistry;

logger.debug("initializing HTTP handlers ...");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,31 @@
import org.elasticsearch.action.StepListener;
import org.elasticsearch.action.support.ChannelActionListener;
import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.EmptyTransportResponseHandler;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportRequestDeduplicator;
import org.elasticsearch.transport.TransportRequestHandler;
import org.elasticsearch.transport.TransportRequestOptions;
import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.transport.TransportService;

import java.io.IOException;
import java.util.Collection;
import java.util.List;
import java.util.Objects;

public class TaskCancellationService {
public static final String BAN_PARENT_ACTION_NAME = "internal:admin/tasks/ban";
private static final Logger logger = LogManager.getLogger(TaskCancellationService.class);
private final TransportService transportService;
private final TaskManager taskManager;
private final TransportRequestDeduplicator<CancelRequest> deduplicator = new TransportRequestDeduplicator<>();

public TaskCancellationService(TransportService transportService) {
this.transportService = transportService;
Expand All @@ -61,35 +65,63 @@ private String localNodeId() {
return transportService.getLocalNode().getId();
}

void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener<Void> listener) {
private static class CancelRequest {
final CancellableTask task;
final boolean waitForCompletion;

CancelRequest(CancellableTask task, boolean waitForCompletion) {
this.task = task;
this.waitForCompletion = waitForCompletion;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
final CancelRequest that = (CancelRequest) o;
return waitForCompletion == that.waitForCompletion && Objects.equals(task, that.task);
}

@Override
public int hashCode() {
return Objects.hash(task, waitForCompletion);
}
}

void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener<Void> finalListener) {
deduplicator.executeOnce(new CancelRequest(task, waitForCompletion), finalListener,
(r, listener) -> doCancelTaskAndDescendants(task, reason, waitForCompletion, listener));
}

void doCancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener<Void> listener) {
final TaskId taskId = task.taskInfo(localNodeId(), false).getTaskId();
if (task.shouldCancelChildrenOnCancellation()) {
logger.trace("cancelling task [{}] and its descendants", taskId);
StepListener<Void> completedListener = new StepListener<>();
GroupedActionListener<Void> groupedListener = new GroupedActionListener<>(completedListener.map(r -> null), 3);
Collection<DiscoveryNode> childrenNodes = taskManager.startBanOnChildrenNodes(task.getId(), () -> {
Collection<Transport.Connection> childConnections = taskManager.startBanOnChildTasks(task.getId(), () -> {
logger.trace("child tasks of parent [{}] are completed", taskId);
groupedListener.onResponse(null);
});
taskManager.cancel(task, reason, () -> {
logger.trace("task [{}] is cancelled", taskId);
groupedListener.onResponse(null);
});
StepListener<Void> banOnNodesListener = new StepListener<>();
setBanOnNodes(reason, waitForCompletion, task, childrenNodes, banOnNodesListener);
banOnNodesListener.whenComplete(groupedListener::onResponse, groupedListener::onFailure);
StepListener<Void> setBanListener = new StepListener<>();
setBanOnChildConnections(reason, waitForCompletion, task, childConnections, setBanListener);
setBanListener.whenComplete(groupedListener::onResponse, groupedListener::onFailure);
// If we start unbanning when the last child task completed and that child task executed with a specific user, then unban
// requests are denied because internal requests can't run with a user. We need to remove bans with the current thread context.
final Runnable removeBansRunnable = transportService.getThreadPool().getThreadContext()
.preserveContext(() -> removeBanOnNodes(task, childrenNodes));
// We remove bans after all child tasks are completed although in theory we can do it on a per-node basis.
.preserveContext(() -> removeBanOnChildConnections(task, childConnections));
// We remove bans after all child tasks are completed although in theory we can do it on a per-connection basis.
completedListener.whenComplete(r -> removeBansRunnable.run(), e -> removeBansRunnable.run());
// if wait_for_completion is true, then only return when (1) bans are placed on child nodes, (2) child tasks are
// completed or failed, (3) the main task is cancelled. Otherwise, return after bans are placed on child nodes.
// if wait_for_completion is true, then only return when (1) bans are placed on child connections, (2) child tasks are
// completed or failed, (3) the main task is cancelled. Otherwise, return after bans are placed on child connections.
if (waitForCompletion) {
completedListener.whenComplete(r -> listener.onResponse(null), listener::onFailure);
} else {
banOnNodesListener.whenComplete(r -> listener.onResponse(null), listener::onFailure);
setBanListener.whenComplete(r -> listener.onResponse(null), listener::onFailure);
}
} else {
logger.trace("task [{}] doesn't have any children that should be cancelled", taskId);
Expand All @@ -102,47 +134,48 @@ void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitF
}
}

private void setBanOnNodes(String reason, boolean waitForCompletion, CancellableTask task,
Collection<DiscoveryNode> childNodes, ActionListener<Void> listener) {
if (childNodes.isEmpty()) {
private void setBanOnChildConnections(String reason, boolean waitForCompletion, CancellableTask task,
Collection<Transport.Connection> childConnections, ActionListener<Void> listener) {
if (childConnections.isEmpty()) {
listener.onResponse(null);
return;
}
final TaskId taskId = new TaskId(localNodeId(), task.getId());
logger.trace("cancelling child tasks of [{}] on child nodes {}", taskId, childNodes);
GroupedActionListener<Void> groupedListener = new GroupedActionListener<>(listener.map(r -> null), childNodes.size());
logger.trace("cancelling child tasks of [{}] on child connections {}", taskId, childConnections);
GroupedActionListener<Void> groupedListener = new GroupedActionListener<>(listener.map(r -> null), childConnections.size());
final BanParentTaskRequest banRequest = BanParentTaskRequest.createSetBanParentTaskRequest(taskId, reason, waitForCompletion);
for (DiscoveryNode node : childNodes) {
transportService.sendRequest(node, BAN_PARENT_ACTION_NAME, banRequest,
for (Transport.Connection connection : childConnections) {
transportService.sendRequest(connection, BAN_PARENT_ACTION_NAME, banRequest, TransportRequestOptions.EMPTY,
new EmptyTransportResponseHandler(ThreadPool.Names.SAME) {
@Override
public void handleResponse(TransportResponse.Empty response) {
logger.trace("sent ban for tasks with the parent [{}] to the node [{}]", taskId, node);
logger.trace("sent ban for tasks with the parent [{}] for connection [{}]", taskId, connection);
groupedListener.onResponse(null);
}

@Override
public void handleException(TransportException exp) {
assert ExceptionsHelper.unwrapCause(exp) instanceof ElasticsearchSecurityException == false;
logger.warn("Cannot send ban for tasks with the parent [{}] to the node [{}]", taskId, node);
logger.warn("Cannot send ban for tasks with the parent [{}] for connection [{}]", taskId, connection);
groupedListener.onFailure(exp);
}
});
}
}

private void removeBanOnNodes(CancellableTask task, Collection<DiscoveryNode> childNodes) {
private void removeBanOnChildConnections(CancellableTask task, Collection<Transport.Connection> childConnections) {
final BanParentTaskRequest request =
BanParentTaskRequest.createRemoveBanParentTaskRequest(new TaskId(localNodeId(), task.getId()));
for (DiscoveryNode node : childNodes) {
logger.trace("Sending remove ban for tasks with the parent [{}] to the node [{}]", request.parentTaskId, node);
transportService.sendRequest(node, BAN_PARENT_ACTION_NAME, request, new EmptyTransportResponseHandler(ThreadPool.Names.SAME) {
@Override
public void handleException(TransportException exp) {
assert ExceptionsHelper.unwrapCause(exp) instanceof ElasticsearchSecurityException == false;
logger.info("failed to remove the parent ban for task {} on node {}", request.parentTaskId, node);
}
});
for (Transport.Connection connection : childConnections) {
logger.trace("Sending remove ban for tasks with the parent [{}] for connection [{}]", request.parentTaskId, connection);
transportService.sendRequest(connection, BAN_PARENT_ACTION_NAME, request, TransportRequestOptions.EMPTY,
new EmptyTransportResponseHandler(ThreadPool.Names.SAME) {
@Override
public void handleException(TransportException exp) {
assert ExceptionsHelper.unwrapCause(exp) instanceof ElasticsearchSecurityException == false;
logger.info("failed to remove the parent ban for task {} for connection {}", request.parentTaskId, connection);
}
});
}
}

Expand Down
Loading