Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.SuppressForbidden;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
Expand All @@ -51,12 +50,10 @@
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeUnit;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TcpChannel;
import org.elasticsearch.transport.TcpTransport;
import org.elasticsearch.transport.TransportRequestOptions;

Expand Down Expand Up @@ -239,9 +236,8 @@ protected final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
}

@Override
protected NettyTcpChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<Void> listener)
throws IOException {
ChannelFuture channelFuture = bootstrap.connect(node.getAddress().address());
protected NettyTcpChannel initiateChannel(InetSocketAddress address, ActionListener<Void> listener) throws IOException {
ChannelFuture channelFuture = bootstrap.connect(address);
Channel channel = channelFuture.channel();
if (channel == null) {
Netty4Utils.maybeDie(channelFuture.cause());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,12 @@

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.recycler.Recycler;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.common.util.concurrent.EsExecutors;
Expand Down Expand Up @@ -93,9 +91,8 @@ protected TcpNioServerSocketChannel bind(String name, InetSocketAddress address)
}

@Override
protected TcpNioSocketChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<Void> connectListener)
throws IOException {
TcpNioSocketChannel channel = nioGroup.openChannel(node.getAddress().address(), clientChannelFactory);
protected TcpNioSocketChannel initiateChannel(InetSocketAddress address, ActionListener<Void> connectListener) throws IOException {
TcpNioSocketChannel channel = nioGroup.openChannel(address, clientChannelFactory);
channel.addConnectListener(ActionListener.toBiConsumer(connectListener));
return channel;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ public final NodeChannels openConnection(DiscoveryNode node, ConnectionProfile c
try {
PlainActionFuture<Void> connectFuture = PlainActionFuture.newFuture();
connectionFutures.add(connectFuture);
TcpChannel channel = initiateChannel(node, connectionProfile.getConnectTimeout(), connectFuture);
TcpChannel channel = initiateChannel(node.getAddress().address(), connectFuture);
logger.trace(() -> new ParameterizedMessage("Tcp transport client channel opened: {}", channel));
channels.add(channel);
} catch (Exception e) {
Expand Down Expand Up @@ -1057,17 +1057,14 @@ protected void serverAcceptedChannel(TcpChannel channel) {
protected abstract TcpChannel bind(String name, InetSocketAddress address) throws IOException;

/**
* Initiate a single tcp socket channel to a node. Implementations do not have to observe the connectTimeout.
* It is provided for synchronous connection implementations.
* Initiate a single tcp socket channel.
*
* @param node the node
* @param connectTimeout the connection timeout
* @param connectListener listener to be called when connection complete
* @param address address for the initiated connection
* @param connectListener listener to be called when connection complete
* @return the pending connection
* @throws IOException if an I/O exception occurs while opening the channel
*/
protected abstract TcpChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<Void> connectListener)
throws IOException;
protected abstract TcpChannel initiateChannel(InetSocketAddress address, ActionListener<Void> connectListener) throws IOException;

/**
* Called to tear down internal resources
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.compress.CompressorFactory;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
Expand All @@ -41,15 +40,13 @@
import java.io.IOException;
import java.io.StreamCorruptedException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.core.IsInstanceOf.instanceOf;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;

/** Unit tests for {@link TcpTransport} */
public class TcpTransportTests extends ESTestCase {
Expand Down Expand Up @@ -193,8 +190,7 @@ protected FakeChannel bind(String name, InetSocketAddress address) throws IOExce
}

@Override
protected FakeChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<Void> connectListener)
throws IOException {
protected FakeChannel initiateChannel(InetSocketAddress address, ActionListener<Void> connectListener) throws IOException {
return new FakeChannel(messageCaptor);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.apache.lucene.util.IOUtils;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.InputStreamStreamInput;
Expand All @@ -30,7 +29,6 @@
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.CancellableThreads;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
Expand All @@ -49,7 +47,6 @@
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketException;
import java.net.SocketTimeoutException;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
Expand All @@ -61,7 +58,6 @@
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;

/**
* This is a socket based blocking TcpTransport implementation that is used for tests
Expand Down Expand Up @@ -164,28 +160,32 @@ private void readMessage(MockChannel mockChannel, StreamInput input) throws IOEx
}

@Override
protected MockChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<Void> connectListener)
throws IOException {
InetSocketAddress address = node.getAddress().address();
protected MockChannel initiateChannel(InetSocketAddress address, ActionListener<Void> connectListener) throws IOException {
final MockSocket socket = new MockSocket();
final MockChannel channel = new MockChannel(socket, address, "none");

boolean success = false;
try {
configureSocket(socket);
try {
socket.connect(address, Math.toIntExact(connectTimeout.millis()));
} catch (SocketTimeoutException ex) {
throw new ConnectTransportException(node, "connect_timeout[" + connectTimeout + "]", ex);
}
MockChannel channel = new MockChannel(socket, address, "none", (c) -> {});
channel.loopRead(executor);
success = true;
connectListener.onResponse(null);
return channel;
} finally {
if (success == false) {
IOUtils.close(socket);
}

}

executor.submit(() -> {
try {
socket.connect(address);
channel.loopRead(executor);
connectListener.onResponse(null);
} catch (Exception ex) {
connectListener.onFailure(ex);
}
});

return channel;
}

@Override
Expand Down Expand Up @@ -218,7 +218,6 @@ public final class MockChannel implements Closeable, TcpChannel {
private final Socket activeChannel;
private final String profile;
private final CancellableThreads cancellableThreads = new CancellableThreads();
private final Closeable onClose;
private final CompletableFuture<Void> closeFuture = new CompletableFuture<>();

/**
Expand All @@ -227,14 +226,12 @@ public final class MockChannel implements Closeable, TcpChannel {
* @param socket The client socket. Mut not be null.
* @param localAddress Address associated with the corresponding local server socket. Must not be null.
* @param profile The associated profile name.
* @param onClose Callback to execute when this channel is closed.
*/
public MockChannel(Socket socket, InetSocketAddress localAddress, String profile, Consumer<MockChannel> onClose) {
public MockChannel(Socket socket, InetSocketAddress localAddress, String profile) {
this.localAddress = localAddress;
this.activeChannel = socket;
this.serverSocket = null;
this.profile = profile;
this.onClose = () -> onClose.accept(this);
synchronized (openChannels) {
openChannels.add(this);
}
Expand All @@ -246,12 +243,11 @@ public MockChannel(Socket socket, InetSocketAddress localAddress, String profile
* @param serverSocket The associated server socket. Must not be null.
* @param profile The associated profile name.
*/
public MockChannel(ServerSocket serverSocket, String profile) {
MockChannel(ServerSocket serverSocket, String profile) {
this.localAddress = (InetSocketAddress) serverSocket.getLocalSocketAddress();
this.serverSocket = serverSocket;
this.profile = profile;
this.activeChannel = null;
this.onClose = null;
synchronized (openChannels) {
openChannels.add(this);
}
Expand All @@ -266,8 +262,19 @@ public void accept(Executor executor) throws IOException {
synchronized (this) {
if (isOpen.get()) {
incomingChannel = new MockChannel(incomingSocket,
new InetSocketAddress(incomingSocket.getLocalAddress(), incomingSocket.getPort()), profile,
workerChannels::remove);
new InetSocketAddress(incomingSocket.getLocalAddress(), incomingSocket.getPort()), profile);
MockChannel finalIncomingChannel = incomingChannel;
incomingChannel.addCloseListener(new ActionListener<Void>() {
@Override
public void onResponse(Void aVoid) {
workerChannels.remove(finalIncomingChannel);
}

@Override
public void onFailure(Exception e) {
workerChannels.remove(finalIncomingChannel);
}
});
serverAcceptedChannel(incomingChannel);
//establish a happens-before edge between closing and accepting a new connection
workerChannels.add(incomingChannel);
Expand All @@ -287,7 +294,7 @@ public void accept(Executor executor) throws IOException {
}
}

public void loopRead(Executor executor) {
void loopRead(Executor executor) {
executor.execute(new AbstractRunnable() {
@Override
public void onFailure(Exception e) {
Expand All @@ -312,7 +319,7 @@ protected void doRun() throws Exception {
});
}

public synchronized void close0() throws IOException {
synchronized void close0() throws IOException {
// establish a happens-before edge between closing and accepting a new connection
// we have to sync this entire block to ensure that our openChannels checks work correctly.
// The close block below will close all worker channels but if one of the worker channels runs into an exception
Expand All @@ -325,7 +332,7 @@ public synchronized void close0() throws IOException {
removedChannel = openChannels.remove(this);
}
IOUtils.close(serverSocket, activeChannel, () -> IOUtils.close(workerChannels),
() -> cancellableThreads.cancel("channel closed"), onClose);
() -> cancellableThreads.cancel("channel closed"));
assert removedChannel: "Channel was not removed or removed twice?";
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,11 @@

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.recycler.Recycler;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
Expand Down Expand Up @@ -83,9 +81,8 @@ protected MockServerChannel bind(String name, InetSocketAddress address) throws
}

@Override
protected MockSocketChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<Void> connectListener)
throws IOException {
MockSocketChannel channel = nioGroup.openChannel(node.getAddress().address(), clientChannelFactory);
protected MockSocketChannel initiateChannel(InetSocketAddress address, ActionListener<Void> connectListener) throws IOException {
MockSocketChannel channel = nioGroup.openChannel(address, clientChannelFactory);
channel.addConnectListener(ActionListener.toBiConsumer(connectListener));
return channel;
}
Expand Down