Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,13 @@ public int skipBytes(int n) throws IOException {

@Override
public byte readByte() throws IOException {
return buffer.readByte();
try {
return buffer.readByte();
} catch (IndexOutOfBoundsException ex) {
EOFException eofException = new EOFException();
eofException.initCause(ex);
throw eofException;
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,13 @@ public int skipBytes(int n) throws IOException {

@Override
public byte readByte() throws IOException {
return buffer.readByte();
try {
return buffer.readByte();
} catch (IndexOutOfBoundsException ex) {
EOFException eofException = new EOFException();
eofException.initCause(ex);
throw eofException;
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements
private volatile Map<String, RequestHandlerRegistry<? extends TransportRequest>> requestHandlers = Collections.emptyMap();
private final ResponseHandlers responseHandlers = new ResponseHandlers();
private final TransportLogger transportLogger;
private final TcpTransportHandshaker handshaker;
private final TransportHandshaker handshaker;
private final TransportKeepAlive keepAlive;
private final String nodeName;

Expand All @@ -224,12 +224,12 @@ public TcpTransport(String transportName, Settings settings, Version version, T
this.networkService = networkService;
this.transportName = transportName;
this.transportLogger = new TransportLogger();
this.handshaker = new TcpTransportHandshaker(version, threadPool,
this.handshaker = new TransportHandshaker(version, threadPool,
(node, channel, requestId, v) -> sendRequestToChannel(node, channel, requestId,
TcpTransportHandshaker.HANDSHAKE_ACTION_NAME, TransportRequest.Empty.INSTANCE, TransportRequestOptions.EMPTY, v,
TransportStatus.setHandshake((byte) 0)),
TransportHandshaker.HANDSHAKE_ACTION_NAME, new TransportHandshaker.HandshakeRequest(version),
TransportRequestOptions.EMPTY, v, TransportStatus.setHandshake((byte) 0)),
(v, features, channel, response, requestId) -> sendResponse(v, features, channel, response, requestId,
TcpTransportHandshaker.HANDSHAKE_ACTION_NAME, TransportResponseOptions.EMPTY, TransportStatus.setHandshake((byte) 0)));
TransportHandshaker.HANDSHAKE_ACTION_NAME, TransportResponseOptions.EMPTY, TransportStatus.setHandshake((byte) 0)));
this.keepAlive = new TransportKeepAlive(threadPool, this::internalSendMessage);
this.nodeName = Node.NODE_NAME_SETTING.get(settings);

Expand Down Expand Up @@ -1287,7 +1287,7 @@ protected String handleRequest(TcpChannel channel, String profileName, final Str
TransportChannel transportChannel = null;
try {
if (TransportStatus.isHandshake(status)) {
handshaker.handleHandshake(version, features, channel, requestId);
handshaker.handleHandshake(version, features, channel, requestId, stream);
} else {
final RequestHandlerRegistry reg = getRequestHandler(action);
if (reg == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.metrics.CounterMetric;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.threadpool.ThreadPool;

import java.io.EOFException;
import java.io.IOException;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
Expand All @@ -37,7 +40,7 @@
* Sends and receives transport-level connection handshakes. This class will send the initial handshake,
* manage state/timeouts while the handshake is in transit, and handle the eventual response.
*/
final class TcpTransportHandshaker {
final class TransportHandshaker {

static final String HANDSHAKE_ACTION_NAME = "internal:tcp/handshake";
private final ConcurrentMap<Long, HandshakeResponseHandler> pendingHandshakes = new ConcurrentHashMap<>();
Expand All @@ -48,8 +51,8 @@ final class TcpTransportHandshaker {
private final HandshakeRequestSender handshakeRequestSender;
private final HandshakeResponseSender handshakeResponseSender;

TcpTransportHandshaker(Version version, ThreadPool threadPool, HandshakeRequestSender handshakeRequestSender,
HandshakeResponseSender handshakeResponseSender) {
TransportHandshaker(Version version, ThreadPool threadPool, HandshakeRequestSender handshakeRequestSender,
HandshakeResponseSender handshakeResponseSender) {
this.version = version;
this.threadPool = threadPool;
this.handshakeRequestSender = handshakeRequestSender;
Expand Down Expand Up @@ -83,11 +86,19 @@ void sendHandshake(long requestId, DiscoveryNode node, TcpChannel channel, TimeV
}
}

void handleHandshake(Version version, Set<String> features, TcpChannel channel, long requestId) throws IOException {
handshakeResponseSender.sendResponse(version, features, channel, new VersionHandshakeResponse(this.version), requestId);
void handleHandshake(Version version, Set<String> features, TcpChannel channel, long requestId, StreamInput stream) throws IOException {
// Must read the handshake request to exhaust the stream
HandshakeRequest handshakeRequest = new HandshakeRequest(stream);
final int nextByte = stream.read();
if (nextByte != -1) {
throw new IllegalStateException("Handshake request not fully read for requestId [" + requestId + "], action ["
+ TransportHandshaker.HANDSHAKE_ACTION_NAME + "], available [" + stream.available() + "]; resetting");
}
HandshakeResponse response = new HandshakeResponse(this.version);
handshakeResponseSender.sendResponse(version, features, channel, response, requestId);
}

TransportResponseHandler<VersionHandshakeResponse> removeHandlerForHandshake(long requestId) {
TransportResponseHandler<HandshakeResponse> removeHandlerForHandshake(long requestId) {
return pendingHandshakes.remove(requestId);
}

Expand All @@ -99,7 +110,7 @@ long getNumHandshakes() {
return numHandshakes.count();
}

private class HandshakeResponseHandler implements TransportResponseHandler<VersionHandshakeResponse> {
private class HandshakeResponseHandler implements TransportResponseHandler<HandshakeResponse> {

private final long requestId;
private final Version currentVersion;
Expand All @@ -113,14 +124,14 @@ private HandshakeResponseHandler(long requestId, Version currentVersion, ActionL
}

@Override
public VersionHandshakeResponse read(StreamInput in) throws IOException {
return new VersionHandshakeResponse(in);
public HandshakeResponse read(StreamInput in) throws IOException {
return new HandshakeResponse(in);
}

@Override
public void handleResponse(VersionHandshakeResponse response) {
public void handleResponse(HandshakeResponse response) {
if (isDone.compareAndSet(false, true)) {
Version version = response.version;
Version version = response.responseVersion;
if (currentVersion.isCompatible(version) == false) {
listener.onFailure(new IllegalStateException("Received message from unsupported version: [" + version
+ "] minimal compatible version is: [" + currentVersion.minimumCompatibilityVersion() + "]"));
Expand Down Expand Up @@ -149,24 +160,75 @@ public String executor() {
}
}

static final class VersionHandshakeResponse extends TransportResponse {
static final class HandshakeRequest extends TransportRequest {

private final Version version;

VersionHandshakeResponse(Version version) {
HandshakeRequest(Version version) {
this.version = version;
}

private VersionHandshakeResponse(StreamInput in) throws IOException {
HandshakeRequest(StreamInput streamInput) throws IOException {
super(streamInput);
BytesReference remainingMessage;
try {
remainingMessage = streamInput.readBytesReference();
} catch (EOFException e) {
remainingMessage = null;
}
if (remainingMessage == null) {
version = null;
} else {
try (StreamInput messageStreamInput = remainingMessage.streamInput()) {
this.version = Version.readVersion(messageStreamInput);
}
}
}

@Override
public void readFrom(StreamInput in) {
throw new UnsupportedOperationException("usage of Streamable is to be replaced by Writeable");
}

@Override
public void writeTo(StreamOutput streamOutput) throws IOException {
super.writeTo(streamOutput);
assert version != null;
try (BytesStreamOutput messageStreamOutput = new BytesStreamOutput(4)) {
Version.writeVersion(version, messageStreamOutput);
BytesReference reference = messageStreamOutput.bytes();
streamOutput.writeBytesReference(reference);
}
}
}

static final class HandshakeResponse extends TransportResponse {

private final Version responseVersion;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the plan is to replace internal:transport/handshake with this message, does it also need the ClusterName? Today we rely on the fact that the transport handshake validates that the cluster names match during discovery - I think that we don't otherwise verify that we're talking to nodes from the right cluster.

Copy link
Contributor Author

@Tim-Brooks Tim-Brooks Dec 11, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is follow-up work. We need the version the the request to bootstrap the follow-up work. Once we have the version in the request, we can serialize arbitrary responses that are compatible with the remote node's version.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thanks.


HandshakeResponse(Version responseVersion) {
this.responseVersion = responseVersion;
}

private HandshakeResponse(StreamInput in) throws IOException {
super.readFrom(in);
version = Version.readVersion(in);
responseVersion = Version.readVersion(in);
}

@Override
public void readFrom(StreamInput in) {
throw new UnsupportedOperationException("usage of Streamable is to be replaced by Writeable");
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
assert version != null;
Version.writeVersion(version, out);
assert responseVersion != null;
Version.writeVersion(responseVersion, out);
}

Version getResponseVersion() {
return responseVersion;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,4 @@ static byte setHandshake(byte value) { // pkg private since it's only used inter
value |= STATUS_HANDSHAKE;
return value;
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
import org.elasticsearch.Version;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.TestThreadPool;
import org.mockito.ArgumentCaptor;
Expand All @@ -38,24 +41,24 @@

public class TransportHandshakerTests extends ESTestCase {

private TcpTransportHandshaker handshaker;
private TransportHandshaker handshaker;
private DiscoveryNode node;
private TcpChannel channel;
private TestThreadPool threadPool;
private TcpTransportHandshaker.HandshakeRequestSender requestSender;
private TcpTransportHandshaker.HandshakeResponseSender responseSender;
private TransportHandshaker.HandshakeRequestSender requestSender;
private TransportHandshaker.HandshakeResponseSender responseSender;

@Override
public void setUp() throws Exception {
super.setUp();
String nodeId = "node-id";
channel = mock(TcpChannel.class);
requestSender = mock(TcpTransportHandshaker.HandshakeRequestSender.class);
responseSender = mock(TcpTransportHandshaker.HandshakeResponseSender.class);
requestSender = mock(TransportHandshaker.HandshakeRequestSender.class);
responseSender = mock(TransportHandshaker.HandshakeResponseSender.class);
node = new DiscoveryNode(nodeId, nodeId, nodeId, "host", "host_address", buildNewFakeTransportAddress(), Collections.emptyMap(),
Collections.emptySet(), Version.CURRENT);
threadPool = new TestThreadPool("thread-poll");
handshaker = new TcpTransportHandshaker(Version.CURRENT, threadPool, requestSender, responseSender);
handshaker = new TransportHandshaker(Version.CURRENT, threadPool, requestSender, responseSender);
}

@Override
Expand All @@ -74,20 +77,63 @@ public void testHandshakeRequestAndResponse() throws IOException {
assertFalse(versionFuture.isDone());

TcpChannel mockChannel = mock(TcpChannel.class);
handshaker.handleHandshake(Version.CURRENT, Collections.emptySet(), mockChannel, reqId);
TransportHandshaker.HandshakeRequest handshakeRequest = new TransportHandshaker.HandshakeRequest(Version.CURRENT);
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
handshakeRequest.writeTo(bytesStreamOutput);
StreamInput input = bytesStreamOutput.bytes().streamInput();
handshaker.handleHandshake(Version.CURRENT, Collections.emptySet(), mockChannel, reqId, input);


ArgumentCaptor<TransportResponse> responseCaptor = ArgumentCaptor.forClass(TransportResponse.class);
verify(responseSender).sendResponse(eq(Version.CURRENT), eq(Collections.emptySet()), eq(mockChannel), responseCaptor.capture(),
eq(reqId));

TransportResponseHandler<TcpTransportHandshaker.VersionHandshakeResponse> handler = handshaker.removeHandlerForHandshake(reqId);
handler.handleResponse((TcpTransportHandshaker.VersionHandshakeResponse) responseCaptor.getValue());
TransportResponseHandler<TransportHandshaker.HandshakeResponse> handler = handshaker.removeHandlerForHandshake(reqId);
handler.handleResponse((TransportHandshaker.HandshakeResponse) responseCaptor.getValue());

assertTrue(versionFuture.isDone());
assertEquals(Version.CURRENT, versionFuture.actionGet());
}

public void testHandshakeRequestFutureVersionsCompatibility() throws IOException {
long reqId = randomLongBetween(1, 10);
handshaker.sendHandshake(reqId, node, channel, new TimeValue(30, TimeUnit.SECONDS), PlainActionFuture.newFuture());

verify(requestSender).sendRequest(node, channel, reqId, Version.CURRENT.minimumCompatibilityVersion());

TcpChannel mockChannel = mock(TcpChannel.class);
TransportHandshaker.HandshakeRequest handshakeRequest = new TransportHandshaker.HandshakeRequest(Version.CURRENT);
BytesStreamOutput currentHandshakeBytes = new BytesStreamOutput();
handshakeRequest.writeTo(currentHandshakeBytes);

BytesStreamOutput lengthCheckingHandshake = new BytesStreamOutput();
BytesStreamOutput futureHandshake = new BytesStreamOutput();
TaskId.EMPTY_TASK_ID.writeTo(lengthCheckingHandshake);
TaskId.EMPTY_TASK_ID.writeTo(futureHandshake);
try (BytesStreamOutput internalMessage = new BytesStreamOutput()) {
Version.writeVersion(Version.CURRENT, internalMessage);
lengthCheckingHandshake.writeBytesReference(internalMessage.bytes());
internalMessage.write(new byte[1024]);
futureHandshake.writeBytesReference(internalMessage.bytes());
}
StreamInput futureHandshakeStream = futureHandshake.bytes().streamInput();
// We check that the handshake we serialize for this test equals the actual request.
// Otherwise, we need to update the test.
assertEquals(currentHandshakeBytes.bytes().length(), lengthCheckingHandshake.bytes().length());
assertEquals(1031, futureHandshakeStream.available());
handshaker.handleHandshake(Version.CURRENT, Collections.emptySet(), mockChannel, reqId, futureHandshakeStream);
assertEquals(0, futureHandshakeStream.available());


ArgumentCaptor<TransportResponse> responseCaptor = ArgumentCaptor.forClass(TransportResponse.class);
verify(responseSender).sendResponse(eq(Version.CURRENT), eq(Collections.emptySet()), eq(mockChannel), responseCaptor.capture(),
eq(reqId));

TransportHandshaker.HandshakeResponse response = (TransportHandshaker.HandshakeResponse) responseCaptor.getValue();

assertEquals(Version.CURRENT, response.getResponseVersion());
}

public void testHandshakeError() throws IOException {
PlainActionFuture<Version> versionFuture = PlainActionFuture.newFuture();
long reqId = randomLongBetween(1, 10);
Expand All @@ -97,7 +143,7 @@ public void testHandshakeError() throws IOException {

assertFalse(versionFuture.isDone());

TransportResponseHandler<TcpTransportHandshaker.VersionHandshakeResponse> handler = handshaker.removeHandlerForHandshake(reqId);
TransportResponseHandler<TransportHandshaker.HandshakeResponse> handler = handshaker.removeHandlerForHandshake(reqId);
handler.handleException(new TransportException("failed"));

assertTrue(versionFuture.isDone());
Expand All @@ -113,7 +159,6 @@ public void testSendRequestThrowsException() throws IOException {

handshaker.sendHandshake(reqId, node, channel, new TimeValue(30, TimeUnit.SECONDS), versionFuture);


assertTrue(versionFuture.isDone());
ConnectTransportException cte = expectThrows(ConnectTransportException.class, versionFuture::actionGet);
assertThat(cte.getMessage(), containsString("failure to send internal:tcp/handshake"));
Expand Down
Loading