Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.test.InternalTestCluster;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportResponseHandler;
import org.elasticsearch.transport.TransportService;
Expand Down Expand Up @@ -307,6 +308,58 @@ public void testCancelOrphanedTasks() throws Exception {
}
}

public void testRemoveBanParentsOnDisconnect() throws Exception {
Set<DiscoveryNode> nodes = StreamSupport.stream(clusterService().state().nodes().spliterator(), false).collect(Collectors.toSet());
final TestRequest rootRequest = generateTestRequest(nodes, 0, between(1, 4));
client().execute(TransportTestAction.ACTION, rootRequest);
Set<TestRequest> pendingRequests = allowPartialRequest(rootRequest);
TaskId rootTaskId = getRootTaskId(rootRequest);
ActionFuture<CancelTasksResponse> cancelFuture = client().admin().cluster().prepareCancelTasks()
.setTaskId(rootTaskId).waitForCompletion(true).execute();
try {
assertBusy(() -> {
for (DiscoveryNode node : nodes) {
TaskManager taskManager = internalCluster().getInstance(TransportService.class, node.getName()).getTaskManager();
Set<TaskId> expectedBans = new HashSet<>();
for (TestRequest req : pendingRequests) {
if (req.node.equals(node)) {
List<Task> childTasks = taskManager.getTasks().values().stream()
.filter(t -> t.getParentTaskId() != null && t.getDescription().equals(req.taskDescription()))
.collect(Collectors.toList());
assertThat(childTasks, hasSize(1));
CancellableTask childTask = (CancellableTask) childTasks.get(0);
assertTrue(childTask.isCancelled());
expectedBans.add(childTask.getParentTaskId());
}
}
assertThat(taskManager.getBannedTaskIds(), equalTo(expectedBans));
}
}, 30, TimeUnit.SECONDS);

final Set<TaskId> bannedParents = new HashSet<>();
for (DiscoveryNode node : nodes) {
TaskManager taskManager = internalCluster().getInstance(TransportService.class, node.getName()).getTaskManager();
bannedParents.addAll(taskManager.getBannedTaskIds());
}
// Disconnect some outstanding child connections
for (DiscoveryNode node : nodes) {
TaskManager taskManager = internalCluster().getInstance(TransportService.class, node.getName()).getTaskManager();
for (TaskId bannedParent : bannedParents) {
if (bannedParent.getNodeId().equals(node.getId()) && randomBoolean()) {
Collection<Transport.Connection> childConns = taskManager.startBanOnChildTasks(bannedParent.getId(), () -> {});
for (Transport.Connection connection : randomSubsetOf(childConns)) {
connection.close();
}
}
}
}
} finally {
allowEntireRequest(rootRequest);
cancelFuture.actionGet();
ensureAllBansRemoved();
}
}

