diff --git a/server/src/main/java/org/elasticsearch/action/support/PlainActionFuture.java b/server/src/main/java/org/elasticsearch/action/support/PlainActionFuture.java index 094f82ae31f63..3157a63c8a31b 100644 --- a/server/src/main/java/org/elasticsearch/action/support/PlainActionFuture.java +++ b/server/src/main/java/org/elasticsearch/action/support/PlainActionFuture.java @@ -19,12 +19,20 @@ package org.elasticsearch.action.support; +import org.elasticsearch.common.CheckedConsumer; + public class PlainActionFuture extends AdapterActionFuture { public static PlainActionFuture newFuture() { return new PlainActionFuture<>(); } + public static T get(CheckedConsumer, E> e) throws E { + PlainActionFuture fut = newFuture(); + e.accept(fut); + return fut.actionGet(); + } + @Override protected T convert(T listenerResponse) { return listenerResponse; diff --git a/server/src/main/java/org/elasticsearch/cluster/coordination/Coordinator.java b/server/src/main/java/org/elasticsearch/cluster/coordination/Coordinator.java index 8de27337f1cba..2fceb76ccc1f4 100644 --- a/server/src/main/java/org/elasticsearch/cluster/coordination/Coordinator.java +++ b/server/src/main/java/org/elasticsearch/cluster/coordination/Coordinator.java @@ -442,23 +442,22 @@ private void handleJoinRequest(JoinRequest joinRequest, JoinHelper.JoinCallback return; } - transportService.connectToNode(joinRequest.getSourceNode()); - - final ClusterState stateForJoinValidation = getStateForMasterService(); - - if (stateForJoinValidation.nodes().isLocalNodeElectedMaster()) { - onJoinValidators.forEach(a -> a.accept(joinRequest.getSourceNode(), stateForJoinValidation)); - if (stateForJoinValidation.getBlocks().hasGlobalBlock(STATE_NOT_RECOVERED_BLOCK) == false) { - // we do this in a couple of places including the cluster update thread. This one here is really just best effort - // to ensure we fail as fast as possible. - JoinTaskExecutor.ensureMajorVersionBarrier(joinRequest.getSourceNode().getVersion(), - stateForJoinValidation.getNodes().getMinNodeVersion()); + transportService.connectToNode(joinRequest.getSourceNode(), ActionListener.wrap(ignore -> { + final ClusterState stateForJoinValidation = getStateForMasterService(); + + if (stateForJoinValidation.nodes().isLocalNodeElectedMaster()) { + onJoinValidators.forEach(a -> a.accept(joinRequest.getSourceNode(), stateForJoinValidation)); + if (stateForJoinValidation.getBlocks().hasGlobalBlock(STATE_NOT_RECOVERED_BLOCK) == false) { + // we do this in a couple of places including the cluster update thread. This one here is really just best effort + // to ensure we fail as fast as possible. + JoinTaskExecutor.ensureMajorVersionBarrier(joinRequest.getSourceNode().getVersion(), + stateForJoinValidation.getNodes().getMinNodeVersion()); + } + sendValidateJoinRequest(stateForJoinValidation, joinRequest, joinCallback); + } else { + processJoinRequest(joinRequest, joinCallback); } - sendValidateJoinRequest(stateForJoinValidation, joinRequest, joinCallback); - - } else { - processJoinRequest(joinRequest, joinCallback); - } + }, joinCallback::onFailure)); } // package private for tests diff --git a/server/src/main/java/org/elasticsearch/discovery/HandshakingTransportAddressConnector.java b/server/src/main/java/org/elasticsearch/discovery/HandshakingTransportAddressConnector.java index 4e90ae02e12ac..bca3cc7037175 100644 --- a/server/src/main/java/org/elasticsearch/discovery/HandshakingTransportAddressConnector.java +++ b/server/src/main/java/org/elasticsearch/discovery/HandshakingTransportAddressConnector.java @@ -24,6 +24,7 @@ import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.NotifyOnceListener; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.Randomness; import org.elasticsearch.common.UUIDs; @@ -70,7 +71,7 @@ public HandshakingTransportAddressConnector(Settings settings, TransportService public void connectToRemoteMasterNode(TransportAddress transportAddress, ActionListener listener) { transportService.getThreadPool().generic().execute(new AbstractRunnable() { @Override - protected void doRun() throws Exception { + protected void doRun() { // TODO if transportService is already connected to this address then skip the handshaking @@ -80,38 +81,68 @@ protected void doRun() throws Exception { emptySet(), Version.CURRENT.minimumCompatibilityVersion()); logger.trace("[{}] opening probe connection", this); - final Connection connection = transportService.openConnection(targetNode, + transportService.openConnection(targetNode, ConnectionProfile.buildSingleChannelProfile(Type.REG, probeConnectTimeout, probeHandshakeTimeout, - TimeValue.MINUS_ONE, null)); - logger.trace("[{}] opened probe connection", this); - - final DiscoveryNode remoteNode; - try { - remoteNode = transportService.handshake(connection, probeHandshakeTimeout.millis()); - // success means (amongst other things) that the cluster names match - logger.trace("[{}] handshake successful: {}", this, remoteNode); - } catch (Exception e) { - // we opened a connection and successfully performed a low-level handshake, so we were definitely talking to an - // Elasticsearch node, but the high-level handshake failed indicating some kind of mismatched configurations - // (e.g. cluster name) that the user should address - logger.warn(new ParameterizedMessage("handshake failed for [{}]", this), e); - listener.onFailure(e); - return; - } finally { - IOUtils.closeWhileHandlingException(connection); - } - - if (remoteNode.equals(transportService.getLocalNode())) { - // TODO cache this result for some time? forever? - listener.onFailure(new ConnectTransportException(remoteNode, "local node found")); - } else if (remoteNode.isMasterNode() == false) { - // TODO cache this result for some time? - listener.onFailure(new ConnectTransportException(remoteNode, "non-master-eligible node found")); - } else { - transportService.connectToNode(remoteNode); - logger.trace("[{}] full connection successful: {}", this, remoteNode); - listener.onResponse(remoteNode); - } + TimeValue.MINUS_ONE, null), new ActionListener<>() { + @Override + public void onResponse(Connection connection) { + logger.trace("[{}] opened probe connection", this); + + // use NotifyOnceListener to make sure the following line does not result in onFailure being called when + // the connection is closed in the onResponse handler + transportService.handshake(connection, probeHandshakeTimeout.millis(), new NotifyOnceListener() { + + @Override + protected void innerOnResponse(DiscoveryNode remoteNode) { + try { + // success means (amongst other things) that the cluster names match + logger.trace("[{}] handshake successful: {}", this, remoteNode); + IOUtils.closeWhileHandlingException(connection); + + if (remoteNode.equals(transportService.getLocalNode())) { + // TODO cache this result for some time? forever? + listener.onFailure(new ConnectTransportException(remoteNode, "local node found")); + } else if (remoteNode.isMasterNode() == false) { + // TODO cache this result for some time? + listener.onFailure(new ConnectTransportException(remoteNode, "non-master-eligible node found")); + } else { + transportService.connectToNode(remoteNode, new ActionListener() { + @Override + public void onResponse(Void ignored) { + logger.trace("[{}] full connection successful: {}", this, remoteNode); + listener.onResponse(remoteNode); + } + + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + }); + } + } catch (Exception e) { + listener.onFailure(e); + } + } + + @Override + protected void innerOnFailure(Exception e) { + // we opened a connection and successfully performed a low-level handshake, so we were definitely + // talking to an Elasticsearch node, but the high-level handshake failed indicating some kind of + // mismatched configurations (e.g. cluster name) that the user should address + logger.warn(new ParameterizedMessage("handshake failed for [{}]", this), e); + IOUtils.closeWhileHandlingException(connection); + listener.onFailure(e); + } + + }); + + } + + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + }); } @Override diff --git a/server/src/main/java/org/elasticsearch/transport/ConnectionManager.java b/server/src/main/java/org/elasticsearch/transport/ConnectionManager.java index da86ed076e396..3c31cddb39945 100644 --- a/server/src/main/java/org/elasticsearch/transport/ConnectionManager.java +++ b/server/src/main/java/org/elasticsearch/transport/ConnectionManager.java @@ -22,24 +22,24 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.cluster.node.DiscoveryNode; -import org.elasticsearch.common.CheckedBiConsumer; import org.elasticsearch.common.lease.Releasable; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.AbstractRefCounted; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; import org.elasticsearch.common.util.concurrent.KeyedLock; +import org.elasticsearch.common.util.concurrent.RunOnce; import org.elasticsearch.core.internal.io.IOUtils; import java.io.Closeable; -import java.io.IOException; +import java.util.ArrayList; import java.util.Iterator; +import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.locks.ReadWriteLock; -import java.util.concurrent.locks.ReentrantReadWriteLock; /** * This class manages node connections. The connection is opened by the underlying transport. Once the @@ -51,11 +51,18 @@ public class ConnectionManager implements Closeable { private static final Logger logger = LogManager.getLogger(ConnectionManager.class); private final ConcurrentMap connectedNodes = ConcurrentCollections.newConcurrentMap(); - private final KeyedLock connectionLock = new KeyedLock<>(); + private final KeyedLock connectionLock = new KeyedLock<>(); // protects concurrent access to connectingNodes + private final Map>> connectingNodes = ConcurrentCollections.newConcurrentMap(); + private final AbstractRefCounted connectingRefCounter = new AbstractRefCounted("connection manager") { + @Override + protected void closeInternal() { + closeLatch.countDown(); + } + }; private final Transport transport; private final ConnectionProfile defaultProfile; - private final AtomicBoolean isClosed = new AtomicBoolean(false); - private final ReadWriteLock closeLock = new ReentrantReadWriteLock(); + private final AtomicBoolean closing = new AtomicBoolean(false); + private final CountDownLatch closeLatch = new CountDownLatch(1); private final DelegatingNodeConnectionListener connectionListener = new DelegatingNodeConnectionListener(); public ConnectionManager(Settings settings, Transport transport) { @@ -75,66 +82,119 @@ public void removeListener(TransportConnectionListener listener) { this.connectionListener.listeners.remove(listener); } - public Transport.Connection openConnection(DiscoveryNode node, ConnectionProfile connectionProfile) { + public void openConnection(DiscoveryNode node, ConnectionProfile connectionProfile, ActionListener listener) { ConnectionProfile resolvedProfile = ConnectionProfile.resolveConnectionProfile(connectionProfile, defaultProfile); - return internalOpenConnection(node, resolvedProfile); + internalOpenConnection(node, resolvedProfile, listener); + } + + @FunctionalInterface + public interface ConnectionValidator { + void validate(Transport.Connection connection, ConnectionProfile profile, ActionListener listener); } /** * Connects to a node with the given connection profile. If the node is already connected this method has no effect. * Once a successful is established, it can be validated before being exposed. + * The ActionListener will be called on the calling thread or the generic thread pool. */ public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile, - CheckedBiConsumer connectionValidator) - throws ConnectTransportException { + ConnectionValidator connectionValidator, + ActionListener listener) throws ConnectTransportException { ConnectionProfile resolvedProfile = ConnectionProfile.resolveConnectionProfile(connectionProfile, defaultProfile); if (node == null) { - throw new ConnectTransportException(null, "can't connect to a null node"); + listener.onFailure(new ConnectTransportException(null, "can't connect to a null node")); + return; } - closeLock.readLock().lock(); // ensure we don't open connections while we are closing - try { - ensureOpen(); - try (Releasable ignored = connectionLock.acquire(node.getId())) { - Transport.Connection connection = connectedNodes.get(node); - if (connection != null) { - return; - } - boolean success = false; - try { - connection = internalOpenConnection(node, resolvedProfile); - connectionValidator.accept(connection, resolvedProfile); - // we acquire a connection lock, so no way there is an existing connection - connectedNodes.put(node, connection); - if (logger.isDebugEnabled()) { - logger.debug("connected to node [{}]", node); - } + + if (connectingRefCounter.tryIncRef() == false) { + listener.onFailure(new IllegalStateException("connection manager is closed")); + return; + } + + try (Releasable lock = connectionLock.acquire(node.getId())) { + Transport.Connection connection = connectedNodes.get(node); + if (connection != null) { + assert connectingNodes.containsKey(node) == false; + lock.close(); + connectingRefCounter.decRef(); + listener.onResponse(null); + return; + } + + final List> connectionListeners = connectingNodes.computeIfAbsent(node, n -> new ArrayList<>()); + connectionListeners.add(listener); + if (connectionListeners.size() > 1) { + // wait on previous entry to complete connection attempt + connectingRefCounter.decRef(); + return; + } + } + + final RunOnce releaseOnce = new RunOnce(connectingRefCounter::decRef); + + internalOpenConnection(node, resolvedProfile, ActionListener.wrap(conn -> { + connectionValidator.validate(conn, resolvedProfile, ActionListener.wrap( + ignored -> { + assert Transports.assertNotTransportThread("connection validator success"); + boolean success = false; + List> listeners = null; try { - connectionListener.onNodeConnected(node); + // we acquire a connection lock, so no way there is an existing connection + try (Releasable ignored2 = connectionLock.acquire(node.getId())) { + connectedNodes.put(node, conn); + if (logger.isDebugEnabled()) { + logger.debug("connected to node [{}]", node); + } + try { + connectionListener.onNodeConnected(node); + } finally { + final Transport.Connection finalConnection = conn; + conn.addCloseListener(ActionListener.wrap(() -> { + logger.trace("unregistering {} after connection close and marking as disconnected", node); + connectedNodes.remove(node, finalConnection); + connectionListener.onNodeDisconnected(node); + })); + } + if (conn.isClosed()) { + throw new NodeNotConnectedException(node, "connection concurrently closed"); + } + success = true; + listeners = connectingNodes.remove(node); + } + } catch (ConnectTransportException e) { + throw e; + } catch (Exception e) { + throw new ConnectTransportException(node, "general node connection failure", e); } finally { - final Transport.Connection finalConnection = connection; - connection.addCloseListener(ActionListener.wrap(() -> { - connectedNodes.remove(node, finalConnection); - connectionListener.onNodeDisconnected(node); - })); + if (success == false) { // close the connection if there is a failure + logger.trace(() -> new ParameterizedMessage("failed to connect to [{}], cleaning dangling connections", node)); + IOUtils.closeWhileHandlingException(conn); + } else { + releaseOnce.run(); + ActionListener.onResponse(listeners, null); + } } - if (connection.isClosed()) { - throw new NodeNotConnectedException(node, "connection concurrently closed"); + }, e -> { + assert Transports.assertNotTransportThread("connection validator failure"); + IOUtils.closeWhileHandlingException(conn); + final List> listeners; + try (Releasable ignored = connectionLock.acquire(node.getId())) { + listeners = connectingNodes.remove(node); } - success = true; - } catch (ConnectTransportException e) { - throw e; - } catch (Exception e) { - throw new ConnectTransportException(node, "general node connection failure", e); - } finally { - if (success == false) { // close the connection if there is a failure - logger.trace(() -> new ParameterizedMessage("failed to connect to [{}], cleaning dangling connections", node)); - IOUtils.closeWhileHandlingException(connection); - } - } + releaseOnce.run(); + ActionListener.onFailure(listeners, e); + })); + }, e -> { + assert Transports.assertNotTransportThread("internalOpenConnection failure"); + final List> listeners; + try (Releasable ignored = connectionLock.acquire(node.getId())) { + listeners = connectingNodes.remove(node); } - } finally { - closeLock.readLock().unlock(); - } + releaseOnce.run(); + if (listeners != null) { + ActionListener.onFailure(listeners, e); + } + })); } /** @@ -143,7 +203,7 @@ public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfil * maintained by this connection manager * * @throws NodeNotConnectedException if the node is not connected - * @see #connectToNode(DiscoveryNode, ConnectionProfile, CheckedBiConsumer) + * @see #connectToNode(DiscoveryNode, ConnectionProfile, ConnectionValidator, ActionListener) */ public Transport.Connection getConnection(DiscoveryNode node) { Transport.Connection connection = connectedNodes.get(node); @@ -180,55 +240,41 @@ public int size() { @Override public void close() { - Transports.assertNotTransportThread("Closing ConnectionManager"); - if (isClosed.compareAndSet(false, true)) { - closeLock.writeLock().lock(); + assert Transports.assertNotTransportThread("Closing ConnectionManager"); + if (closing.compareAndSet(false, true)) { + connectingRefCounter.decRef(); try { - // we are holding a write lock so nobody adds to the connectedNodes / openConnections map - it's safe to first close - // all instances and then clear them maps - Iterator> iterator = connectedNodes.entrySet().iterator(); - while (iterator.hasNext()) { - Map.Entry next = iterator.next(); - try { - IOUtils.closeWhileHandlingException(next.getValue()); - } finally { - iterator.remove(); - } + closeLatch.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new IllegalStateException(e); + } + Iterator> iterator = connectedNodes.entrySet().iterator(); + while (iterator.hasNext()) { + Map.Entry next = iterator.next(); + try { + IOUtils.closeWhileHandlingException(next.getValue()); + } finally { + iterator.remove(); } - } finally { - closeLock.writeLock().unlock(); } } } - private Transport.Connection internalOpenConnection(DiscoveryNode node, ConnectionProfile connectionProfile) { - PlainActionFuture future = PlainActionFuture.newFuture(); - Releasable pendingConnection = transport.openConnection(node, connectionProfile, future); - Transport.Connection connection; - try { - connection = future.actionGet(); - } catch (IllegalStateException e) { - // If the future was interrupted we must cancel the pending connection to avoid channels leaking - if (e.getCause() instanceof InterruptedException) { - pendingConnection.close(); + private void internalOpenConnection(DiscoveryNode node, ConnectionProfile connectionProfile, + ActionListener listener) { + transport.openConnection(node, connectionProfile, ActionListener.map(listener, connection -> { + assert Transports.assertNotTransportThread("internalOpenConnection success"); + try { + connectionListener.onConnectionOpened(connection); + } finally { + connection.addCloseListener(ActionListener.wrap(() -> connectionListener.onConnectionClosed(connection))); } - throw e; - } - try { - connectionListener.onConnectionOpened(connection); - } finally { - connection.addCloseListener(ActionListener.wrap(() -> connectionListener.onConnectionClosed(connection))); - } - if (connection.isClosed()) { - throw new ConnectTransportException(node, "a channel closed while connecting"); - } - return connection; - } - - private void ensureOpen() { - if (isClosed.get()) { - throw new IllegalStateException("connection manager is closed"); - } + if (connection.isClosed()) { + throw new ConnectTransportException(node, "a channel closed while connecting"); + } + return connection; + })); } ConnectionProfile getConnectionProfile() { diff --git a/server/src/main/java/org/elasticsearch/transport/RemoteClusterConnection.java b/server/src/main/java/org/elasticsearch/transport/RemoteClusterConnection.java index ca03e59a75858..df56691ef80bc 100644 --- a/server/src/main/java/org/elasticsearch/transport/RemoteClusterConnection.java +++ b/server/src/main/java/org/elasticsearch/transport/RemoteClusterConnection.java @@ -29,6 +29,7 @@ import org.elasticsearch.action.admin.cluster.state.ClusterStateRequest; import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse; import org.elasticsearch.action.support.ContextPreservingActionListener; +import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodes; @@ -449,14 +450,16 @@ private void collectRemoteNodes(Iterator> seedNodes, fin logger.debug("[{}] opening connection to seed node: [{}] proxy address: [{}]", clusterAlias, seedNode, proxyAddress); final TransportService.HandshakeResponse handshakeResponse; - ConnectionProfile profile = ConnectionProfile.buildSingleChannelProfile(TransportRequestOptions.Type.REG); - Transport.Connection connection = manager.openConnection(seedNode, profile); + final ConnectionProfile profile = ConnectionProfile.buildSingleChannelProfile(TransportRequestOptions.Type.REG); + final Transport.Connection connection = PlainActionFuture.get( + fut -> manager.openConnection(seedNode, profile, fut)); boolean success = false; try { try { ConnectionProfile connectionProfile = connectionManager.getConnectionProfile(); - handshakeResponse = transportService.handshake(connection, connectionProfile.getHandshakeTimeout().millis(), - (c) -> remoteClusterName.get() == null ? true : c.equals(remoteClusterName.get())); + handshakeResponse = PlainActionFuture.get(fut -> + transportService.handshake(connection, connectionProfile.getHandshakeTimeout().millis(), + (c) -> remoteClusterName.get() == null ? true : c.equals(remoteClusterName.get()), fut)); } catch (IllegalStateException ex) { logger.warn(() -> new ParameterizedMessage("seed node {} cluster name mismatch expected " + "cluster name {}", connection.getNode(), remoteClusterName.get()), ex); @@ -465,7 +468,8 @@ private void collectRemoteNodes(Iterator> seedNodes, fin final DiscoveryNode handshakeNode = maybeAddProxyAddress(proxyAddress, handshakeResponse.getDiscoveryNode()); if (nodePredicate.test(handshakeNode) && connectedNodes.size() < maxNumRemoteConnections) { - manager.connectToNode(handshakeNode, null, transportService.connectionValidator(handshakeNode)); + PlainActionFuture.get(fut -> manager.connectToNode(handshakeNode, null, + transportService.connectionValidator(handshakeNode), ActionListener.map(fut, x -> null))); if (remoteClusterName.get() == null) { assert handshakeResponse.getClusterName().value() != null; remoteClusterName.set(handshakeResponse.getClusterName()); @@ -579,8 +583,9 @@ public void handleResponse(ClusterStateResponse response) { DiscoveryNode node = maybeAddProxyAddress(proxyAddress, n); if (nodePredicate.test(node) && connectedNodes.size() < maxNumRemoteConnections) { try { - connectionManager.connectToNode(node, null, - transportService.connectionValidator(node)); // noop if node is connected + // noop if node is connected + PlainActionFuture.get(fut -> connectionManager.connectToNode(node, null, + transportService.connectionValidator(node), ActionListener.map(fut, x -> null))); connectedNodes.add(node); } catch (ConnectTransportException | IllegalStateException ex) { // ISE if we fail the handshake with an version incompatible node diff --git a/server/src/main/java/org/elasticsearch/transport/TcpTransport.java b/server/src/main/java/org/elasticsearch/transport/TcpTransport.java index ad9059dbc3757..7ebda8336bd7f 100644 --- a/server/src/main/java/org/elasticsearch/transport/TcpTransport.java +++ b/server/src/main/java/org/elasticsearch/transport/TcpTransport.java @@ -26,6 +26,7 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ThreadedActionListener; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.Booleans; import org.elasticsearch.common.Strings; @@ -35,7 +36,6 @@ import org.elasticsearch.common.component.AbstractLifecycleComponent; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.lease.Releasable; import org.elasticsearch.common.metrics.MeanMetric; import org.elasticsearch.common.network.CloseableChannel; import org.elasticsearch.common.network.NetworkAddress; @@ -254,7 +254,8 @@ protected ConnectionProfile maybeOverrideConnectionProfile(ConnectionProfile con } @Override - public Releasable openConnection(DiscoveryNode node, ConnectionProfile profile, ActionListener listener) { + public void openConnection(DiscoveryNode node, ConnectionProfile profile, ActionListener listener) { + Objects.requireNonNull(profile, "connection profile cannot be null"); if (node == null) { throw new ConnectTransportException(null, "can't open connection to a null node"); @@ -263,8 +264,7 @@ public Releasable openConnection(DiscoveryNode node, ConnectionProfile profile, closeLock.readLock().lock(); // ensure we don't open connections while we are closing try { ensureOpen(); - List pendingChannels = initiateConnection(node, finalProfile, listener); - return () -> CloseableChannel.closeChannels(pendingChannels, false); + initiateConnection(node, finalProfile, listener); } finally { closeLock.readLock().unlock(); } @@ -293,7 +293,8 @@ private List initiateConnection(DiscoveryNode node, ConnectionProfil } } - ChannelsConnectedListener channelsConnectedListener = new ChannelsConnectedListener(node, connectionProfile, channels, listener); + ChannelsConnectedListener channelsConnectedListener = new ChannelsConnectedListener(node, connectionProfile, channels, + new ThreadedActionListener<>(logger, threadPool, ThreadPool.Names.GENERIC, listener, false)); for (TcpChannel channel : channels) { channel.addConnectListener(channelsConnectedListener); diff --git a/server/src/main/java/org/elasticsearch/transport/Transport.java b/server/src/main/java/org/elasticsearch/transport/Transport.java index 0b79b6aecf093..e81fb9c380e9b 100644 --- a/server/src/main/java/org/elasticsearch/transport/Transport.java +++ b/server/src/main/java/org/elasticsearch/transport/Transport.java @@ -25,7 +25,6 @@ import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.breaker.NoopCircuitBreaker; import org.elasticsearch.common.component.LifecycleComponent; -import org.elasticsearch.common.lease.Releasable; import org.elasticsearch.common.transport.BoundTransportAddress; import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.util.concurrent.ConcurrentCollections; @@ -80,12 +79,10 @@ default CircuitBreaker getInFlightRequestBreaker() { } /** - * Opens a new connection to the given node. When the connection is fully connected, the listener is - * called. A {@link Releasable} is returned representing the pending connection. If the caller of this - * method decides to move on before the listener is called with the completed connection, they should - * release the pending connection to prevent hanging connections. + * Opens a new connection to the given node. When the connection is fully connected, the listener is called. + * The ActionListener will be called on the calling thread or the generic thread pool. */ - Releasable openConnection(DiscoveryNode node, ConnectionProfile profile, ActionListener listener); + void openConnection(DiscoveryNode node, ConnectionProfile profile, ActionListener listener); TransportStats getStats(); diff --git a/server/src/main/java/org/elasticsearch/transport/TransportService.java b/server/src/main/java/org/elasticsearch/transport/TransportService.java index dca7f52e60474..c070233c1f2e6 100644 --- a/server/src/main/java/org/elasticsearch/transport/TransportService.java +++ b/server/src/main/java/org/elasticsearch/transport/TransportService.java @@ -24,9 +24,10 @@ import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionListenerResponseHandler; +import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.node.DiscoveryNode; -import org.elasticsearch.common.CheckedBiConsumer; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Strings; import org.elasticsearch.common.component.AbstractLifecycleComponent; @@ -321,7 +322,7 @@ public boolean nodeConnected(DiscoveryNode node) { * @param node the node to connect to */ public void connectToNode(DiscoveryNode node) throws ConnectTransportException { - connectToNode(node, null); + connectToNode(node, (ConnectionProfile) null); } /** @@ -331,34 +332,74 @@ public void connectToNode(DiscoveryNode node) throws ConnectTransportException { * @param connectionProfile the connection profile to use when connecting to this node */ public void connectToNode(final DiscoveryNode node, ConnectionProfile connectionProfile) { + PlainActionFuture.get(fut -> connectToNode(node, connectionProfile, ActionListener.map(fut, x -> null))); + } + + /** + * Connect to the specified node with the given connection profile. + * The ActionListener will be called on the calling thread or the generic thread pool. + * + * @param node the node to connect to + * @param listener the action listener to notify + */ + public void connectToNode(DiscoveryNode node, ActionListener listener) throws ConnectTransportException { + connectToNode(node, null, listener); + } + + /** + * Connect to the specified node with the given connection profile. + * The ActionListener will be called on the calling thread or the generic thread pool. + * + * @param node the node to connect to + * @param connectionProfile the connection profile to use when connecting to this node + * @param listener the action listener to notify + */ + public void connectToNode(final DiscoveryNode node, ConnectionProfile connectionProfile, ActionListener listener) { if (isLocalNode(node)) { + listener.onResponse(null); return; } - connectionManager.connectToNode(node, connectionProfile, connectionValidator(node)); + connectionManager.connectToNode(node, connectionProfile, connectionValidator(node), listener); } - public CheckedBiConsumer connectionValidator(DiscoveryNode node) { - return (newConnection, actualProfile) -> { + public ConnectionManager.ConnectionValidator connectionValidator(DiscoveryNode node) { + return (newConnection, actualProfile, listener) -> { // We don't validate cluster names to allow for CCS connections. - final DiscoveryNode remote = handshake(newConnection, actualProfile.getHandshakeTimeout().millis(), cn -> true).discoveryNode; - if (node.equals(remote) == false) { - throw new ConnectTransportException(node, "handshake failed. unexpected remote node " + remote); - } + handshake(newConnection, actualProfile.getHandshakeTimeout().millis(), cn -> true, ActionListener.map(listener, resp -> { + final DiscoveryNode remote = resp.discoveryNode; + if (node.equals(remote) == false) { + throw new ConnectTransportException(node, "handshake failed. unexpected remote node " + remote); + } + return null; + })); }; - } /** * Establishes and returns a new connection to the given node. The connection is NOT maintained by this service, it's the callers * responsibility to close the connection once it goes out of scope. + * The ActionListener will be called on the calling thread or the generic thread pool. + * @param node the node to connect to + * @param connectionProfile the connection profile to use + */ + public Transport.Connection openConnection(final DiscoveryNode node, ConnectionProfile connectionProfile) { + return PlainActionFuture.get(fut -> openConnection(node, connectionProfile, fut)); + } + + /** + * Establishes a new connection to the given node. The connection is NOT maintained by this service, it's the callers + * responsibility to close the connection once it goes out of scope. + * The ActionListener will be called on the calling thread or the generic thread pool. * @param node the node to connect to * @param connectionProfile the connection profile to use + * @param listener the action listener to notify */ - public Transport.Connection openConnection(final DiscoveryNode node, ConnectionProfile connectionProfile) throws IOException { + public void openConnection(final DiscoveryNode node, ConnectionProfile connectionProfile, + ActionListener listener) { if (isLocalNode(node)) { - return localNodeConnection; + listener.onResponse(localNodeConnection); } else { - return connectionManager.openConnection(node, connectionProfile); + connectionManager.openConnection(node, connectionProfile, listener); } } @@ -367,17 +408,19 @@ public Transport.Connection openConnection(final DiscoveryNode node, ConnectionP * and returns the discovery node of the node the connection * was established with. The handshake will fail if the cluster * name on the target node mismatches the local cluster name. + * The ActionListener will be called on the calling thread or the generic thread pool. * * @param connection the connection to a specific node * @param handshakeTimeout handshake timeout - * @return the connected node + * @param listener action listener to notify * @throws ConnectTransportException if the connection failed * @throws IllegalStateException if the handshake failed */ - public DiscoveryNode handshake( - final Transport.Connection connection, - final long handshakeTimeout) throws ConnectTransportException { - return handshake(connection, handshakeTimeout, clusterName::equals).discoveryNode; + public void handshake( + final Transport.Connection connection, + final long handshakeTimeout, + final ActionListener listener) { + handshake(connection, handshakeTimeout, clusterName::equals, ActionListener.map(listener, HandshakeResponse::getDiscoveryNode)); } /** @@ -385,40 +428,43 @@ public DiscoveryNode handshake( * and returns the discovery node of the node the connection * was established with. The handshake will fail if the cluster * name on the target node doesn't match the local cluster name. + * The ActionListener will be called on the calling thread or the generic thread pool. * * @param connection the connection to a specific node * @param handshakeTimeout handshake timeout * @param clusterNamePredicate cluster name validation predicate - * @return the handshake response + * @param listener action listener to notify * @throws IllegalStateException if the handshake failed */ - public HandshakeResponse handshake( + public void handshake( final Transport.Connection connection, - final long handshakeTimeout, Predicate clusterNamePredicate) { - final HandshakeResponse response; + final long handshakeTimeout, Predicate clusterNamePredicate, + final ActionListener listener) { final DiscoveryNode node = connection.getNode(); - try { - PlainTransportFuture futureHandler = new PlainTransportFuture<>( - new FutureTransportResponseHandler() { - @Override - public HandshakeResponse read(StreamInput in) throws IOException { - return new HandshakeResponse(in); - } - }); - sendRequest(connection, HANDSHAKE_ACTION_NAME, HandshakeRequest.INSTANCE, - TransportRequestOptions.builder().withTimeout(handshakeTimeout).build(), futureHandler); - response = futureHandler.txGet(); - } catch (Exception e) { - throw new IllegalStateException("handshake failed with " + node, e); - } - - if (!clusterNamePredicate.test(response.clusterName)) { - throw new IllegalStateException("handshake failed, mismatched cluster name [" + response.clusterName + "] - " + node); - } else if (response.version.isCompatible(localNode.getVersion()) == false) { - throw new IllegalStateException("handshake failed, incompatible version [" + response.version + "] - " + node); - } + sendRequest(connection, HANDSHAKE_ACTION_NAME, HandshakeRequest.INSTANCE, + TransportRequestOptions.builder().withTimeout(handshakeTimeout).build(), + new ActionListenerResponseHandler<>( + new ActionListener<>() { + @Override + public void onResponse(HandshakeResponse response) { + if (!clusterNamePredicate.test(response.clusterName)) { + listener.onFailure(new IllegalStateException("handshake failed, mismatched cluster name [" + + response.clusterName + "] - " + node.toString())); + } else if (response.version.isCompatible(localNode.getVersion()) == false) { + listener.onFailure(new IllegalStateException("handshake failed, incompatible version [" + + response.version + "] - " + node)); + } else { + listener.onResponse(response); + } + } - return response; + @Override + public void onFailure(Exception e) { + listener.onFailure(e); + } + } + , HandshakeResponse::new, ThreadPool.Names.GENERIC + )); } public ConnectionManager getConnectionManager() { diff --git a/server/src/test/java/org/elasticsearch/cluster/NodeConnectionsServiceTests.java b/server/src/test/java/org/elasticsearch/cluster/NodeConnectionsServiceTests.java index 151f4d9268523..26f92b1f83d48 100644 --- a/server/src/test/java/org/elasticsearch/cluster/NodeConnectionsServiceTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/NodeConnectionsServiceTests.java @@ -31,7 +31,6 @@ import org.elasticsearch.common.UUIDs; import org.elasticsearch.common.component.Lifecycle; import org.elasticsearch.common.component.LifecycleListener; -import org.elasticsearch.common.lease.Releasable; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.BoundTransportAddress; import org.elasticsearch.common.transport.TransportAddress; @@ -355,8 +354,9 @@ private TestTransportService(Transport transport, ThreadPool threadPool) { } @Override - public HandshakeResponse handshake(Transport.Connection connection, long timeout, Predicate clusterNamePredicate) { - return new HandshakeResponse(connection.getNode(), new ClusterName(""), Version.CURRENT); + public void handshake(Transport.Connection connection, long timeout, Predicate clusterNamePredicate, + ActionListener listener) { + listener.onResponse(new HandshakeResponse(connection.getNode(), new ClusterName(""), Version.CURRENT)); } @Override @@ -406,7 +406,7 @@ public TransportAddress[] addressesFromString(String address) { } @Override - public Releasable openConnection(DiscoveryNode node, ConnectionProfile profile, ActionListener listener) { + public void openConnection(DiscoveryNode node, ConnectionProfile profile, ActionListener listener) { if (profile == null && randomConnectionExceptions && randomBoolean()) { threadPool.generic().execute(() -> listener.onFailure(new ConnectTransportException(node, "simulated"))); } else { @@ -435,8 +435,6 @@ public boolean isClosed() { } })); } - return () -> { - }; } @Override diff --git a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java index 4c02b205e76a3..6290c5d749299 100644 --- a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java @@ -753,7 +753,7 @@ public void clearNetworkDisruptions() { disconnectedNodes.forEach(nodeName -> { if (testClusterNodes.nodes.containsKey(nodeName)) { final DiscoveryNode node = testClusterNodes.nodes.get(nodeName).node; - testClusterNodes.nodes.values().forEach(n -> n.transportService.getConnectionManager().openConnection(node, null)); + testClusterNodes.nodes.values().forEach(n -> n.transportService.openConnection(node, null)); } }); } diff --git a/server/src/test/java/org/elasticsearch/transport/ConnectionManagerTests.java b/server/src/test/java/org/elasticsearch/transport/ConnectionManagerTests.java index c1dd512e0232d..d74aa88404d4d 100644 --- a/server/src/test/java/org/elasticsearch/transport/ConnectionManagerTests.java +++ b/server/src/test/java/org/elasticsearch/transport/ConnectionManagerTests.java @@ -21,8 +21,8 @@ import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.cluster.node.DiscoveryNode; -import org.elasticsearch.common.CheckedBiConsumer; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.unit.TimeValue; @@ -31,8 +31,12 @@ import org.junit.After; import org.junit.Before; -import java.io.IOException; import java.net.InetAddress; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.BrokenBarrierException; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.CyclicBarrier; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; @@ -94,8 +98,12 @@ public void onNodeDisconnected(DiscoveryNode node) { assertFalse(connectionManager.nodeConnected(node)); AtomicReference connectionRef = new AtomicReference<>(); - CheckedBiConsumer validator = (c, p) -> connectionRef.set(c); - connectionManager.connectToNode(node, connectionProfile, validator); + ConnectionManager.ConnectionValidator validator = (c, p, l) -> { + connectionRef.set(c); + l.onResponse(null); + }; + PlainActionFuture.get( + fut -> connectionManager.connectToNode(node, connectionProfile, validator, ActionListener.map(fut, x -> null))); assertFalse(connection.isClosed()); assertTrue(connectionManager.nodeConnected(node)); @@ -115,7 +123,78 @@ public void onNodeDisconnected(DiscoveryNode node) { assertEquals(1, nodeDisconnectedCount.get()); } - public void testConnectFails() { + public void testConcurrentConnectsAndDisconnects() throws BrokenBarrierException, InterruptedException { + DiscoveryNode node = new DiscoveryNode("", new TransportAddress(InetAddress.getLoopbackAddress(), 0), Version.CURRENT); + Transport.Connection connection = new TestConnect(node); + doAnswer(invocationOnMock -> { + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; + if (rarely()) { + listener.onResponse(connection); + } if (frequently()) { + threadPool.generic().execute(() -> listener.onResponse(connection)); + } else { + threadPool.generic().execute(() -> listener.onFailure(new IllegalStateException("dummy exception"))); + } + return null; + }).when(transport).openConnection(eq(node), eq(connectionProfile), any(ActionListener.class)); + + assertFalse(connectionManager.nodeConnected(node)); + + ConnectionManager.ConnectionValidator validator = (c, p, l) -> { + if (rarely()) { + l.onResponse(null); + } if (frequently()) { + threadPool.generic().execute(() -> l.onResponse(null)); + } else { + threadPool.generic().execute(() -> l.onFailure(new IllegalStateException("dummy exception"))); + } + }; + + CyclicBarrier barrier = new CyclicBarrier(11); + List threads = new ArrayList<>(); + AtomicInteger nodeConnectedCount = new AtomicInteger(); + AtomicInteger nodeFailureCount = new AtomicInteger(); + for (int i = 0; i < 10; i++) { + Thread thread = new Thread(() -> { + try { + barrier.await(); + } catch (InterruptedException | BrokenBarrierException e) { + throw new RuntimeException(e); + } + CountDownLatch latch = new CountDownLatch(1); + connectionManager.connectToNode(node, connectionProfile, validator, + ActionListener.wrap(c -> { + nodeConnectedCount.incrementAndGet(); + assert latch.getCount() == 1; + latch.countDown(); + }, e -> { + nodeFailureCount.incrementAndGet(); + assert latch.getCount() == 1; + latch.countDown(); + })); + try { + latch.await(); + } catch (InterruptedException e) { + throw new IllegalStateException(e); + } + }); + threads.add(thread); + thread.start(); + } + + barrier.await(); + threads.forEach(t -> { + try { + t.join(); + } catch (InterruptedException e) { + throw new IllegalStateException(e); + } + }); + + assertEquals(10, nodeConnectedCount.get() + nodeFailureCount.get()); + } + + public void testConnectFailsDuringValidation() { AtomicInteger nodeConnectedCount = new AtomicInteger(); AtomicInteger nodeDisconnectedCount = new AtomicInteger(); connectionManager.addListener(new TransportConnectionListener() { @@ -141,11 +220,11 @@ public void onNodeDisconnected(DiscoveryNode node) { assertFalse(connectionManager.nodeConnected(node)); - CheckedBiConsumer validator = (c, p) -> { - throw new ConnectTransportException(node, ""); - }; + ConnectionManager.ConnectionValidator validator = (c, p, l) -> l.onFailure(new ConnectTransportException(node, "")); - expectThrows(ConnectTransportException.class, () -> connectionManager.connectToNode(node, connectionProfile, validator)); + PlainActionFuture fut = new PlainActionFuture<>(); + connectionManager.connectToNode(node, connectionProfile, validator, fut); + expectThrows(ConnectTransportException.class, () -> fut.actionGet()); assertTrue(connection.isClosed()); assertFalse(connectionManager.nodeConnected(node)); @@ -155,6 +234,44 @@ public void onNodeDisconnected(DiscoveryNode node) { assertEquals(0, nodeDisconnectedCount.get()); } + public void testConnectFailsDuringConnect() { + AtomicInteger nodeConnectedCount = new AtomicInteger(); + AtomicInteger nodeDisconnectedCount = new AtomicInteger(); + connectionManager.addListener(new TransportConnectionListener() { + @Override + public void onNodeConnected(DiscoveryNode node) { + nodeConnectedCount.incrementAndGet(); + } + + @Override + public void onNodeDisconnected(DiscoveryNode node) { + nodeDisconnectedCount.incrementAndGet(); + } + }); + + + DiscoveryNode node = new DiscoveryNode("", new TransportAddress(InetAddress.getLoopbackAddress(), 0), Version.CURRENT); + doAnswer(invocationOnMock -> { + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onFailure(new ConnectTransportException(node, "")); + return null; + }).when(transport).openConnection(eq(node), eq(connectionProfile), any(ActionListener.class)); + + assertFalse(connectionManager.nodeConnected(node)); + + ConnectionManager.ConnectionValidator validator = (c, p, l) -> l.onResponse(null); + + PlainActionFuture fut = new PlainActionFuture<>(); + connectionManager.connectToNode(node, connectionProfile, validator, fut); + expectThrows(ConnectTransportException.class, () -> fut.actionGet()); + + assertFalse(connectionManager.nodeConnected(node)); + expectThrows(NodeNotConnectedException.class, () -> connectionManager.getConnection(node)); + assertEquals(0, connectionManager.size()); + assertEquals(0, nodeConnectedCount.get()); + assertEquals(0, nodeDisconnectedCount.get()); + } + private static class TestConnect extends CloseableConnection { private final DiscoveryNode node; diff --git a/server/src/test/java/org/elasticsearch/transport/RemoteClusterConnectionTests.java b/server/src/test/java/org/elasticsearch/transport/RemoteClusterConnectionTests.java index e93755873c917..d62bd37564d74 100644 --- a/server/src/test/java/org/elasticsearch/transport/RemoteClusterConnectionTests.java +++ b/server/src/test/java/org/elasticsearch/transport/RemoteClusterConnectionTests.java @@ -1261,35 +1261,35 @@ public static Transport getProxyTransport(ThreadPool threadPool, Map delegatedListener.onResponse( - new Transport.Connection() { - @Override - public DiscoveryNode getNode() { - return node; - } + t.openConnection(proxyNode, profile, ActionListener.delegateFailure(listener, + (delegatedListener, connection) -> delegatedListener.onResponse( + new Transport.Connection() { + @Override + public DiscoveryNode getNode() { + return node; + } - @Override - public void sendRequest(long requestId, String action, TransportRequest request, - TransportRequestOptions options) throws IOException { - connection.sendRequest(requestId, action, request, options); - } + @Override + public void sendRequest(long requestId, String action, TransportRequest request, + TransportRequestOptions options) throws IOException { + connection.sendRequest(requestId, action, request, options); + } - @Override - public void addCloseListener(ActionListener listener) { - connection.addCloseListener(listener); - } + @Override + public void addCloseListener(ActionListener listener) { + connection.addCloseListener(listener); + } - @Override - public boolean isClosed() { - return connection.isClosed(); - } + @Override + public boolean isClosed() { + return connection.isClosed(); + } - @Override - public void close() { - connection.close(); - } - }))); + @Override + public void close() { + connection.close(); + } + }))); }); return stubbableTransport; } diff --git a/server/src/test/java/org/elasticsearch/transport/TransportServiceHandshakeTests.java b/server/src/test/java/org/elasticsearch/transport/TransportServiceHandshakeTests.java index 65292918752b8..5e96dd8ec8385 100644 --- a/server/src/test/java/org/elasticsearch/transport/TransportServiceHandshakeTests.java +++ b/server/src/test/java/org/elasticsearch/transport/TransportServiceHandshakeTests.java @@ -20,6 +20,8 @@ package org.elasticsearch.transport; import org.elasticsearch.Version; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkService; @@ -108,7 +110,7 @@ public void testConnectToNodeLight() throws IOException { emptySet(), Version.CURRENT.minimumCompatibilityVersion()); try (Transport.Connection connection = handleA.transportService.openConnection(discoveryNode, TestProfiles.LIGHT_PROFILE)){ - DiscoveryNode connectedNode = handleA.transportService.handshake(connection, timeout); + DiscoveryNode connectedNode = PlainActionFuture.get(fut -> handleA.transportService.handshake(connection, timeout, fut)); assertNotNull(connectedNode); // the name and version should be updated assertEquals(connectedNode.getName(), "TS_B"); @@ -130,7 +132,7 @@ public void testMismatchedClusterName() { IllegalStateException ex = expectThrows(IllegalStateException.class, () -> { try (Transport.Connection connection = handleA.transportService.openConnection(discoveryNode, TestProfiles.LIGHT_PROFILE)) { - handleA.transportService.handshake(connection, timeout); + PlainActionFuture.get(fut -> handleA.transportService.handshake(connection, timeout, ActionListener.map(fut, x -> null))); } }); assertThat(ex.getMessage(), containsString("handshake failed, mismatched cluster name [Cluster [b]]")); @@ -151,7 +153,7 @@ public void testIncompatibleVersions() { IllegalStateException ex = expectThrows(IllegalStateException.class, () -> { try (Transport.Connection connection = handleA.transportService.openConnection(discoveryNode, TestProfiles.LIGHT_PROFILE)) { - handleA.transportService.handshake(connection, timeout); + PlainActionFuture.get(fut -> handleA.transportService.handshake(connection, timeout, ActionListener.map(fut, x -> null))); } }); assertThat(ex.getMessage(), containsString("handshake failed, incompatible version")); diff --git a/test/framework/src/main/java/org/elasticsearch/test/disruption/DisruptableMockTransport.java b/test/framework/src/main/java/org/elasticsearch/test/disruption/DisruptableMockTransport.java index eb39b1c16d00c..d035d1ab713e1 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/disruption/DisruptableMockTransport.java +++ b/test/framework/src/main/java/org/elasticsearch/test/disruption/DisruptableMockTransport.java @@ -23,7 +23,6 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.Nullable; -import org.elasticsearch.common.lease.Releasable; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.BoundTransportAddress; @@ -48,7 +47,6 @@ import java.util.function.Function; import static org.elasticsearch.test.ESTestCase.copyWriteable; -import static org.elasticsearch.transport.TransportService.HANDSHAKE_ACTION_NAME; public abstract class DisruptableMockTransport extends MockTransport { private final DiscoveryNode localNode; @@ -65,15 +63,6 @@ public DisruptableMockTransport(DiscoveryNode localNode, Logger logger) { protected abstract void execute(Runnable runnable); - protected final void execute(String action, Runnable runnable) { - // handshake needs to run inline as the caller blockingly waits on the result - if (action.equals(HANDSHAKE_ACTION_NAME)) { - runnable.run(); - } else { - execute(runnable); - } - } - public DiscoveryNode getLocalNode() { return localNode; } @@ -86,30 +75,30 @@ public TransportService createTransportService(Settings settings, ThreadPool thr } @Override - public Releasable openConnection(DiscoveryNode node, ConnectionProfile profile, ActionListener listener) { + public void openConnection(DiscoveryNode node, ConnectionProfile profile, ActionListener listener) { final Optional optionalMatchingTransport = getDisruptableMockTransport(node.getAddress()); if (optionalMatchingTransport.isPresent()) { final DisruptableMockTransport matchingTransport = optionalMatchingTransport.get(); final ConnectionStatus connectionStatus = getConnectionStatus(matchingTransport.getLocalNode()); if (connectionStatus != ConnectionStatus.CONNECTED) { - throw new ConnectTransportException(node, "node [" + node + "] is [" + connectionStatus + "] not [CONNECTED]"); - } - - listener.onResponse(new CloseableConnection() { - @Override - public DiscoveryNode getNode() { - return node; - } + listener.onFailure( + new ConnectTransportException(node, "node [" + node + "] is [" + connectionStatus + "] not [CONNECTED]")); + } else { + listener.onResponse(new CloseableConnection() { + @Override + public DiscoveryNode getNode() { + return node; + } - @Override - public void sendRequest(long requestId, String action, TransportRequest request, TransportRequestOptions options) - throws TransportException { - onSendRequest(requestId, action, request, matchingTransport); - } - }); - return () -> {}; + @Override + public void sendRequest(long requestId, String action, TransportRequest request, TransportRequestOptions options) + throws TransportException { + onSendRequest(requestId, action, request, matchingTransport); + } + }); + } } else { - throw new ConnectTransportException(node, "node [" + node + "] does not exist"); + listener.onFailure(new ConnectTransportException(node, "node " + node + " does not exist")); } } @@ -119,7 +108,7 @@ protected void onSendRequest(long requestId, String action, TransportRequest req assert destinationTransport.getLocalNode().equals(getLocalNode()) == false : "non-local message from " + getLocalNode() + " to itself"; - destinationTransport.execute(action, new Runnable() { + destinationTransport.execute(new Runnable() { @Override public void run() { final ConnectionStatus connectionStatus = getConnectionStatus(destinationTransport.getLocalNode()); @@ -169,18 +158,11 @@ protected String getRequestDescription(long requestId, String action, DiscoveryN } protected void onBlackholedDuringSend(long requestId, String action, DisruptableMockTransport destinationTransport) { - if (action.equals(HANDSHAKE_ACTION_NAME)) { - logger.trace("ignoring blackhole and delivering {}", - getRequestDescription(requestId, action, destinationTransport.getLocalNode())); - // handshakes always have a timeout, and are sent in a blocking fashion, so we must respond with an exception. - destinationTransport.execute(action, getDisconnectException(requestId, action, destinationTransport.getLocalNode())); - } else { - logger.trace("dropping {}", getRequestDescription(requestId, action, destinationTransport.getLocalNode())); - } + logger.trace("dropping {}", getRequestDescription(requestId, action, destinationTransport.getLocalNode())); } protected void onDisconnectedDuringSend(long requestId, String action, DisruptableMockTransport destinationTransport) { - destinationTransport.execute(action, getDisconnectException(requestId, action, destinationTransport.getLocalNode())); + destinationTransport.execute(getDisconnectException(requestId, action, destinationTransport.getLocalNode())); } protected void onConnectedDuringSend(long requestId, String action, TransportRequest request, @@ -205,7 +187,7 @@ public String getChannelType() { @Override public void sendResponse(final TransportResponse response) { - execute(action, new Runnable() { + execute(new Runnable() { @Override public void run() { final ConnectionStatus connectionStatus = destinationTransport.getConnectionStatus(getLocalNode()); @@ -234,7 +216,8 @@ public String toString() { @Override public void sendResponse(Exception exception) { - execute(action, new Runnable() { + + execute(new Runnable() { @Override public void run() { final ConnectionStatus connectionStatus = destinationTransport.getConnectionStatus(getLocalNode()); diff --git a/test/framework/src/main/java/org/elasticsearch/test/transport/MockTransport.java b/test/framework/src/main/java/org/elasticsearch/test/transport/MockTransport.java index 8086289127ece..93832833b7ff4 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/transport/MockTransport.java +++ b/test/framework/src/main/java/org/elasticsearch/test/transport/MockTransport.java @@ -31,7 +31,6 @@ import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; -import org.elasticsearch.common.lease.Releasable; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.BoundTransportAddress; @@ -164,9 +163,8 @@ public void handleError(final long requestId, final TransportException e) { } @Override - public Releasable openConnection(DiscoveryNode node, ConnectionProfile profile, ActionListener listener) { + public void openConnection(DiscoveryNode node, ConnectionProfile profile, ActionListener listener) { listener.onResponse(createConnection(node)); - return () -> {}; } public Connection createConnection(DiscoveryNode node) { diff --git a/test/framework/src/main/java/org/elasticsearch/test/transport/MockTransportService.java b/test/framework/src/main/java/org/elasticsearch/test/transport/MockTransportService.java index 7cd706b3564ce..93ea8309294a1 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/transport/MockTransportService.java +++ b/test/framework/src/main/java/org/elasticsearch/test/transport/MockTransportService.java @@ -30,7 +30,6 @@ import org.elasticsearch.common.UUIDs; import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; -import org.elasticsearch.common.lease.Releasable; import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Setting; @@ -222,10 +221,8 @@ public void addFailToSendNoConnectRule(TransportService transportService) { * is added to fail as well. */ public void addFailToSendNoConnectRule(TransportAddress transportAddress) { - transport().addConnectBehavior(transportAddress, (transport, discoveryNode, profile, listener) -> { - listener.onFailure(new ConnectTransportException(discoveryNode, "DISCONNECT: simulated")); - return () -> {}; - }); + transport().addConnectBehavior(transportAddress, (transport, discoveryNode, profile, listener) -> + listener.onFailure(new ConnectTransportException(discoveryNode, "DISCONNECT: simulated"))); transport().addSendBehavior(transportAddress, (connection, requestId, action, request, options) -> { connection.close(); @@ -278,10 +275,8 @@ public void addUnresponsiveRule(TransportService transportService) { * and failing to connect once the rule was added. */ public void addUnresponsiveRule(TransportAddress transportAddress) { - transport().addConnectBehavior(transportAddress, (transport, discoveryNode, profile, listener) -> { - listener.onFailure(new ConnectTransportException(discoveryNode, "UNRESPONSIVE: simulated")); - return () -> {}; - }); + transport().addConnectBehavior(transportAddress, (transport, discoveryNode, profile, listener) -> + listener.onFailure(new ConnectTransportException(discoveryNode, "UNRESPONSIVE: simulated"))); transport().addSendBehavior(transportAddress, new StubbableTransport.SendRequestBehavior() { private Set toClose = ConcurrentHashMap.newKeySet(); @@ -331,11 +326,12 @@ public void addUnresponsiveRule(TransportAddress transportAddress, final TimeVal transport().addConnectBehavior(transportAddress, new StubbableTransport.OpenConnectionBehavior() { private CountDownLatch stopLatch = new CountDownLatch(1); @Override - public Releasable openConnection(Transport transport, DiscoveryNode discoveryNode, + public void openConnection(Transport transport, DiscoveryNode discoveryNode, ConnectionProfile profile, ActionListener listener) { TimeValue delay = delaySupplier.get(); if (delay.millis() <= 0) { - return original.openConnection(discoveryNode, profile, listener); + original.openConnection(discoveryNode, profile, listener); + return; } // TODO: Replace with proper setting @@ -343,17 +339,13 @@ public Releasable openConnection(Transport transport, DiscoveryNode discoveryNod try { if (delay.millis() < connectingTimeout.millis()) { stopLatch.await(delay.millis(), TimeUnit.MILLISECONDS); - return original.openConnection(discoveryNode, profile, listener); + original.openConnection(discoveryNode, profile, listener); } else { stopLatch.await(connectingTimeout.millis(), TimeUnit.MILLISECONDS); listener.onFailure(new ConnectTransportException(discoveryNode, "UNRESPONSIVE: simulated")); - return () -> { - }; } } catch (InterruptedException e) { listener.onFailure(new ConnectTransportException(discoveryNode, "UNRESPONSIVE: simulated")); - return () -> { - }; } } @@ -524,7 +516,7 @@ public Transport getOriginalTransport() { } @Override - public Transport.Connection openConnection(DiscoveryNode node, ConnectionProfile profile) throws IOException { + public Transport.Connection openConnection(DiscoveryNode node, ConnectionProfile profile) { Transport.Connection connection = super.openConnection(node, profile); synchronized (openConnections) { diff --git a/test/framework/src/main/java/org/elasticsearch/test/transport/StubbableConnectionManager.java b/test/framework/src/main/java/org/elasticsearch/test/transport/StubbableConnectionManager.java index d01b91258d576..a14eaa691f43e 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/transport/StubbableConnectionManager.java +++ b/test/framework/src/main/java/org/elasticsearch/test/transport/StubbableConnectionManager.java @@ -18,8 +18,8 @@ */ package org.elasticsearch.test.transport; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.node.DiscoveryNode; -import org.elasticsearch.common.CheckedBiConsumer; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.transport.ConnectTransportException; @@ -28,7 +28,6 @@ import org.elasticsearch.transport.Transport; import org.elasticsearch.transport.TransportConnectionListener; -import java.io.IOException; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -80,8 +79,8 @@ public void clearBehavior(TransportAddress transportAddress) { } @Override - public Transport.Connection openConnection(DiscoveryNode node, ConnectionProfile connectionProfile) { - return delegate.openConnection(node, connectionProfile); + public void openConnection(DiscoveryNode node, ConnectionProfile connectionProfile, ActionListener listener) { + delegate.openConnection(node, connectionProfile, listener); } @Override @@ -110,9 +109,9 @@ public void removeListener(TransportConnectionListener listener) { @Override public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile, - CheckedBiConsumer connectionValidator) + ConnectionValidator connectionValidator, ActionListener listener) throws ConnectTransportException { - delegate.connectToNode(node, connectionProfile, connectionValidator); + delegate.connectToNode(node, connectionProfile, connectionValidator, listener); } @Override diff --git a/test/framework/src/main/java/org/elasticsearch/test/transport/StubbableTransport.java b/test/framework/src/main/java/org/elasticsearch/test/transport/StubbableTransport.java index d812fdffe9673..5fe67acde8ab7 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/transport/StubbableTransport.java +++ b/test/framework/src/main/java/org/elasticsearch/test/transport/StubbableTransport.java @@ -24,7 +24,6 @@ import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.component.Lifecycle; import org.elasticsearch.common.component.LifecycleListener; -import org.elasticsearch.common.lease.Releasable; import org.elasticsearch.common.transport.BoundTransportAddress; import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.transport.ConnectionProfile; @@ -128,7 +127,7 @@ public List getDefaultSeedAddresses() { } @Override - public Releasable openConnection(DiscoveryNode node, ConnectionProfile profile, ActionListener listener) { + public void openConnection(DiscoveryNode node, ConnectionProfile profile, ActionListener listener) { TransportAddress address = node.getAddress(); OpenConnectionBehavior behavior = connectBehaviors.getOrDefault(address, defaultConnectBehavior); @@ -137,9 +136,9 @@ public Releasable openConnection(DiscoveryNode node, ConnectionProfile profile, (delegatedListener, connection) -> delegatedListener.onResponse(new WrappedConnection(connection))); if (behavior == null) { - return delegate.openConnection(node, profile, wrappedListener); + delegate.openConnection(node, profile, wrappedListener); } else { - return behavior.openConnection(delegate, node, profile, wrappedListener); + behavior.openConnection(delegate, node, profile, wrappedListener); } } @@ -247,8 +246,8 @@ public Transport.Connection getConnection() { @FunctionalInterface public interface OpenConnectionBehavior { - Releasable openConnection(Transport transport, DiscoveryNode discoveryNode, ConnectionProfile profile, - ActionListener listener); + void openConnection(Transport transport, DiscoveryNode discoveryNode, ConnectionProfile profile, + ActionListener listener); default void clearCallback() {} } diff --git a/test/framework/src/test/java/org/elasticsearch/test/disruption/DisruptableMockTransportTests.java b/test/framework/src/test/java/org/elasticsearch/test/disruption/DisruptableMockTransportTests.java index 90a47d09e6ca9..cc85ae0bad3e7 100644 --- a/test/framework/src/test/java/org/elasticsearch/test/disruption/DisruptableMockTransportTests.java +++ b/test/framework/src/test/java/org/elasticsearch/test/disruption/DisruptableMockTransportTests.java @@ -20,6 +20,7 @@ package org.elasticsearch.test.disruption; import org.elasticsearch.Version; +import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.cluster.coordination.DeterministicTaskQueue; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.collect.Tuple; @@ -146,8 +147,13 @@ protected void execute(Runnable runnable) { service1.start(); service2.start(); - service1.connectToNode(node2); - service2.connectToNode(node1); + final PlainActionFuture fut1 = new PlainActionFuture<>(); + service1.connectToNode(node2, fut1); + final PlainActionFuture fut2 = new PlainActionFuture<>(); + service2.connectToNode(node1, fut2); + deterministicTaskQueue.runAllTasksInTimeOrder(); + assertTrue(fut1.isDone()); + assertTrue(fut2.isDone()); } private TransportRequestHandler requestHandlerShouldNotBeCalled() {