diff --git a/docs/changelog/82685.yaml b/docs/changelog/82685.yaml new file mode 100644 index 0000000000000..3ef9e7841ba6e --- /dev/null +++ b/docs/changelog/82685.yaml @@ -0,0 +1,6 @@ +pr: 82685 +summary: Discard intermediate results upon cancellation for stats endpoints +area: Stats +type: bug +issues: + - 82337 diff --git a/server/src/main/java/org/elasticsearch/action/support/NodeResponseTracker.java b/server/src/main/java/org/elasticsearch/action/support/NodeResponseTracker.java new file mode 100644 index 0000000000000..aafd6166cb364 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/action/support/NodeResponseTracker.java @@ -0,0 +1,97 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.action.support; + +import java.util.Collection; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReferenceArray; + +/** + * This class tracks the intermediate responses that will be used to create aggregated cluster response to a request. It also gives the + * possibility to discard the intermediate results when asked, for example when the initial request is cancelled, in order to release the + * resources. + */ +public class NodeResponseTracker { + + private final AtomicInteger counter = new AtomicInteger(); + private final int expectedResponsesCount; + private volatile AtomicReferenceArray responses; + private volatile Exception causeOfDiscarding; + + public NodeResponseTracker(int size) { + this.expectedResponsesCount = size; + this.responses = new AtomicReferenceArray<>(size); + } + + public NodeResponseTracker(Collection array) { + this.expectedResponsesCount = array.size(); + this.responses = new AtomicReferenceArray<>(array.toArray()); + } + + /** + * This method discards the results collected so far to free up the resources. + * @param cause the discarding, this will be communicated if they try to access the discarded results + */ + public void discardIntermediateResponses(Exception cause) { + if (responses != null) { + this.causeOfDiscarding = cause; + responses = null; + } + } + + public boolean responsesDiscarded() { + return responses == null; + } + + /** + * This method stores a new node response if the intermediate responses haven't been discarded yet. If the responses are not discarded + * the method asserts that this is the first response encountered from this node to protect from miscounting the responses in case of a + * double invocation. If the responses have been discarded we accept this risk for simplicity. + * @param nodeIndex, the index that represents a single node of the cluster + * @param response, a response can be either a NodeResponse or an error + * @return true if all the nodes' responses have been received, else false + */ + public boolean trackResponseAndCheckIfLast(int nodeIndex, Object response) { + AtomicReferenceArray responses = this.responses; + + if (responsesDiscarded() == false) { + boolean firstEncounter = responses.compareAndSet(nodeIndex, null, response); + assert firstEncounter : "a response should be tracked only once"; + } + return counter.incrementAndGet() == getExpectedResponseCount(); + } + + /** + * Returns the tracked response or null if the response hasn't been received yet for a specific index that represents a node of the + * cluster. + * @throws DiscardedResponsesException if the responses have been discarded + */ + public Object getResponse(int nodeIndex) throws DiscardedResponsesException { + AtomicReferenceArray responses = this.responses; + if (responsesDiscarded()) { + throw new DiscardedResponsesException(causeOfDiscarding); + } + return responses.get(nodeIndex); + } + + public int getExpectedResponseCount() { + return expectedResponsesCount; + } + + /** + * This exception is thrown when the {@link NodeResponseTracker} is asked to give information about the responses after they have been + * discarded. + */ + public static class DiscardedResponsesException extends Exception { + + public DiscardedResponsesException(Exception cause) { + super(cause); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeAction.java b/server/src/main/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeAction.java index 5c5594aa094d6..382c9cf01693e 100644 --- a/server/src/main/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeAction.java +++ b/server/src/main/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeAction.java @@ -16,6 +16,7 @@ import org.elasticsearch.action.support.DefaultShardOperationFailedException; import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.action.support.IndicesOptions; +import org.elasticsearch.action.support.NodeResponseTracker; import org.elasticsearch.action.support.TransportActions; import org.elasticsearch.action.support.broadcast.BroadcastRequest; import org.elasticsearch.action.support.broadcast.BroadcastResponse; @@ -51,7 +52,6 @@ import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReferenceArray; import java.util.function.Consumer; /** @@ -118,28 +118,29 @@ public TransportBroadcastByNodeAction( private Response newResponse( Request request, - AtomicReferenceArray responses, + NodeResponseTracker nodeResponseTracker, int unavailableShardCount, Map> nodes, ClusterState clusterState - ) { + ) throws NodeResponseTracker.DiscardedResponsesException { int totalShards = 0; int successfulShards = 0; List broadcastByNodeResponses = new ArrayList<>(); List exceptions = new ArrayList<>(); - for (int i = 0; i < responses.length(); i++) { - if (responses.get(i)instanceof FailedNodeException exception) { + for (int i = 0; i < nodeResponseTracker.getExpectedResponseCount(); i++) { + Object response = nodeResponseTracker.getResponse(i); + if (response instanceof FailedNodeException exception) { totalShards += nodes.get(exception.nodeId()).size(); for (ShardRouting shard : nodes.get(exception.nodeId())) { exceptions.add(new DefaultShardOperationFailedException(shard.getIndexName(), shard.getId(), exception)); } } else { @SuppressWarnings("unchecked") - NodeResponse response = (NodeResponse) responses.get(i); - broadcastByNodeResponses.addAll(response.results); - totalShards += response.getTotalShards(); - successfulShards += response.getSuccessfulShards(); - for (BroadcastShardOperationFailedException throwable : response.getExceptions()) { + NodeResponse nodeResponse = (NodeResponse) response; + broadcastByNodeResponses.addAll(nodeResponse.results); + totalShards += nodeResponse.getTotalShards(); + successfulShards += nodeResponse.getSuccessfulShards(); + for (BroadcastShardOperationFailedException throwable : nodeResponse.getExceptions()) { if (TransportActions.isShardNotAvailableException(throwable) == false) { exceptions.add( new DefaultShardOperationFailedException( @@ -256,16 +257,15 @@ protected void doExecute(Task task, Request request, ActionListener li new AsyncAction(task, request, listener).start(); } - protected class AsyncAction { + protected class AsyncAction implements CancellableTask.CancellationListener { private final Task task; private final Request request; private final ActionListener listener; private final ClusterState clusterState; private final DiscoveryNodes nodes; private final Map> nodeIds; - private final AtomicReferenceArray responses; - private final AtomicInteger counter = new AtomicInteger(); private final int unavailableShardCount; + private final NodeResponseTracker nodeResponseTracker; protected AsyncAction(Task task, Request request, ActionListener listener) { this.task = task; @@ -312,10 +312,13 @@ protected AsyncAction(Task task, Request request, ActionListener liste } this.unavailableShardCount = unavailableShardCount; - responses = new AtomicReferenceArray<>(nodeIds.size()); + nodeResponseTracker = new NodeResponseTracker(nodeIds.size()); } public void start() { + if (task instanceof CancellableTask cancellableTask) { + cancellableTask.addListener(this); + } if (nodeIds.size() == 0) { try { onCompletion(); @@ -373,38 +376,34 @@ protected void onNodeResponse(DiscoveryNode node, int nodeIndex, NodeResponse re logger.trace("received response for [{}] from node [{}]", actionName, node.getId()); } - // this is defensive to protect against the possibility of double invocation - // the current implementation of TransportService#sendRequest guards against this - // but concurrency is hard, safety is important, and the small performance loss here does not matter - if (responses.compareAndSet(nodeIndex, null, response)) { - if (counter.incrementAndGet() == responses.length()) { - onCompletion(); - } + if (nodeResponseTracker.trackResponseAndCheckIfLast(nodeIndex, response)) { + onCompletion(); } } protected void onNodeFailure(DiscoveryNode node, int nodeIndex, Throwable t) { String nodeId = node.getId(); logger.debug(new ParameterizedMessage("failed to execute [{}] on node [{}]", actionName, nodeId), t); - - // this is defensive to protect against the possibility of double invocation - // the current implementation of TransportService#sendRequest guards against this - // but concurrency is hard, safety is important, and the small performance loss here does not matter - if (responses.compareAndSet(nodeIndex, null, new FailedNodeException(nodeId, "Failed node [" + nodeId + "]", t))) { - if (counter.incrementAndGet() == responses.length()) { - onCompletion(); - } + if (nodeResponseTracker.trackResponseAndCheckIfLast( + nodeIndex, + new FailedNodeException(nodeId, "Failed node [" + nodeId + "]", t) + )) { + onCompletion(); } } protected void onCompletion() { - if (task instanceof CancellableTask && ((CancellableTask) task).notifyIfCancelled(listener)) { + if ((task instanceof CancellableTask t) && t.notifyIfCancelled(listener)) { return; } Response response = null; try { - response = newResponse(request, responses, unavailableShardCount, nodeIds, clusterState); + response = newResponse(request, nodeResponseTracker, unavailableShardCount, nodeIds, clusterState); + } catch (NodeResponseTracker.DiscardedResponsesException e) { + // We propagate the reason that the results, in this case the task cancellation, in case the listener needs to take + // follow-up actions + listener.onFailure((Exception) e.getCause()); } catch (Exception e) { logger.debug("failed to combine responses from nodes", e); listener.onFailure(e); @@ -417,6 +416,21 @@ protected void onCompletion() { } } } + + @Override + public void onCancelled() { + assert task instanceof CancellableTask : "task must be cancellable"; + try { + ((CancellableTask) task).ensureNotCancelled(); + } catch (TaskCancelledException e) { + nodeResponseTracker.discardIntermediateResponses(e); + } + } + + // For testing purposes + public NodeResponseTracker getNodeResponseTracker() { + return nodeResponseTracker; + } } class BroadcastByNodeTransportRequestHandler implements TransportRequestHandler { diff --git a/server/src/main/java/org/elasticsearch/action/support/nodes/TransportNodesAction.java b/server/src/main/java/org/elasticsearch/action/support/nodes/TransportNodesAction.java index 5b13f3aab917d..c93f688b5a16d 100644 --- a/server/src/main/java/org/elasticsearch/action/support/nodes/TransportNodesAction.java +++ b/server/src/main/java/org/elasticsearch/action/support/nodes/TransportNodesAction.java @@ -13,6 +13,7 @@ import org.elasticsearch.action.FailedNodeException; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.action.support.NodeResponseTracker; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.service.ClusterService; @@ -20,6 +21,7 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; +import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportChannel; import org.elasticsearch.transport.TransportException; @@ -34,8 +36,6 @@ import java.util.Arrays; import java.util.List; import java.util.Objects; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReferenceArray; public abstract class TransportNodesAction< NodesRequest extends BaseNodesRequest, @@ -128,14 +128,15 @@ protected void doExecute(Task task, NodesRequest request, ActionListener nodesResponses, ActionListener listener) { + void newResponse(Task task, NodesRequest request, NodeResponseTracker nodeResponseTracker, ActionListener listener) + throws NodeResponseTracker.DiscardedResponsesException { - if (nodesResponses == null) { + if (nodeResponseTracker == null) { listener.onFailure(new NullPointerException("nodesResponses")); return; } @@ -143,11 +144,10 @@ void newResponse(Task task, NodesRequest request, AtomicReferenceArray nodesR final List responses = new ArrayList<>(); final List failures = new ArrayList<>(); - for (int i = 0; i < nodesResponses.length(); ++i) { - Object response = nodesResponses.get(i); - - if (response instanceof FailedNodeException) { - failures.add((FailedNodeException) response); + for (int i = 0; i < nodeResponseTracker.getExpectedResponseCount(); ++i) { + Object response = nodeResponseTracker.getResponse(i); + if (nodeResponseTracker.getResponse(i)instanceof FailedNodeException failedNodeException) { + failures.add(failedNodeException); } else { responses.add(nodeResponseClass.cast(response)); } @@ -203,12 +203,11 @@ protected String getTransportNodeAction(DiscoveryNode node) { return transportNodeAction; } - class AsyncAction { + class AsyncAction implements CancellableTask.CancellationListener { private final NodesRequest request; private final ActionListener listener; - private final AtomicReferenceArray responses; - private final AtomicInteger counter = new AtomicInteger(); + private final NodeResponseTracker nodeResponseTracker; private final Task task; AsyncAction(Task task, NodesRequest request, ActionListener listener) { @@ -219,10 +218,13 @@ class AsyncAction { resolveRequest(request, clusterService.state()); assert request.concreteNodes() != null; } - this.responses = new AtomicReferenceArray<>(request.concreteNodes().length); + this.nodeResponseTracker = new NodeResponseTracker(request.concreteNodes().length); } void start() { + if (task instanceof CancellableTask cancellableTask) { + cancellableTask.addListener(this); + } final DiscoveryNode[] nodes = request.concreteNodes(); if (nodes.length == 0) { finishHim(); @@ -267,28 +269,49 @@ public void handleException(TransportException exp) { } } + // For testing purposes + NodeResponseTracker getNodeResponseTracker() { + return nodeResponseTracker; + } + private void onOperation(int idx, NodeResponse nodeResponse) { - responses.set(idx, nodeResponse); - if (counter.incrementAndGet() == responses.length()) { + if (nodeResponseTracker.trackResponseAndCheckIfLast(idx, nodeResponse)) { finishHim(); } } private void onFailure(int idx, String nodeId, Throwable t) { logger.debug(new ParameterizedMessage("failed to execute on node [{}]", nodeId), t); - responses.set(idx, new FailedNodeException(nodeId, "Failed node [" + nodeId + "]", t)); - if (counter.incrementAndGet() == responses.length()) { + if (nodeResponseTracker.trackResponseAndCheckIfLast(idx, new FailedNodeException(nodeId, "Failed node [" + nodeId + "]", t))) { finishHim(); } } private void finishHim() { - if (task instanceof CancellableTask && ((CancellableTask) task).notifyIfCancelled(listener)) { + if ((task instanceof CancellableTask t) && t.notifyIfCancelled(listener)) { return; } final String executor = finalExecutor.equals(ThreadPool.Names.SAME) ? ThreadPool.Names.GENERIC : finalExecutor; - threadPool.executor(executor).execute(() -> newResponse(task, request, responses, listener)); + threadPool.executor(executor).execute(() -> { + try { + newResponse(task, request, nodeResponseTracker, listener); + } catch (NodeResponseTracker.DiscardedResponsesException e) { + // We propagate the reason that the results, in this case the task cancellation, in case the listener needs to take + // follow-up actions + listener.onFailure((Exception) e.getCause()); + } + }); + } + + @Override + public void onCancelled() { + assert task instanceof CancellableTask : "task must be cancellable"; + try { + ((CancellableTask) task).ensureNotCancelled(); + } catch (TaskCancelledException e) { + nodeResponseTracker.discardIntermediateResponses(e); + } } } diff --git a/server/src/main/java/org/elasticsearch/tasks/CancellableTask.java b/server/src/main/java/org/elasticsearch/tasks/CancellableTask.java index b318d485317b7..9010a9d99d3c4 100644 --- a/server/src/main/java/org/elasticsearch/tasks/CancellableTask.java +++ b/server/src/main/java/org/elasticsearch/tasks/CancellableTask.java @@ -12,6 +12,7 @@ import org.elasticsearch.core.Nullable; import java.util.Map; +import java.util.concurrent.ConcurrentLinkedQueue; /** * A task that can be cancelled @@ -20,6 +21,7 @@ public class CancellableTask extends Task { private volatile String reason; private volatile boolean isCancelled; + private final ConcurrentLinkedQueue listeners = new ConcurrentLinkedQueue<>(); public CancellableTask(long id, String type, String action, String description, TaskId parentTaskId, Map headers) { super(id, type, action, description, parentTaskId, headers); @@ -37,6 +39,7 @@ final void cancel(String reason) { this.isCancelled = true; this.reason = reason; } + listeners.forEach(CancellationListener::onCancelled); onCancelled(); } @@ -67,6 +70,20 @@ public final String getReasonCancelled() { return reason; } + /** + * This method adds a listener that needs to be notified if this task is cancelled. + */ + public final void addListener(CancellationListener listener) { + synchronized (this) { + if (this.isCancelled == false) { + listeners.add(listener); + } + } + if (isCancelled) { + listener.onCancelled(); + } + } + /** * Called after the task is cancelled so that it can take any actions that it has to take. */ @@ -103,4 +120,11 @@ private TaskCancelledException getTaskCancelledException() { assert reason != null; return new TaskCancelledException("task cancelled [" + reason + ']'); } + + /** + * This interface is implemented by any class that needs to react to the cancellation of this task. + */ + public interface CancellationListener { + void onCancelled(); + } } diff --git a/server/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/CancellableTasksTests.java b/server/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/CancellableTasksTests.java index ef04ad960e607..82677663b01c0 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/CancellableTasksTests.java +++ b/server/src/test/java/org/elasticsearch/action/admin/cluster/node/tasks/CancellableTasksTests.java @@ -187,6 +187,19 @@ protected NodeResponse nodeOperation(CancellableNodeRequest request, Task task) } } + /** + * Simulates a cancellation listener and sets a flag to true if the task was cancelled + */ + static class CancellableTestCancellationListener implements CancellableTask.CancellationListener { + + final AtomicBoolean calledUponCancellation = new AtomicBoolean(false); + + @Override + public void onCancelled() { + calledUponCancellation.set(true); + } + } + private Task startCancellableTestNodesAction( boolean waitForActionToStart, int runNodesCount, @@ -252,6 +265,7 @@ public void testBasicTaskCancellation() throws Exception { setupTestNodes(Settings.EMPTY); connectNodes(testNodes); CountDownLatch responseLatch = new CountDownLatch(1); + AtomicBoolean listenerCalledUponCancellation = new AtomicBoolean(false); boolean waitForActionToStart = randomBoolean(); logger.info("waitForActionToStart is set to {}", waitForActionToStart); final AtomicReference responseReference = new AtomicReference<>(); @@ -260,24 +274,23 @@ public void testBasicTaskCancellation() throws Exception { // Block at least 1 node, otherwise it's quite easy to end up in a race condition where the node tasks // have finished before the cancel request has arrived int blockedNodesCount = randomIntBetween(1, runNodesCount); - Task mainTask = startCancellableTestNodesAction( - waitForActionToStart, - runNodesCount, - blockedNodesCount, - new ActionListener() { - @Override - public void onResponse(NodesResponse listTasksResponse) { - responseReference.set(listTasksResponse); - responseLatch.countDown(); - } + Task mainTask = startCancellableTestNodesAction(waitForActionToStart, runNodesCount, blockedNodesCount, new ActionListener<>() { + @Override + public void onResponse(NodesResponse listTasksResponse) { + responseReference.set(listTasksResponse); + responseLatch.countDown(); + } - @Override - public void onFailure(Exception e) { - throwableReference.set(e); - responseLatch.countDown(); - } + @Override + public void onFailure(Exception e) { + throwableReference.set(e); + responseLatch.countDown(); } - ); + }); + + assert mainTask instanceof CancellableTask; + CancellableTestCancellationListener listenerAddedBeforeCancellation = new CancellableTestCancellationListener(); + ((CancellableTask) mainTask).addListener(listenerAddedBeforeCancellation); // Cancel main task CancelTasksRequest request = new CancelTasksRequest(); @@ -311,6 +324,13 @@ public void onFailure(Exception e) { for (TaskInfo taskInfo : response.getTasks()) { assertTrue(taskInfo.cancellable()); } + + CancellableTestCancellationListener listenerAddedAfterCancellation = new CancellableTestCancellationListener(); + ((CancellableTask) mainTask).addListener(listenerAddedAfterCancellation); + + // Verify both cancellation listeners have been notified + assertTrue(listenerAddedBeforeCancellation.calledUponCancellation.get()); + assertTrue(listenerAddedAfterCancellation.calledUponCancellation.get()); } // Make sure that tasks are no longer running @@ -337,7 +357,7 @@ public void testChildTasksCancellation() throws Exception { final AtomicReference throwableReference = new AtomicReference<>(); int runNodesCount = randomIntBetween(1, nodesCount); int blockedNodesCount = randomIntBetween(0, runNodesCount); - Task mainTask = startCancellableTestNodesAction(true, runNodesCount, blockedNodesCount, new ActionListener() { + Task mainTask = startCancellableTestNodesAction(true, runNodesCount, blockedNodesCount, new ActionListener<>() { @Override public void onResponse(NodesResponse listTasksResponse) { responseReference.set(listTasksResponse); diff --git a/server/src/test/java/org/elasticsearch/action/support/NodeResponseTrackerTests.java b/server/src/test/java/org/elasticsearch/action/support/NodeResponseTrackerTests.java new file mode 100644 index 0000000000000..11d2ee1f12a04 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/action/support/NodeResponseTrackerTests.java @@ -0,0 +1,61 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.action.support; + +import org.elasticsearch.test.ESTestCase; + +public class NodeResponseTrackerTests extends ESTestCase { + + public void testAllResponsesReceived() throws Exception { + int nodes = randomIntBetween(1, 10); + NodeResponseTracker intermediateNodeResponses = new NodeResponseTracker(nodes); + for (int i = 0; i < nodes; i++) { + boolean isLast = i == nodes - 1; + assertEquals( + isLast, + intermediateNodeResponses.trackResponseAndCheckIfLast(i, randomBoolean() ? i : new Exception("from node " + i)) + ); + } + + assertFalse(intermediateNodeResponses.responsesDiscarded()); + assertEquals(nodes, intermediateNodeResponses.getExpectedResponseCount()); + for (int i = 0; i < nodes; i++) { + assertNotNull(intermediateNodeResponses.getResponse(i)); + if (intermediateNodeResponses.getResponse(i)instanceof Integer nodeResponse) { + assertEquals(i, nodeResponse.intValue()); + } + } + } + + public void testDiscardingResults() { + int nodes = randomIntBetween(1, 10); + int cancelAt = randomIntBetween(0, Math.max(0, nodes - 2)); + NodeResponseTracker intermediateNodeResponses = new NodeResponseTracker(nodes); + for (int i = 0; i < nodes; i++) { + if (i == cancelAt) { + intermediateNodeResponses.discardIntermediateResponses(new Exception("simulated")); + } + boolean isLast = i == nodes - 1; + assertEquals( + isLast, + intermediateNodeResponses.trackResponseAndCheckIfLast(i, randomBoolean() ? i : new Exception("from node " + i)) + ); + } + + assertTrue(intermediateNodeResponses.responsesDiscarded()); + assertEquals(nodes, intermediateNodeResponses.getExpectedResponseCount()); + expectThrows(NodeResponseTracker.DiscardedResponsesException.class, () -> intermediateNodeResponses.getResponse(0)); + } + + public void testResponseIsRegisteredOnlyOnce() { + NodeResponseTracker intermediateNodeResponses = new NodeResponseTracker(1); + assertTrue(intermediateNodeResponses.trackResponseAndCheckIfLast(0, "response1")); + expectThrows(AssertionError.class, () -> intermediateNodeResponses.trackResponseAndCheckIfLast(0, "response2")); + } +} diff --git a/server/src/test/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeActionTests.java b/server/src/test/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeActionTests.java index 948288fe06281..93defb70ec466 100644 --- a/server/src/test/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeActionTests.java @@ -537,14 +537,23 @@ public void testResultAggregation() throws ExecutionException, InterruptedExcept public void testNoResultAggregationIfTaskCancelled() { Request request = new Request(new String[] { TEST_INDEX }); PlainActionFuture listener = new PlainActionFuture<>(); - action.new AsyncAction(cancelledTask(), request, listener).start(); + final CancellableTask task = new CancellableTask(randomLong(), "transport", "action", "", null, emptyMap()); + TransportBroadcastByNodeAction.AsyncAction asyncAction = + action.new AsyncAction(task, request, listener); + asyncAction.start(); Map> capturedRequests = transport.getCapturedRequestsByTargetNodeAndClear(); - + int cancelAt = randomIntBetween(0, Math.max(0, capturedRequests.size() - 2)); + int i = 0; for (Map.Entry> entry : capturedRequests.entrySet()) { + if (cancelAt == i) { + TaskCancelHelper.cancel(task, "simulated"); + } transport.handleRemoteError(entry.getValue().get(0).requestId(), new ElasticsearchException("simulated")); + i++; } assertTrue(listener.isDone()); + assertTrue(asyncAction.getNodeResponseTracker().responsesDiscarded()); expectThrows(ExecutionException.class, TaskCancelledException.class, listener::get); } diff --git a/server/src/test/java/org/elasticsearch/action/support/nodes/TransportNodesActionTests.java b/server/src/test/java/org/elasticsearch/action/support/nodes/TransportNodesActionTests.java index ee43aaa5b5e90..def2e4558bd23 100644 --- a/server/src/test/java/org/elasticsearch/action/support/nodes/TransportNodesActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/support/nodes/TransportNodesActionTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.Version; import org.elasticsearch.action.FailedNodeException; import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.NodeResponseTracker; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.action.support.broadcast.node.TransportBroadcastByNodeActionTests; import org.elasticsearch.cluster.ClusterName; @@ -47,7 +48,6 @@ import java.util.Set; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReferenceArray; import java.util.function.Supplier; import static java.util.Collections.emptyMap; @@ -94,14 +94,14 @@ public void testNodesSelectors() { assertEquals(clusterService.state().nodes().resolveNodes(finalNodesIds).length, capturedRequests.size()); } - public void testNewResponseNullArray() { + public void testNewResponseNullArray() throws Exception { TransportNodesAction action = getTestTransportNodesAction(); final PlainActionFuture future = new PlainActionFuture<>(); action.newResponse(new Task(1, "test", "test", "", null, emptyMap()), new TestNodesRequest(), null, future); expectThrows(NullPointerException.class, future::actionGet); } - public void testNewResponse() { + public void testNewResponse() throws Exception { TestTransportNodesAction action = getTestTransportNodesAction(); TestNodesRequest request = new TestNodesRequest(); List expectedNodeResponses = mockList(TestNodeResponse::new, randomIntBetween(0, 2)); @@ -120,10 +120,10 @@ public void testNewResponse() { Collections.shuffle(allResponses, random()); - AtomicReferenceArray atomicArray = new AtomicReferenceArray<>(allResponses.toArray()); + NodeResponseTracker nodeResponseCollector = new NodeResponseTracker(allResponses); final PlainActionFuture future = new PlainActionFuture<>(); - action.newResponse(new Task(1, "test", "test", "", null, emptyMap()), request, atomicArray, future); + action.newResponse(new Task(1, "test", "test", "", null, emptyMap()), request, nodeResponseCollector, future); TestNodesResponse response = future.actionGet(); assertSame(request, response.request); @@ -146,7 +146,7 @@ public void testCustomResolving() throws Exception { assertEquals(clusterService.state().nodes().getDataNodes().size(), capturedRequests.size()); } - public void testTaskCancellationThrowsException() { + public void testTaskCancellation() { TransportNodesAction action = getTestTransportNodesAction(); List nodeIds = new ArrayList<>(); for (DiscoveryNode node : clusterService.state().nodes()) { @@ -156,10 +156,16 @@ public void testTaskCancellationThrowsException() { TestNodesRequest request = new TestNodesRequest(nodeIds.toArray(new String[0])); PlainActionFuture listener = new PlainActionFuture<>(); CancellableTask cancellableTask = new CancellableTask(randomLong(), "transport", "action", "", null, emptyMap()); - TaskCancelHelper.cancel(cancellableTask, "simulated"); - action.doExecute(cancellableTask, request, listener); + TransportNodesAction.AsyncAction asyncAction = + action.new AsyncAction(cancellableTask, request, listener); + asyncAction.start(); Map> capturedRequests = transport.getCapturedRequestsByTargetNodeAndClear(); + int cancelAt = randomIntBetween(0, Math.max(0, capturedRequests.values().size() - 2)); + int requestCount = 0; for (List requests : capturedRequests.values()) { + if (requestCount == cancelAt) { + TaskCancelHelper.cancel(cancellableTask, "simulated"); + } for (CapturingTransport.CapturedRequest capturedRequest : requests) { if (randomBoolean()) { transport.handleResponse(capturedRequest.requestId(), new TestNodeResponse(capturedRequest.node())); @@ -167,9 +173,11 @@ public void testTaskCancellationThrowsException() { transport.handleRemoteError(capturedRequest.requestId(), new TaskCancelledException("simulated")); } } + requestCount++; } assertTrue(listener.isDone()); + assertTrue(asyncAction.getNodeResponseTracker().responsesDiscarded()); expectThrows(ExecutionException.class, TaskCancelledException.class, listener::get); }