2929import org .elasticsearch .ElasticsearchException ;
3030import org .elasticsearch .ElasticsearchTimeoutException ;
3131import org .elasticsearch .ExceptionsHelper ;
32+ import org .elasticsearch .Version ;
3233import org .elasticsearch .action .ActionListener ;
3334import org .elasticsearch .action .ActionRequest ;
3435import org .elasticsearch .action .ActionResponse ;
4748import org .elasticsearch .common .util .concurrent .ConcurrentMapLong ;
4849import org .elasticsearch .common .util .concurrent .ThreadContext ;
4950import org .elasticsearch .threadpool .ThreadPool ;
51+ import org .elasticsearch .transport .TaskTransportChannel ;
5052import org .elasticsearch .transport .TcpChannel ;
53+ import org .elasticsearch .transport .TcpTransportChannel ;
5154import org .elasticsearch .transport .Transport ;
55+ import org .elasticsearch .transport .TransportChannel ;
5256
5357import java .io .IOException ;
5458import java .util .ArrayList ;
5559import java .util .Collection ;
5660import java .util .Collections ;
5761import java .util .HashMap ;
62+ import java .util .HashSet ;
5863import java .util .Iterator ;
5964import java .util .List ;
6065import java .util .Map ;
6570import java .util .concurrent .atomic .AtomicBoolean ;
6671import java .util .concurrent .atomic .AtomicLong ;
6772import java .util .function .BiConsumer ;
73+ import java .util .function .Consumer ;
6874import java .util .stream .Collectors ;
6975import java .util .stream .StreamSupport ;
7076
@@ -91,7 +97,7 @@ public class TaskManager implements ClusterStateApplier {
9197
9298 private final AtomicLong taskIdGenerator = new AtomicLong ();
9399
94- private final Map <TaskId , String > banedParents = new ConcurrentHashMap <>();
100+ private final Map <TaskId , Ban > bannedParents = new ConcurrentHashMap <>();
95101
96102 private TaskResultsService taskResultsService ;
97103
@@ -196,12 +202,12 @@ private void registerCancellableTask(Task task) {
196202 assert oldHolder == null ;
197203 // Check if this task was banned before we start it. The empty check is used to avoid
198204 // computing the hash code of the parent taskId as most of the time banedParents is empty.
199- if (task .getParentTaskId ().isSet () && banedParents .isEmpty () == false ) {
200- String reason = banedParents .get (task .getParentTaskId ());
201- if (reason != null ) {
205+ if (task .getParentTaskId ().isSet () && bannedParents .isEmpty () == false ) {
206+ final Ban ban = bannedParents .get (task .getParentTaskId ());
207+ if (ban != null ) {
202208 try {
203- holder .cancel (reason );
204- throw new TaskCancelledException ("Task cancelled before it started: " + reason );
209+ holder .cancel (ban . reason );
210+ throw new TaskCancelledException ("Task cancelled before it started: " + ban . reason );
205211 } finally {
206212 // let's clean up the registration
207213 unregister (task );
@@ -381,7 +387,17 @@ public CancellableTask getCancellableTask(long id) {
381387 * Will be used in task manager stats and for debugging.
382388 */
383389 public int getBanCount () {
384- return banedParents .size ();
390+ return bannedParents .size ();
391+ }
392+
393+ static TcpChannel getTcpChannel (TransportChannel channel ) {
394+ if (channel instanceof TaskTransportChannel ) {
395+ return getTcpChannel (((TaskTransportChannel ) channel ).getChannel ());
396+ }
397+ if (channel instanceof TcpTransportChannel ) {
398+ return ((TcpTransportChannel ) channel ).getChannel ();
399+ }
400+ return null ;
385401 }
386402
387403 /**
@@ -390,14 +406,27 @@ public int getBanCount() {
390406 * This method is called when a parent task that has children is cancelled.
391407 * @return a list of pending cancellable child tasks
392408 */
393- public List <CancellableTask > setBan (TaskId parentTaskId , String reason ) {
409+ public List <CancellableTask > setBan (TaskId parentTaskId , String reason , TransportChannel channel ) {
394410 logger .trace ("setting ban for the parent task {} {}" , parentTaskId , reason );
395-
396- // Set the ban first, so the newly created tasks cannot be registered
397- synchronized (banedParents ) {
398- if (lastDiscoveryNodes .nodeExists (parentTaskId .getNodeId ())) {
399- // Only set the ban if the node is the part of the cluster
400- banedParents .put (parentTaskId , reason );
411+ if (channel .getVersion ().onOrAfter (Version .V_8_0_0 )) {
412+ // If it does not have tcp channel, then we would be trouble here?
413+ final Ban ban = bannedParents .computeIfAbsent (parentTaskId , k -> new Ban (reason , true ));
414+ assert ban .perChannel : "not a ban per channel" ;
415+ final TcpChannel tcpChannel = getTcpChannel (channel );
416+ if (tcpChannel != null ) {
417+ startTrackingChannel (tcpChannel , ban ::registerChannel );
418+ } else {
419+ // register a dummy channel to the bane so that we remove it in this situation: the local channel and a tcp channel
420+ // register a ban for the same parent task id and the tcp channel gets disconnected.
421+ ban .registerChannel (new ChannelPendingTaskTracker ());
422+ }
423+ } else {
424+ synchronized (bannedParents ) {
425+ if (lastDiscoveryNodes .nodeExists (parentTaskId .getNodeId ())) {
426+ // Only set the ban if the node is the part of the cluster
427+ final Ban existing = bannedParents .put (parentTaskId , new Ban (reason , false ));
428+ assert existing == null || existing .perChannel == false : "not a ban per node" ;
429+ }
401430 }
402431 }
403432 return cancellableTasks .values ().stream ()
@@ -413,12 +442,47 @@ public List<CancellableTask> setBan(TaskId parentTaskId, String reason) {
413442 */
414443 public void removeBan (TaskId parentTaskId ) {
415444 logger .trace ("removing ban for the parent task {}" , parentTaskId );
416- banedParents .remove (parentTaskId );
445+ bannedParents .remove (parentTaskId );
417446 }
418447
419448 // for testing
420449 public Set <TaskId > getBannedTaskIds () {
421- return Collections .unmodifiableSet (banedParents .keySet ());
450+ return Collections .unmodifiableSet (bannedParents .keySet ());
451+ }
452+
453+ private static class Ban {
454+ final String reason ;
455+ final boolean perChannel ;
456+ final Set <ChannelPendingTaskTracker > channels ;
457+
458+ Ban (String reason , boolean perChannel ) {
459+ this .reason = reason ;
460+ this .perChannel = perChannel ;
461+ if (perChannel ) {
462+ this .channels = new HashSet <>();
463+ } else {
464+ this .channels = Set .of ();
465+ }
466+ }
467+
468+ synchronized boolean registerChannel (ChannelPendingTaskTracker channel ) {
469+ assert perChannel : "not a ban per channel" ;
470+ return channels .add (channel );
471+ }
472+
473+ synchronized boolean unregisterChannel (ChannelPendingTaskTracker channel ) {
474+ assert perChannel : "not a ban per channel" ;
475+ return channels .remove (channel );
476+ }
477+
478+ synchronized int registeredChannels () {
479+ return channels .size ();
480+ }
481+
482+ @ Override
483+ public String toString () {
484+ return "Ban{" + "reason='" + reason + '\'' + ", perChannel=" + perChannel + ", channels=" + channels + '}' ;
485+ }
422486 }
423487
424488 /**
@@ -442,15 +506,15 @@ public Collection<Transport.Connection> startBanOnChildTasks(long taskId, Runnab
442506 public void applyClusterState (ClusterChangedEvent event ) {
443507 lastDiscoveryNodes = event .state ().getNodes ();
444508 if (event .nodesRemoved ()) {
445- synchronized (banedParents ) {
509+ synchronized (bannedParents ) {
446510 lastDiscoveryNodes = event .state ().getNodes ();
447511 // Remove all bans that were registered by nodes that are no longer in the cluster state
448- Iterator <TaskId > banIterator = banedParents . keySet ().iterator ();
512+ final Iterator <Map . Entry < TaskId , Ban >> banIterator = bannedParents . entrySet ().iterator ();
449513 while (banIterator .hasNext ()) {
450- TaskId taskId = banIterator .next ();
451- if (lastDiscoveryNodes .nodeExists (taskId .getNodeId ()) == false ) {
452- logger .debug ("Removing ban for the parent [{}] on the node [{}], reason: the parent node is gone" , taskId ,
453- event .state ().getNodes ().getLocalNode ());
514+ final Map . Entry < TaskId , Ban > ban = banIterator .next ();
515+ if (ban . getValue (). registeredChannels () == 0 && lastDiscoveryNodes .nodeExists (ban . getKey () .getNodeId ()) == false ) {
516+ logger .debug ("Removing ban for the parent [{}] on the node [{}], reason: the parent node is gone" ,
517+ ban . getKey (), event .state ().getNodes ().getLocalNode ());
454518 banIterator .remove ();
455519 }
456520 }
@@ -617,25 +681,30 @@ Set<Transport.Connection> startBan(Runnable onChildTasksCompleted) {
617681 */
618682 public Releasable startTrackingCancellableChannelTask (TcpChannel channel , CancellableTask task ) {
619683 assert cancellableTasks .containsKey (task .getId ()) : "task [" + task .getId () + "] is not registered yet" ;
684+ final ChannelPendingTaskTracker tracker = startTrackingChannel (channel , trackerChannel -> trackerChannel .addTask (task ));
685+ return () -> tracker .removeTask (task );
686+ }
687+
688+ private ChannelPendingTaskTracker startTrackingChannel (TcpChannel channel , Consumer <ChannelPendingTaskTracker > onRegister ) {
620689 final ChannelPendingTaskTracker tracker = channelPendingTaskTrackers .compute (channel , (k , curr ) -> {
621690 if (curr == null ) {
622691 curr = new ChannelPendingTaskTracker ();
623692 }
624- curr . addTask ( task );
693+ onRegister . accept ( curr );
625694 return curr ;
626695 });
627696 if (tracker .registered .compareAndSet (false , true )) {
628697 channel .addCloseListener (ActionListener .wrap (
629698 r -> {
630699 final ChannelPendingTaskTracker removedTracker = channelPendingTaskTrackers .remove (channel );
631700 assert removedTracker == tracker ;
632- cancelTasksOnChannelClosed (tracker . drainTasks () );
701+ onChannelClosed (tracker );
633702 },
634703 e -> {
635704 assert false : new AssertionError ("must not be here" , e );
636705 }));
637706 }
638- return () -> tracker . removeTask ( task ) ;
707+ return tracker ;
639708 }
640709
641710 // for testing
@@ -676,7 +745,8 @@ void removeTask(CancellableTask task) {
676745 }
677746 }
678747
679- private void cancelTasksOnChannelClosed (Set <CancellableTask > tasks ) {
748+ private void onChannelClosed (ChannelPendingTaskTracker channel ) {
749+ final Set <CancellableTask > tasks = channel .drainTasks ();
680750 if (tasks .isEmpty () == false ) {
681751 threadPool .generic ().execute (new AbstractRunnable () {
682752 @ Override
@@ -692,6 +762,16 @@ protected void doRun() {
692762 }
693763 });
694764 }
765+
766+ // Unregister the closing channel and remove bans whose has no registered channels
767+ final Iterator <Map .Entry <TaskId , Ban >> iterator = bannedParents .entrySet ().iterator ();
768+ while (iterator .hasNext ()) {
769+ final Map .Entry <TaskId , Ban > entry = iterator .next ();
770+ final Ban ban = entry .getValue ();
771+ if (ban .perChannel && ban .unregisterChannel (channel ) && entry .getValue ().registeredChannels () == 0 ) {
772+ removeBan (entry .getKey ());
773+ }
774+ }
695775 }
696776
697777 public void cancelTaskAndDescendants (CancellableTask task , String reason , boolean waitForCompletion , ActionListener <Void > listener ) {
0 commit comments