Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion core/src/main/java/org/elasticsearch/transport/TcpTransport.java
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,9 @@ public final NodeChannels openConnection(DiscoveryNode node, ConnectionProfile c
nodeChannels = new NodeChannels(nodeChannels, version); // clone the channels - we now have the correct version
transportService.onConnectionOpened(nodeChannels);
connectionRef.set(nodeChannels);
if (Arrays.stream(nodeChannels.channels).allMatch(this::isOpen) == false) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this solution solves the majority of the cases where the original issue arises but don't we still have an issue if this check occurs and immediately after it succeeds, concurrently the other side disconnects? The event loop is a different thread so the close listener could still be executed prior the nodes channels being returned from this scope.

In order to be safe don't we need to check if all the channels are still open AFTER we put it in the connectedNodes map in the connectToNode method?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's enough to check they are open after we've set the reference to the channels because the problem arises if we close before that reference is set; after that reference is set we are covered by the close listener?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah that’s probably right. I was thinking about the connection map that is not manipulated until later.

throw new ConnectTransportException(node, "a channel closed while connecting");
}
success = true;
return nodeChannels;
} catch (ConnectTransportException e) {
Expand Down Expand Up @@ -1034,7 +1037,18 @@ protected void innerOnFailure(Exception e) {
*/
protected abstract void sendMessage(Channel channel, BytesReference reference, ActionListener<Channel> listener);

protected abstract NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile connectionProfile,
/**
* Connect to the node with channels as defined by the specified connection profile. Implementations must invoke the specified channel
* close callback when a channel is closed.
*
* @param node the node to connect to
* @param connectionProfile the connection profile
* @param onChannelClose callback to invoke when a channel is closed
* @return the channels
* @throws IOException if an I/O exception occurs while opening channels
*/
protected abstract NodeChannels connectToChannels(DiscoveryNode node,
ConnectionProfile connectionProfile,
Consumer<Channel> onChannelClose) throws IOException;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import java.util.Objects;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.function.Function;
import java.util.function.Predicate;
Expand Down Expand Up @@ -187,6 +188,15 @@ protected TaskManager createTaskManager() {
return new TaskManager(settings);
}

/**
* The executor service for this transport service.
*
* @return the executor service
*/
protected ExecutorService getExecutorService() {
return threadPool.generic();
}

void setTracerLogInclude(List<String> tracerLogInclude) {
this.tracerLogInclude = tracerLogInclude.toArray(Strings.EMPTY_ARRAY);
}
Expand Down Expand Up @@ -232,7 +242,7 @@ protected void doStop() {
if (holderToNotify != null) {
// callback that an exception happened, but on a different thread since we don't
// want handlers to worry about stack overflows
threadPool.generic().execute(new AbstractRunnable() {
getExecutorService().execute(new AbstractRunnable() {
@Override
public void onRejection(Exception e) {
// if we get rejected during node shutdown we don't wanna bubble it up
Expand Down Expand Up @@ -879,20 +889,20 @@ void onNodeConnected(final DiscoveryNode node) {
// connectToNode(); connection is completed successfully
// addConnectionListener(); this listener shouldn't be called
final Stream<TransportConnectionListener> listenersToNotify = TransportService.this.connectionListeners.stream();
threadPool.generic().execute(() -> listenersToNotify.forEach(listener -> listener.onNodeConnected(node)));
getExecutorService().execute(() -> listenersToNotify.forEach(listener -> listener.onNodeConnected(node)));
}

void onConnectionOpened(Transport.Connection connection) {
// capture listeners before spawning the background callback so the following pattern won't trigger a call
// connectToNode(); connection is completed successfully
// addConnectionListener(); this listener shouldn't be called
final Stream<TransportConnectionListener> listenersToNotify = TransportService.this.connectionListeners.stream();
threadPool.generic().execute(() -> listenersToNotify.forEach(listener -> listener.onConnectionOpened(connection)));
getExecutorService().execute(() -> listenersToNotify.forEach(listener -> listener.onConnectionOpened(connection)));
}

public void onNodeDisconnected(final DiscoveryNode node) {
try {
threadPool.generic().execute( () -> {
getExecutorService().execute( () -> {
for (final TransportConnectionListener connectionListener : connectionListeners) {
connectionListener.onNodeDisconnected(node);
}
Expand All @@ -911,7 +921,7 @@ void onConnectionClosed(Transport.Connection connection) {
if (holderToNotify != null) {
// callback that an exception happened, but on a different thread since we don't
// want handlers to worry about stack overflows
threadPool.generic().execute(() -> holderToNotify.handler().handleException(new NodeDisconnectedException(
getExecutorService().execute(() -> holderToNotify.handler().handleException(new NodeDisconnectedException(
connection.getNode(), holderToNotify.action())));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,8 @@ protected void sendMessage(Object o, BytesReference reference, ActionListener li
}

@Override
protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile profile,
Consumer onChannelClose) throws IOException {
protected NodeChannels connectToChannels(
DiscoveryNode node, ConnectionProfile profile, Consumer onChannelClose) throws IOException {
return new NodeChannels(node, new Object[profile.getNumConnections()], profile);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
public class SimpleNetty4TransportTests extends AbstractSimpleTransportTestCase {

public static MockTransportService nettyFromThreadPool(Settings settings, ThreadPool threadPool, final Version version,
ClusterSettings clusterSettings, boolean doHandshake) {
ClusterSettings clusterSettings, boolean doHandshake) {
NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList());
Transport transport = new Netty4Transport(settings, threadPool, new NetworkService(Collections.emptyList()),
BigArrays.NON_RECYCLING_INSTANCE, namedWriteableRegistry, new NoneCircuitBreakerService()) {
Expand Down Expand Up @@ -86,6 +86,13 @@ protected MockTransportService build(Settings settings, Version version, Cluster
return transportService;
}

@Override
protected void closeConnectionChannel(Transport transport, Transport.Connection connection) throws IOException {
final Netty4Transport t = (Netty4Transport) transport;
@SuppressWarnings("unchecked") final TcpTransport<Channel>.NodeChannels channels = (TcpTransport<Channel>.NodeChannels) connection;
t.closeChannels(channels.getChannels().subList(0, randomIntBetween(1, channels.getChannels().size())), true, false);
}

public void testConnectException() throws UnknownHostException {
try {
serviceA.connectToNode(new DiscoveryNode("C", new TransportAddress(InetAddress.getByName("localhost"), 9876),
Expand All @@ -108,7 +115,8 @@ public void testBindUnavailableAddress() {
.build();
ClusterSettings clusterSettings = new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
BindTransportException bindTransportException = expectThrows(BindTransportException.class, () -> {
MockTransportService transportService = nettyFromThreadPool(settings, threadPool, Version.CURRENT, clusterSettings, true);
MockTransportService transportService =
nettyFromThreadPool(settings, threadPool, Version.CURRENT, clusterSettings, true);
try {
transportService.start();
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
import java.util.Set;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
Expand Down Expand Up @@ -167,6 +168,17 @@ protected TaskManager createTaskManager() {
}
}

private volatile String executorName;

public void setExecutorName(final String executorName) {
this.executorName = executorName;
}

@Override
protected ExecutorService getExecutorService() {
return executorName == null ? super.getExecutorService() : getThreadPool().executor(executorName);
}

/**
* Clears all the registered rules.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,10 @@

import static java.util.Collections.emptyMap;
import static java.util.Collections.emptySet;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasToString;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.startsWith;
Expand Down Expand Up @@ -147,14 +149,14 @@ public void onNodeDisconnected(DiscoveryNode node) {
private MockTransportService buildService(final String name, final Version version, ClusterSettings clusterSettings,
Settings settings, boolean acceptRequests, boolean doHandshake) {
MockTransportService service = build(
Settings.builder()
.put(settings)
.put(Node.NODE_NAME_SETTING.getKey(), name)
.put(TransportService.TRACE_LOG_INCLUDE_SETTING.getKey(), "")
.put(TransportService.TRACE_LOG_EXCLUDE_SETTING.getKey(), "NOTHING")
.build(),
version,
clusterSettings, doHandshake);
Settings.builder()
.put(settings)
.put(Node.NODE_NAME_SETTING.getKey(), name)
.put(TransportService.TRACE_LOG_INCLUDE_SETTING.getKey(), "")
.put(TransportService.TRACE_LOG_EXCLUDE_SETTING.getKey(), "NOTHING")
.build(),
version,
clusterSettings, doHandshake);
if (acceptRequests) {
service.acceptIncomingRequests();
}
Expand Down Expand Up @@ -2612,4 +2614,33 @@ public void testProfilesIncludesDefault() {
assertEquals(new HashSet<>(Arrays.asList("default", "test")), profileSettings.stream().map(s -> s.profileName).collect(Collectors
.toSet()));
}

public void testChannelCloseWhileConnecting() throws IOException {
try (MockTransportService service = build(Settings.builder().put("name", "close").build(), version0, null, true)) {
service.setExecutorName(ThreadPool.Names.SAME); // make sure stuff is executed in a blocking fashion
service.addConnectionListener(new TransportConnectionListener() {
@Override
public void onConnectionOpened(final Transport.Connection connection) {
try {
closeConnectionChannel(service.getOriginalTransport(), connection);
} catch (final IOException e) {
throw new AssertionError(e);
}
}
});
final ConnectionProfile.Builder builder = new ConnectionProfile.Builder();
builder.addConnections(1,
TransportRequestOptions.Type.BULK,
TransportRequestOptions.Type.PING,
TransportRequestOptions.Type.RECOVERY,
TransportRequestOptions.Type.REG,
TransportRequestOptions.Type.STATE);
final ConnectTransportException e =
expectThrows(ConnectTransportException.class, () -> service.openConnection(nodeA, builder.build()));
assertThat(e, hasToString(containsString(("a channel closed while connecting"))));
}
}

protected abstract void closeConnectionChannel(Transport transport, Transport.Connection connection) throws IOException;

}
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ private void readMessage(MockChannel mockChannel, StreamInput input) throws IOEx
}

@Override
protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile profile,
protected NodeChannels connectToChannels(DiscoveryNode node,
ConnectionProfile profile,
Consumer<MockChannel> onChannelClose) throws IOException {
final MockChannel[] mockChannels = new MockChannel[1];
final NodeChannels nodeChannels = new NodeChannels(node, mockChannels, LIGHT_PROFILE); // we always use light here
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ public NioClient(Logger logger, OpenChannels openChannels, Supplier<SocketSelect
this.channelFactory = channelFactory;
}

public boolean connectToChannels(DiscoveryNode node, NioSocketChannel[] channels, TimeValue connectTimeout,
public boolean connectToChannels(DiscoveryNode node,
NioSocketChannel[] channels,
TimeValue connectTimeout,
Consumer<NioChannel> closeListener) throws IOException {
boolean allowedToConnect = semaphore.tryAcquire();
if (allowedToConnect == false) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import java.util.Collections;

public class MockTcpTransportTests extends AbstractSimpleTransportTestCase {

@Override
protected MockTransportService build(Settings settings, Version version, ClusterSettings clusterSettings, boolean doHandshake) {
NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList());
Expand All @@ -53,4 +54,13 @@ protected Version executeHandshake(DiscoveryNode node, MockChannel mockChannel,
mockTransportService.start();
return mockTransportService;
}

@Override
protected void closeConnectionChannel(Transport transport, Transport.Connection connection) throws IOException {
final MockTcpTransport t = (MockTcpTransport) transport;
@SuppressWarnings("unchecked") final TcpTransport<MockTcpTransport.MockChannel>.NodeChannels channels =
(TcpTransport<MockTcpTransport.MockChannel>.NodeChannels) connection;
t.closeChannels(channels.getChannels().subList(0, randomIntBetween(1, channels.getChannels().size())), true, false);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
public class SimpleNioTransportTests extends AbstractSimpleTransportTestCase {

public static MockTransportService nioFromThreadPool(Settings settings, ThreadPool threadPool, final Version version,
ClusterSettings clusterSettings, boolean doHandshake) {
ClusterSettings clusterSettings, boolean doHandshake) {
NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList());
NetworkService networkService = new NetworkService(Collections.emptyList());
Transport transport = new NioTransport(settings, threadPool,
Expand Down Expand Up @@ -96,6 +96,13 @@ protected MockTransportService build(Settings settings, Version version, Cluster
return transportService;
}

@Override
protected void closeConnectionChannel(Transport transport, Transport.Connection connection) throws IOException {
final NioTransport t = (NioTransport) transport;
@SuppressWarnings("unchecked") TcpTransport<NioChannel>.NodeChannels channels = (TcpTransport<NioChannel>.NodeChannels) connection;
t.closeChannels(channels.getChannels().subList(0, randomIntBetween(1, channels.getChannels().size())), true, false);
}

public void testConnectException() throws UnknownHostException {
try {
serviceA.connectToNode(new DiscoveryNode("C", new TransportAddress(InetAddress.getByName("localhost"), 9876),
Expand Down