Skip to content

Commit 143620d

Browse files
committed
Track ban parent each channel
1 parent 1d306ef commit 143620d

File tree

3 files changed

+109
-28
lines changed

3 files changed

+109
-28
lines changed

server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ public void messageReceived(final BanParentTaskRequest request, final TransportC
240240
if (request.ban) {
241241
logger.debug("Received ban for the parent [{}] on the node [{}], reason: [{}]", request.parentTaskId,
242242
localNodeId(), request.reason);
243-
final List<CancellableTask> childTasks = taskManager.setBan(request.parentTaskId, request.reason);
243+
final List<CancellableTask> childTasks = taskManager.setBan(request.parentTaskId, request.reason, channel);
244244
final GroupedActionListener<Void> listener = new GroupedActionListener<>(
245245
new ChannelActionListener<>(channel, BAN_PARENT_ACTION_NAME, request).map(r -> TransportResponse.Empty.INSTANCE),
246246
childTasks.size() + 1);

server/src/main/java/org/elasticsearch/tasks/TaskManager.java

Lines changed: 106 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import org.elasticsearch.ElasticsearchException;
3030
import org.elasticsearch.ElasticsearchTimeoutException;
3131
import org.elasticsearch.ExceptionsHelper;
32+
import org.elasticsearch.Version;
3233
import org.elasticsearch.action.ActionListener;
3334
import org.elasticsearch.action.ActionRequest;
3435
import org.elasticsearch.action.ActionResponse;
@@ -47,14 +48,18 @@
4748
import org.elasticsearch.common.util.concurrent.ConcurrentMapLong;
4849
import org.elasticsearch.common.util.concurrent.ThreadContext;
4950
import org.elasticsearch.threadpool.ThreadPool;
51+
import org.elasticsearch.transport.TaskTransportChannel;
5052
import org.elasticsearch.transport.TcpChannel;
53+
import org.elasticsearch.transport.TcpTransportChannel;
5154
import org.elasticsearch.transport.Transport;
55+
import org.elasticsearch.transport.TransportChannel;
5256

5357
import java.io.IOException;
5458
import java.util.ArrayList;
5559
import java.util.Collection;
5660
import java.util.Collections;
5761
import java.util.HashMap;
62+
import java.util.HashSet;
5863
import java.util.Iterator;
5964
import java.util.List;
6065
import java.util.Map;
@@ -65,6 +70,7 @@
6570
import java.util.concurrent.atomic.AtomicBoolean;
6671
import java.util.concurrent.atomic.AtomicLong;
6772
import java.util.function.BiConsumer;
73+
import java.util.function.Consumer;
6874
import java.util.stream.Collectors;
6975
import java.util.stream.StreamSupport;
7076

@@ -91,7 +97,7 @@ public class TaskManager implements ClusterStateApplier {
9197

9298
private final AtomicLong taskIdGenerator = new AtomicLong();
9399

94-
private final Map<TaskId, String> banedParents = new ConcurrentHashMap<>();
100+
private final Map<TaskId, Ban> bannedParents = new ConcurrentHashMap<>();
95101

96102
private TaskResultsService taskResultsService;
97103

@@ -196,12 +202,12 @@ private void registerCancellableTask(Task task) {
196202
assert oldHolder == null;
197203
// Check if this task was banned before we start it. The empty check is used to avoid
198204
// computing the hash code of the parent taskId as most of the time banedParents is empty.
199-
if (task.getParentTaskId().isSet() && banedParents.isEmpty() == false) {
200-
String reason = banedParents.get(task.getParentTaskId());
201-
if (reason != null) {
205+
if (task.getParentTaskId().isSet() && bannedParents.isEmpty() == false) {
206+
final Ban ban = bannedParents.get(task.getParentTaskId());
207+
if (ban != null) {
202208
try {
203-
holder.cancel(reason);
204-
throw new TaskCancelledException("Task cancelled before it started: " + reason);
209+
holder.cancel(ban.reason);
210+
throw new TaskCancelledException("Task cancelled before it started: " + ban.reason);
205211
} finally {
206212
// let's clean up the registration
207213
unregister(task);
@@ -381,7 +387,17 @@ public CancellableTask getCancellableTask(long id) {
381387
* Will be used in task manager stats and for debugging.
382388
*/
383389
public int getBanCount() {
384-
return banedParents.size();
390+
return bannedParents.size();
391+
}
392+
393+
static TcpChannel getTcpChannel(TransportChannel channel) {
394+
if (channel instanceof TaskTransportChannel) {
395+
return getTcpChannel(((TaskTransportChannel) channel).getChannel());
396+
}
397+
if (channel instanceof TcpTransportChannel) {
398+
return ((TcpTransportChannel) channel).getChannel();
399+
}
400+
return null;
385401
}
386402

387403
/**
@@ -390,14 +406,27 @@ public int getBanCount() {
390406
* This method is called when a parent task that has children is cancelled.
391407
* @return a list of pending cancellable child tasks
392408
*/
393-
public List<CancellableTask> setBan(TaskId parentTaskId, String reason) {
409+
public List<CancellableTask> setBan(TaskId parentTaskId, String reason, TransportChannel channel) {
394410
logger.trace("setting ban for the parent task {} {}", parentTaskId, reason);
395-
396-
// Set the ban first, so the newly created tasks cannot be registered
397-
synchronized (banedParents) {
398-
if (lastDiscoveryNodes.nodeExists(parentTaskId.getNodeId())) {
399-
// Only set the ban if the node is the part of the cluster
400-
banedParents.put(parentTaskId, reason);
411+
if (channel.getVersion().onOrAfter(Version.V_8_0_0)) {
412+
// If it does not have tcp channel, then we would be trouble here?
413+
final Ban ban = bannedParents.computeIfAbsent(parentTaskId, k -> new Ban(reason, true));
414+
assert ban.perChannel : "not a ban per channel";
415+
final TcpChannel tcpChannel = getTcpChannel(channel);
416+
if (tcpChannel != null) {
417+
startTrackingChannel(tcpChannel, ban::registerChannel);
418+
} else {
419+
// register a dummy channel to the bane so that we remove it in this situation: the local channel and a tcp channel
420+
// register a ban for the same parent task id and the tcp channel gets disconnected.
421+
ban.registerChannel(new ChannelPendingTaskTracker());
422+
}
423+
} else {
424+
synchronized (bannedParents) {
425+
if (lastDiscoveryNodes.nodeExists(parentTaskId.getNodeId())) {
426+
// Only set the ban if the node is the part of the cluster
427+
final Ban existing = bannedParents.put(parentTaskId, new Ban(reason, false));
428+
assert existing == null || existing.perChannel == false : "not a ban per node";
429+
}
401430
}
402431
}
403432
return cancellableTasks.values().stream()
@@ -413,12 +442,47 @@ public List<CancellableTask> setBan(TaskId parentTaskId, String reason) {
413442
*/
414443
public void removeBan(TaskId parentTaskId) {
415444
logger.trace("removing ban for the parent task {}", parentTaskId);
416-
banedParents.remove(parentTaskId);
445+
bannedParents.remove(parentTaskId);
417446
}
418447

419448
// for testing
420449
public Set<TaskId> getBannedTaskIds() {
421-
return Collections.unmodifiableSet(banedParents.keySet());
450+
return Collections.unmodifiableSet(bannedParents.keySet());
451+
}
452+
453+
private static class Ban {
454+
final String reason;
455+
final boolean perChannel;
456+
final Set<ChannelPendingTaskTracker> channels;
457+
458+
Ban(String reason, boolean perChannel) {
459+
this.reason = reason;
460+
this.perChannel = perChannel;
461+
if (perChannel) {
462+
this.channels = new HashSet<>();
463+
} else {
464+
this.channels = Set.of();
465+
}
466+
}
467+
468+
synchronized boolean registerChannel(ChannelPendingTaskTracker channel) {
469+
assert perChannel : "not a ban per channel";
470+
return channels.add(channel);
471+
}
472+
473+
synchronized boolean unregisterChannel(ChannelPendingTaskTracker channel) {
474+
assert perChannel : "not a ban per channel";
475+
return channels.remove(channel);
476+
}
477+
478+
synchronized int registeredChannels() {
479+
return channels.size();
480+
}
481+
482+
@Override
483+
public String toString() {
484+
return "Ban{" + "reason='" + reason + '\'' + ", perChannel=" + perChannel + ", channels=" + channels + '}';
485+
}
422486
}
423487

424488
/**
@@ -442,15 +506,15 @@ public Collection<Transport.Connection> startBanOnChildTasks(long taskId, Runnab
442506
public void applyClusterState(ClusterChangedEvent event) {
443507
lastDiscoveryNodes = event.state().getNodes();
444508
if (event.nodesRemoved()) {
445-
synchronized (banedParents) {
509+
synchronized (bannedParents) {
446510
lastDiscoveryNodes = event.state().getNodes();
447511
// Remove all bans that were registered by nodes that are no longer in the cluster state
448-
Iterator<TaskId> banIterator = banedParents.keySet().iterator();
512+
final Iterator<Map.Entry<TaskId, Ban>> banIterator = bannedParents.entrySet().iterator();
449513
while (banIterator.hasNext()) {
450-
TaskId taskId = banIterator.next();
451-
if (lastDiscoveryNodes.nodeExists(taskId.getNodeId()) == false) {
452-
logger.debug("Removing ban for the parent [{}] on the node [{}], reason: the parent node is gone", taskId,
453-
event.state().getNodes().getLocalNode());
514+
final Map.Entry<TaskId, Ban> ban = banIterator.next();
515+
if (ban.getValue().registeredChannels() == 0 && lastDiscoveryNodes.nodeExists(ban.getKey().getNodeId()) == false) {
516+
logger.debug("Removing ban for the parent [{}] on the node [{}], reason: the parent node is gone",
517+
ban.getKey(), event.state().getNodes().getLocalNode());
454518
banIterator.remove();
455519
}
456520
}
@@ -617,25 +681,30 @@ Set<Transport.Connection> startBan(Runnable onChildTasksCompleted) {
617681
*/
618682
public Releasable startTrackingCancellableChannelTask(TcpChannel channel, CancellableTask task) {
619683
assert cancellableTasks.containsKey(task.getId()) : "task [" + task.getId() + "] is not registered yet";
684+
final ChannelPendingTaskTracker tracker = startTrackingChannel(channel, trackerChannel -> trackerChannel.addTask(task));
685+
return () -> tracker.removeTask(task);
686+
}
687+
688+
private ChannelPendingTaskTracker startTrackingChannel(TcpChannel channel, Consumer<ChannelPendingTaskTracker> onRegister) {
620689
final ChannelPendingTaskTracker tracker = channelPendingTaskTrackers.compute(channel, (k, curr) -> {
621690
if (curr == null) {
622691
curr = new ChannelPendingTaskTracker();
623692
}
624-
curr.addTask(task);
693+
onRegister.accept(curr);
625694
return curr;
626695
});
627696
if (tracker.registered.compareAndSet(false, true)) {
628697
channel.addCloseListener(ActionListener.wrap(
629698
r -> {
630699
final ChannelPendingTaskTracker removedTracker = channelPendingTaskTrackers.remove(channel);
631700
assert removedTracker == tracker;
632-
cancelTasksOnChannelClosed(tracker.drainTasks());
701+
onChannelClosed(tracker);
633702
},
634703
e -> {
635704
assert false : new AssertionError("must not be here", e);
636705
}));
637706
}
638-
return () -> tracker.removeTask(task);
707+
return tracker;
639708
}
640709

641710
// for testing
@@ -676,7 +745,8 @@ void removeTask(CancellableTask task) {
676745
}
677746
}
678747

679-
private void cancelTasksOnChannelClosed(Set<CancellableTask> tasks) {
748+
private void onChannelClosed(ChannelPendingTaskTracker channel) {
749+
final Set<CancellableTask> tasks = channel.drainTasks();
680750
if (tasks.isEmpty() == false) {
681751
threadPool.generic().execute(new AbstractRunnable() {
682752
@Override
@@ -692,6 +762,16 @@ protected void doRun() {
692762
}
693763
});
694764
}
765+
766+
// Unregister the closing channel and remove bans whose has no registered channels
767+
final Iterator<Map.Entry<TaskId, Ban>> iterator = bannedParents.entrySet().iterator();
768+
while (iterator.hasNext()) {
769+
final Map.Entry<TaskId, Ban> entry = iterator.next();
770+
final Ban ban = entry.getValue();
771+
if (ban.perChannel && ban.unregisterChannel(channel) && entry.getValue().registeredChannels() == 0) {
772+
removeBan(entry.getKey());
773+
}
774+
}
695775
}
696776

697777
public void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener<Void> listener) {

server/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/CancellableTasksTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
import org.elasticsearch.tasks.TaskInfo;
4242
import org.elasticsearch.tasks.TaskManager;
4343
import org.elasticsearch.threadpool.ThreadPool;
44+
import org.elasticsearch.transport.TestTransportChannel;
4445
import org.elasticsearch.transport.TransportRequest;
4546
import org.elasticsearch.transport.TransportService;
4647

@@ -356,7 +357,7 @@ public void testRegisterAndExecuteChildTaskWhileParentTaskIsBeingCanceled() thro
356357
CancellableNodesRequest parentRequest = new CancellableNodesRequest("parent");
357358
final Task parentTask = taskManager.register("test", "test", parentRequest);
358359
final TaskId parentTaskId = parentTask.taskInfo(testNodes[0].getNodeId(), false).getTaskId();
359-
taskManager.setBan(new TaskId(testNodes[0].getNodeId(), parentTask.getId()), "test");
360+
taskManager.setBan(new TaskId(testNodes[0].getNodeId(), parentTask.getId()), "test", new TestTransportChannel(null));
360361
CancellableNodesRequest childRequest = new CancellableNodesRequest("child");
361362
childRequest.setParentTask(parentTaskId);
362363
CancellableTestNodesAction testAction = new CancellableTestNodesAction("internal:testAction", threadPool, testNodes[1]

0 commit comments

Comments
 (0)