static TaskId getRootTaskId(TestRequest request) throws Exception {
SetOnce<TaskId> taskId = new SetOnce<>();
assertBusy(() -> {
Expand All @@ -326,6 +379,7 @@ static void waitForRootTask(ActionFuture<TestResponse> rootTask) {
rootTask.actionGet();
} catch (Exception e) {
final Throwable cause = ExceptionsHelper.unwrap(e, TaskCancelledException.class);
assertNotNull(cause);
assertThat(cause.getMessage(), anyOf(
equalTo("The parent task was cancelled, shouldn't start any child tasks"),
containsString("Task cancelled before it started:"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ public void messageReceived(final BanParentTaskRequest request, final TransportC
if (request.ban) {
logger.debug("Received ban for the parent [{}] on the node [{}], reason: [{}]", request.parentTaskId,
localNodeId(), request.reason);
final List<CancellableTask> childTasks = taskManager.setBan(request.parentTaskId, request.reason);
final List<CancellableTask> childTasks = taskManager.setBan(request.parentTaskId, request.reason, channel);
final GroupedActionListener<Void> listener = new GroupedActionListener<>(
new ChannelActionListener<>(channel, BAN_PARENT_ACTION_NAME, request).map(r -> TransportResponse.Empty.INSTANCE),
childTasks.size() + 1);
Expand Down
135 changes: 108 additions & 27 deletions server/src/main/java/org/elasticsearch/tasks/TaskManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchTimeoutException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.cluster.ClusterChangedEvent;
Expand All @@ -45,15 +46,19 @@
import org.elasticsearch.common.util.concurrent.ConcurrentMapLong;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TaskTransportChannel;
import org.elasticsearch.transport.TcpChannel;
import org.elasticsearch.transport.TcpTransportChannel;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportService;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
Expand All @@ -63,6 +68,7 @@
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

Expand All @@ -89,7 +95,7 @@ public class TaskManager implements ClusterStateApplier {

private final AtomicLong taskIdGenerator = new AtomicLong();

private final Map<TaskId, String> banedParents = new ConcurrentHashMap<>();
private final Map<TaskId, Ban> bannedParents = new ConcurrentHashMap<>();

private TaskResultsService taskResultsService;

Expand Down Expand Up @@ -154,13 +160,13 @@ private void registerCancellableTask(Task task) {
CancellableTaskHolder oldHolder = cancellableTasks.put(task.getId(), holder);
assert oldHolder == null;
// Check if this task was banned before we start it. The empty check is used to avoid
// computing the hash code of the parent taskId as most of the time banedParents is empty.
if (task.getParentTaskId().isSet() && banedParents.isEmpty() == false) {
String reason = banedParents.get(task.getParentTaskId());
if (reason != null) {
// computing the hash code of the parent taskId as most of the time bannedParents is empty.
if (task.getParentTaskId().isSet() && bannedParents.isEmpty() == false) {
final Ban ban = bannedParents.get(task.getParentTaskId());
if (ban != null) {
try {
holder.cancel(reason);
throw new TaskCancelledException("Task cancelled before it started: " + reason);
holder.cancel(ban.reason);
throw new TaskCancelledException("Task cancelled before it started: " + ban.reason);
} finally {
// let's clean up the registration
unregister(task);
Expand Down Expand Up @@ -345,7 +351,7 @@ public CancellableTask getCancellableTask(long id) {
* Will be used in task manager stats and for debugging.
*/
public int getBanCount() {
return banedParents.size();
return bannedParents.size();
}

/**
Expand All @@ -354,14 +360,27 @@ public int getBanCount() {
* This method is called when a parent task that has children is cancelled.
* @return a list of pending cancellable child tasks
*/
public List<CancellableTask> setBan(TaskId parentTaskId, String reason) {
public List<CancellableTask> setBan(TaskId parentTaskId, String reason, TransportChannel channel) {
logger.trace("setting ban for the parent task {} {}", parentTaskId, reason);

// Set the ban first, so the newly created tasks cannot be registered
synchronized (banedParents) {
if (lastDiscoveryNodes.nodeExists(parentTaskId.getNodeId())) {
// Only set the ban if the node is the part of the cluster
banedParents.put(parentTaskId, reason);
synchronized (bannedParents) {
if (channel.getVersion().onOrAfter(Version.V_7_12_0)) {
final Ban ban = bannedParents.computeIfAbsent(parentTaskId, k -> new Ban(reason, true));
assert ban.perChannel : "not a ban per channel";
while (channel instanceof TaskTransportChannel) {
channel = ((TaskTransportChannel) channel).getChannel();
}
if (channel instanceof TcpTransportChannel) {
startTrackingChannel(((TcpTransportChannel) channel).getChannel(), ban::registerChannel);
} else {
assert channel.getChannelType().equals("direct") : "expect direct channel; got [" + channel + "]";
ban.registerChannel(DIRECT_CHANNEL_TRACKER);
}
} else {
if (lastDiscoveryNodes.nodeExists(parentTaskId.getNodeId())) {
// Only set the ban if the node is the part of the cluster
final Ban existing = bannedParents.put(parentTaskId, new Ban(reason, false));
assert existing == null || existing.perChannel == false : "not a ban per node";
}
}
}
return cancellableTasks.values().stream()
Expand All @@ -377,12 +396,52 @@ public List<CancellableTask> setBan(TaskId parentTaskId, String reason) {
*/
public void removeBan(TaskId parentTaskId) {
logger.trace("removing ban for the parent task {}", parentTaskId);
banedParents.remove(parentTaskId);
bannedParents.remove(parentTaskId);
}

// for testing
public Set<TaskId> getBannedTaskIds() {
return Collections.unmodifiableSet(banedParents.keySet());
return Collections.unmodifiableSet(bannedParents.keySet());
}

private class Ban {
final String reason;
final boolean perChannel; // TODO: Remove this in 8.0
final Set<ChannelPendingTaskTracker> channels;

Ban(String reason, boolean perChannel) {
assert Thread.holdsLock(bannedParents);
this.reason = reason;
this.perChannel = perChannel;
if (perChannel) {
this.channels = new HashSet<>();
} else {
this.channels = Collections.emptySet();
}
}

void registerChannel(ChannelPendingTaskTracker channel) {
assert Thread.holdsLock(bannedParents);
assert perChannel : "not a ban per channel";
channels.add(channel);
}

boolean unregisterChannel(ChannelPendingTaskTracker channel) {
assert Thread.holdsLock(bannedParents);
assert perChannel : "not a ban per channel";
return channels.remove(channel);
}

int registeredChannels() {
assert Thread.holdsLock(bannedParents);
assert perChannel : "not a ban per channel";
return channels.size();
}

@Override
public String toString() {
return "Ban{" + "reason='" + reason + '\'' + ", perChannel=" + perChannel + ", channels=" + channels + '}';
}
}

/**
Expand All @@ -406,15 +465,15 @@ public Collection<Transport.Connection> startBanOnChildTasks(long taskId, Runnab
public void applyClusterState(ClusterChangedEvent event) {
lastDiscoveryNodes = event.state().getNodes();
if (event.nodesRemoved()) {
synchronized (banedParents) {
synchronized (bannedParents) {
lastDiscoveryNodes = event.state().getNodes();
// Remove all bans that were registered by nodes that are no longer in the cluster state
Iterator<TaskId> banIterator = banedParents.keySet().iterator();
final Iterator<Map.Entry<TaskId, Ban>> banIterator = bannedParents.entrySet().iterator();
while (banIterator.hasNext()) {
TaskId taskId = banIterator.next();
if (lastDiscoveryNodes.nodeExists(taskId.getNodeId()) == false) {
logger.debug("Removing ban for the parent [{}] on the node [{}], reason: the parent node is gone", taskId,
event.state().getNodes().getLocalNode());
final Map.Entry<TaskId, Ban> ban = banIterator.next();
if (ban.getValue().perChannel == false && lastDiscoveryNodes.nodeExists(ban.getKey().getNodeId()) == false) {
logger.debug("Removing ban for the parent [{}] on the node [{}], reason: the parent node is gone",
ban.getKey(), event.state().getNodes().getLocalNode());
banIterator.remove();
}
}
Expand Down Expand Up @@ -581,32 +640,39 @@ Set<Transport.Connection> startBan(Runnable onChildTasksCompleted) {
*/
public Releasable startTrackingCancellableChannelTask(TcpChannel channel, CancellableTask task) {
assert cancellableTasks.containsKey(task.getId()) : "task [" + task.getId() + "] is not registered yet";
final ChannelPendingTaskTracker tracker = startTrackingChannel(channel, trackerChannel -> trackerChannel.addTask(task));
return () -> tracker.removeTask(task);
}

private ChannelPendingTaskTracker startTrackingChannel(TcpChannel channel, Consumer<ChannelPendingTaskTracker> onRegister) {
final ChannelPendingTaskTracker tracker = channelPendingTaskTrackers.compute(channel, (k, curr) -> {
if (curr == null) {
curr = new ChannelPendingTaskTracker();
}
curr.addTask(task);
onRegister.accept(curr);
return curr;
});
if (tracker.registered.compareAndSet(false, true)) {
channel.addCloseListener(ActionListener.wrap(
r -> {
final ChannelPendingTaskTracker removedTracker = channelPendingTaskTrackers.remove(channel);
assert removedTracker == tracker;
cancelTasksOnChannelClosed(tracker.drainTasks());
onChannelClosed(tracker);
},
e -> {
assert false : new AssertionError("must not be here", e);
}));
}
return () -> tracker.removeTask(task);
return tracker;
}

// for testing
final int numberOfChannelPendingTaskTrackers() {
return channelPendingTaskTrackers.size();
}

private static final ChannelPendingTaskTracker DIRECT_CHANNEL_TRACKER = new ChannelPendingTaskTracker();

private static class ChannelPendingTaskTracker {
final AtomicBoolean registered = new AtomicBoolean();
final Semaphore permits = Assertions.ENABLED ? new Semaphore(Integer.MAX_VALUE) : null;
Expand Down Expand Up @@ -640,7 +706,8 @@ void removeTask(CancellableTask task) {
}
}

private void cancelTasksOnChannelClosed(Set<CancellableTask> tasks) {
private void onChannelClosed(ChannelPendingTaskTracker channel) {
final Set<CancellableTask> tasks = channel.drainTasks();
if (tasks.isEmpty() == false) {
threadPool.generic().execute(new AbstractRunnable() {
@Override
Expand All @@ -656,6 +723,20 @@ protected void doRun() {
}
});
}

// Unregister the closing channel and remove bans whose has no registered channels
synchronized (bannedParents) {
final Iterator<Map.Entry<TaskId, Ban>> iterator = bannedParents.entrySet().iterator();
while (iterator.hasNext()) {
final Map.Entry<TaskId, Ban> entry = iterator.next();
final Ban ban = entry.getValue();
if (ban.perChannel) {
if (ban.unregisterChannel(channel) && entry.getValue().registeredChannels() == 0) {
iterator.remove();
}
}
}
}
}

public void cancelTaskAndDescendants(CancellableTask task, String reason, boolean waitForCompletion, ActionListener<Void> listener) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import com.carrotsearch.randomizedtesting.RandomizedContext;
import com.carrotsearch.randomizedtesting.generators.RandomNumbers;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksAction;
import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest;
Expand All @@ -41,7 +42,10 @@
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.tasks.TaskInfo;
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.test.VersionUtils;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.FakeTcpChannel;
import org.elasticsearch.transport.TestTransportChannels;
import org.elasticsearch.transport.TransportService;

import java.io.IOException;
Expand Down Expand Up @@ -360,7 +364,10 @@ public void testRegisterAndExecuteChildTaskWhileParentTaskIsBeingCanceled() thro
CancellableNodesRequest parentRequest = new CancellableNodesRequest("parent");
final Task parentTask = taskManager.register("test", "test", parentRequest);
final TaskId parentTaskId = parentTask.taskInfo(testNodes[0].getNodeId(), false).getTaskId();
taskManager.setBan(new TaskId(testNodes[0].getNodeId(), parentTask.getId()), "test");
taskManager.setBan(new TaskId(testNodes[0].getNodeId(), parentTask.getId()), "test",
TestTransportChannels.newFakeTcpTransportChannel(
testNodes[0].getNodeId(), new FakeTcpChannel(), threadPool,
"test", randomNonNegativeLong(), Version.CURRENT));
CancellableNodesRequest childRequest = new CancellableNodesRequest("child");
childRequest.setParentTask(parentTaskId);
CancellableTestNodesAction testAction = new CancellableTestNodesAction("internal:testAction", threadPool, testNodes[1]
Expand All @@ -374,7 +381,7 @@ public void testRegisterAndExecuteChildTaskWhileParentTaskIsBeingCanceled() thro
}

public void testTaskCancellationOnCoordinatingNodeLeavingTheCluster() throws Exception {
setupTestNodes(Settings.EMPTY);
setupTestNodes(Settings.EMPTY, VersionUtils.randomVersionBetween(random(), Version.V_7_0_0, Version.V_7_11_0));
connectNodes(testNodes);
CountDownLatch responseLatch = new CountDownLatch(1);
final AtomicReference<NodesResponse> responseReference = new AtomicReference<>();
Expand Down
Loading