1616import org .elasticsearch .action .support .DefaultShardOperationFailedException ;
1717import org .elasticsearch .action .support .HandledTransportAction ;
1818import org .elasticsearch .action .support .IndicesOptions ;
19+ import org .elasticsearch .action .support .NodeResponseTracker ;
1920import org .elasticsearch .action .support .TransportActions ;
2021import org .elasticsearch .action .support .broadcast .BroadcastRequest ;
2122import org .elasticsearch .action .support .broadcast .BroadcastResponse ;
5152import java .util .List ;
5253import java .util .Map ;
5354import java .util .concurrent .atomic .AtomicInteger ;
54- import java .util .concurrent .atomic .AtomicReferenceArray ;
5555import java .util .function .Consumer ;
5656
5757/**
@@ -118,28 +118,29 @@ public TransportBroadcastByNodeAction(
118118
119119 private Response newResponse (
120120 Request request ,
121- AtomicReferenceArray <?> responses ,
121+ NodeResponseTracker nodeResponseTracker ,
122122 int unavailableShardCount ,
123123 Map <String , List <ShardRouting >> nodes ,
124124 ClusterState clusterState
125- ) {
125+ ) throws NodeResponseTracker . DiscardedResponsesException {
126126 int totalShards = 0 ;
127127 int successfulShards = 0 ;
128128 List <ShardOperationResult > broadcastByNodeResponses = new ArrayList <>();
129129 List <DefaultShardOperationFailedException > exceptions = new ArrayList <>();
130- for (int i = 0 ; i < responses .length (); i ++) {
131- if (responses .get (i )instanceof FailedNodeException exception ) {
130+ for (int i = 0 ; i < nodeResponseTracker .getExpectedResponseCount (); i ++) {
131+ Object response = nodeResponseTracker .getResponse (i );
132+ if (response instanceof FailedNodeException exception ) {
132133 totalShards += nodes .get (exception .nodeId ()).size ();
133134 for (ShardRouting shard : nodes .get (exception .nodeId ())) {
134135 exceptions .add (new DefaultShardOperationFailedException (shard .getIndexName (), shard .getId (), exception ));
135136 }
136137 } else {
137138 @ SuppressWarnings ("unchecked" )
138- NodeResponse response = (NodeResponse ) responses . get ( i ) ;
139- broadcastByNodeResponses .addAll (response .results );
140- totalShards += response .getTotalShards ();
141- successfulShards += response .getSuccessfulShards ();
142- for (BroadcastShardOperationFailedException throwable : response .getExceptions ()) {
139+ NodeResponse nodeResponse = (NodeResponse ) response ;
140+ broadcastByNodeResponses .addAll (nodeResponse .results );
141+ totalShards += nodeResponse .getTotalShards ();
142+ successfulShards += nodeResponse .getSuccessfulShards ();
143+ for (BroadcastShardOperationFailedException throwable : nodeResponse .getExceptions ()) {
143144 if (TransportActions .isShardNotAvailableException (throwable ) == false ) {
144145 exceptions .add (
145146 new DefaultShardOperationFailedException (
@@ -256,16 +257,15 @@ protected void doExecute(Task task, Request request, ActionListener<Response> li
256257 new AsyncAction (task , request , listener ).start ();
257258 }
258259
259- protected class AsyncAction {
260+ protected class AsyncAction implements CancellableTask . CancellationListener {
260261 private final Task task ;
261262 private final Request request ;
262263 private final ActionListener <Response > listener ;
263264 private final ClusterState clusterState ;
264265 private final DiscoveryNodes nodes ;
265266 private final Map <String , List <ShardRouting >> nodeIds ;
266- private final AtomicReferenceArray <Object > responses ;
267- private final AtomicInteger counter = new AtomicInteger ();
268267 private final int unavailableShardCount ;
268+ private final NodeResponseTracker nodeResponseTracker ;
269269
270270 protected AsyncAction (Task task , Request request , ActionListener <Response > listener ) {
271271 this .task = task ;
@@ -312,10 +312,13 @@ protected AsyncAction(Task task, Request request, ActionListener<Response> liste
312312
313313 }
314314 this .unavailableShardCount = unavailableShardCount ;
315- responses = new AtomicReferenceArray <> (nodeIds .size ());
315+ nodeResponseTracker = new NodeResponseTracker (nodeIds .size ());
316316 }
317317
318318 public void start () {
319+ if (task instanceof CancellableTask cancellableTask ) {
320+ cancellableTask .addListener (this );
321+ }
319322 if (nodeIds .size () == 0 ) {
320323 try {
321324 onCompletion ();
@@ -373,38 +376,34 @@ protected void onNodeResponse(DiscoveryNode node, int nodeIndex, NodeResponse re
373376 logger .trace ("received response for [{}] from node [{}]" , actionName , node .getId ());
374377 }
375378
376- // this is defensive to protect against the possibility of double invocation
377- // the current implementation of TransportService#sendRequest guards against this
378- // but concurrency is hard, safety is important, and the small performance loss here does not matter
379- if (responses .compareAndSet (nodeIndex , null , response )) {
380- if (counter .incrementAndGet () == responses .length ()) {
381- onCompletion ();
382- }
379+ if (nodeResponseTracker .trackResponseAndCheckIfLast (nodeIndex , response )) {
380+ onCompletion ();
383381 }
384382 }
385383
386384 protected void onNodeFailure (DiscoveryNode node , int nodeIndex , Throwable t ) {
387385 String nodeId = node .getId ();
388386 logger .debug (new ParameterizedMessage ("failed to execute [{}] on node [{}]" , actionName , nodeId ), t );
389-
390- // this is defensive to protect against the possibility of double invocation
391- // the current implementation of TransportService#sendRequest guards against this
392- // but concurrency is hard, safety is important, and the small performance loss here does not matter
393- if (responses .compareAndSet (nodeIndex , null , new FailedNodeException (nodeId , "Failed node [" + nodeId + "]" , t ))) {
394- if (counter .incrementAndGet () == responses .length ()) {
395- onCompletion ();
396- }
387+ if (nodeResponseTracker .trackResponseAndCheckIfLast (
388+ nodeIndex ,
389+ new FailedNodeException (nodeId , "Failed node [" + nodeId + "]" , t )
390+ )) {
391+ onCompletion ();
397392 }
398393 }
399394
400395 protected void onCompletion () {
401- if (task instanceof CancellableTask && (( CancellableTask ) task ) .notifyIfCancelled (listener )) {
396+ if (( task instanceof CancellableTask t ) && t .notifyIfCancelled (listener )) {
402397 return ;
403398 }
404399
405400 Response response = null ;
406401 try {
407- response = newResponse (request , responses , unavailableShardCount , nodeIds , clusterState );
402+ response = newResponse (request , nodeResponseTracker , unavailableShardCount , nodeIds , clusterState );
403+ } catch (NodeResponseTracker .DiscardedResponsesException e ) {
404+ // We propagate the reason that the results, in this case the task cancellation, in case the listener needs to take
405+ // follow-up actions
406+ listener .onFailure ((Exception ) e .getCause ());
408407 } catch (Exception e ) {
409408 logger .debug ("failed to combine responses from nodes" , e );
410409 listener .onFailure (e );
@@ -417,6 +416,21 @@ protected void onCompletion() {
417416 }
418417 }
419418 }
419+
420+ @ Override
421+ public void onCancelled () {
422+ assert task instanceof CancellableTask : "task must be cancellable" ;
423+ try {
424+ ((CancellableTask ) task ).ensureNotCancelled ();
425+ } catch (TaskCancelledException e ) {
426+ nodeResponseTracker .discardIntermediateResponses (e );
427+ }
428+ }
429+
430+ // For testing purposes
431+ public NodeResponseTracker getNodeResponseTracker () {
432+ return nodeResponseTracker ;
433+ }
420434 }
421435
422436 class BroadcastByNodeTransportRequestHandler implements TransportRequestHandler <NodeRequest > {
0 commit comments