From f4afa6463b597e02b25936c5d8e100f0b91bdb5c Mon Sep 17 00:00:00 2001 From: Tim Brooks Date: Tue, 19 Mar 2019 13:13:57 -0600 Subject: [PATCH 1/5] WIP --- .../DelegatingTransportMessageListener.java | 74 +++++++++++ .../transport/OutboundHandler.java | 64 ++++++++- .../elasticsearch/transport/TcpTransport.java | 125 ++++-------------- .../transport/TcpTransportChannel.java | 28 ++-- .../elasticsearch/transport/Transport.java | 2 +- .../transport/FailAndRetryMockTransport.java | 2 +- .../cluster/NodeConnectionsServiceTests.java | 2 +- .../transport/OutboundHandlerTests.java | 3 +- .../test/transport/MockTransport.java | 4 +- .../test/transport/StubbableTransport.java | 4 +- 10 files changed, 184 insertions(+), 124 deletions(-) create mode 100644 server/src/main/java/org/elasticsearch/transport/DelegatingTransportMessageListener.java diff --git a/server/src/main/java/org/elasticsearch/transport/DelegatingTransportMessageListener.java b/server/src/main/java/org/elasticsearch/transport/DelegatingTransportMessageListener.java new file mode 100644 index 0000000000000..df1e2979ee88c --- /dev/null +++ b/server/src/main/java/org/elasticsearch/transport/DelegatingTransportMessageListener.java @@ -0,0 +1,74 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport; + +import org.elasticsearch.cluster.node.DiscoveryNode; + +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; + +final class DelegatingTransportMessageListener implements TransportMessageListener { + + private final List listeners = new CopyOnWriteArrayList<>(); + + @Override + public void onRequestReceived(long requestId, String action) { + for (TransportMessageListener listener : listeners) { + listener.onRequestReceived(requestId, action); + } + } + + @Override + public void onResponseSent(long requestId, String action, TransportResponse response) { + for (TransportMessageListener listener : listeners) { + listener.onResponseSent(requestId, action, response); + } + } + + @Override + public void onResponseSent(long requestId, String action, Exception error) { + for (TransportMessageListener listener : listeners) { + listener.onResponseSent(requestId, action, error); + } + } + + @Override + public void onRequestSent(DiscoveryNode node, long requestId, String action, TransportRequest request, + TransportRequestOptions finalOptions) { + for (TransportMessageListener listener : listeners) { + listener.onRequestSent(node, requestId, action, request, finalOptions); + } + } + + @Override + public void onResponseReceived(long requestId, Transport.ResponseContext holder) { + for (TransportMessageListener listener : listeners) { + listener.onResponseReceived(requestId, holder); + } + } + + public void addListener(TransportMessageListener listener) { + listeners.add(listener); + } + + public boolean removeListener(TransportMessageListener listener) { + return listeners.remove(listener); + } +} diff --git a/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java b/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java index 9431258f3230c..c1532900a42e7 100644 --- a/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java +++ b/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java @@ -22,8 +22,10 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.NotifyOnceListener; +import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.CheckedSupplier; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput; @@ -32,22 +34,33 @@ import org.elasticsearch.common.metrics.MeanMetric; import org.elasticsearch.common.network.CloseableChannel; import org.elasticsearch.common.transport.NetworkExceptionHelper; +import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.core.internal.io.IOUtils; import org.elasticsearch.threadpool.ThreadPool; import java.io.IOException; +import java.util.Set; final class OutboundHandler { private static final Logger logger = LogManager.getLogger(OutboundHandler.class); private final MeanMetric transmittedBytesMetric = new MeanMetric(); + + private final String nodeName; + private final Version version; + private final String[] features; private final ThreadPool threadPool; private final BigArrays bigArrays; private final TransportLogger transportLogger; + private final DelegatingTransportMessageListener messageListener = new DelegatingTransportMessageListener(); - OutboundHandler(ThreadPool threadPool, BigArrays bigArrays, TransportLogger transportLogger) { + OutboundHandler(String nodeName, Version version, String[] features, ThreadPool threadPool, BigArrays bigArrays, + TransportLogger transportLogger) { + this.nodeName = nodeName; + this.version = version; + this.features = features; this.threadPool = threadPool; this.bigArrays = bigArrays; this.transportLogger = transportLogger; @@ -64,6 +77,47 @@ void sendBytes(TcpChannel channel, BytesReference bytes, ActionListener li } } + public void sendRequest(final DiscoveryNode node, final TcpChannel channel, final long requestId, final String action, + final TransportRequest request, final TransportRequestOptions options, final Version channelVersion, + final boolean compressRequest, final boolean isHandshake) throws IOException, TransportException { + Version version = Version.min(this.version, channelVersion); + OutboundMessage.Request message = new OutboundMessage.Request(threadPool.getThreadContext(), features, request, version, action, + requestId, isHandshake, compressRequest); + ActionListener listener = ActionListener.wrap(() -> + messageListener.onRequestSent(node, requestId, action, request, options)); + sendMessage(channel, message, listener); + } + + /** + * Sends the response to the given channel. This method should be used to send {@link TransportResponse} + * objects back to the caller. + * + * @see #sendErrorResponse(Version, Set, TcpChannel, Exception, long, String) for sending error responses + */ + public void sendResponse(final Version nodeVersion, final Set features, final TcpChannel channel, + final TransportResponse response, final long requestId, final String action, + final boolean compress, final boolean isHandshake) throws IOException { + Version version = Version.min(this.version, nodeVersion); + OutboundMessage.Response message = new OutboundMessage.Response(threadPool.getThreadContext(), features, response, version, + requestId, isHandshake, compress); + ActionListener listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, response)); + sendMessage(channel, message, listener); + } + + /** + * Sends back an error response to the caller via the given channel + */ + public void sendErrorResponse(final Version nodeVersion, final Set features, final TcpChannel channel, final Exception error, + final long requestId, final String action) throws IOException { + Version version = Version.min(this.version, nodeVersion); + TransportAddress address = new TransportAddress(channel.getLocalAddress()); + RemoteTransportException tx = new RemoteTransportException(nodeName, address, action, error); + OutboundMessage.Response message = new OutboundMessage.Response(threadPool.getThreadContext(), features, tx, version, requestId, + false, false); + ActionListener listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, error)); + sendMessage(channel, message, listener); + } + void sendMessage(TcpChannel channel, OutboundMessage networkMessage, ActionListener listener) throws IOException { channel.getChannelStats().markAccessed(threadPool.relativeTimeInMillis()); MessageSerializer serializer = new MessageSerializer(networkMessage, bigArrays); @@ -91,6 +145,14 @@ MeanMetric getTransmittedBytes() { return transmittedBytesMetric; } + public void addMessageListener(TransportMessageListener listener) { + messageListener.addListener(listener); + } + + public void removeMessageListener(TransportMessageListener listener) { + messageListener.removeListener(listener); + } + private static class MessageSerializer implements CheckedSupplier, Releasable { private final OutboundMessage message; diff --git a/server/src/main/java/org/elasticsearch/transport/TcpTransport.java b/server/src/main/java/org/elasticsearch/transport/TcpTransport.java index 9b0a1dafd14ff..804d5a6089cbf 100644 --- a/server/src/main/java/org/elasticsearch/transport/TcpTransport.java +++ b/server/src/main/java/org/elasticsearch/transport/TcpTransport.java @@ -82,7 +82,6 @@ import java.util.Set; import java.util.TreeSet; import java.util.concurrent.ConcurrentMap; -import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; @@ -152,17 +151,7 @@ public TcpTransport(Settings settings, Version version, ThreadPool threadPool, P this.circuitBreakerService = circuitBreakerService; this.networkService = networkService; this.transportLogger = new TransportLogger(); - this.outboundHandler = new OutboundHandler(threadPool, bigArrays, transportLogger); - this.handshaker = new TransportHandshaker(version, threadPool, - (node, channel, requestId, v) -> sendRequestToChannel(node, channel, requestId, - TransportHandshaker.HANDSHAKE_ACTION_NAME, new TransportHandshaker.HandshakeRequest(version), - TransportRequestOptions.EMPTY, v, false, true), - (v, features, channel, response, requestId) -> sendResponse(v, features, channel, response, requestId, - TransportHandshaker.HANDSHAKE_ACTION_NAME, false, true)); - this.keepAlive = new TransportKeepAlive(threadPool, this.outboundHandler::sendBytes); - this.reader = new InboundMessage.Reader(version, namedWriteableRegistry, threadPool.getThreadContext()); this.nodeName = Node.NODE_NAME_SETTING.get(settings); - final Settings defaultFeatures = TransportSettings.DEFAULT_FEATURES_SETTING.get(settings); if (defaultFeatures == null) { this.features = new String[0]; @@ -175,6 +164,16 @@ public TcpTransport(Settings settings, Version version, ThreadPool threadPool, P // use a sorted set to present the features in a consistent order this.features = new TreeSet<>(defaultFeatures.names()).toArray(new String[defaultFeatures.names().size()]); } + this.outboundHandler = new OutboundHandler(nodeName, version, this.features, threadPool, bigArrays, transportLogger); + + this.handshaker = new TransportHandshaker(version, threadPool, + (node, channel, requestId, v) -> outboundHandler.sendRequest(node, channel, requestId, + TransportHandshaker.HANDSHAKE_ACTION_NAME, new TransportHandshaker.HandshakeRequest(version), + TransportRequestOptions.EMPTY, v, false, true), + (v, features, channel, response, requestId) -> outboundHandler.sendResponse(v, features, channel, response, requestId, + TransportHandshaker.HANDSHAKE_ACTION_NAME, false, true)); + this.keepAlive = new TransportKeepAlive(threadPool, this.outboundHandler::sendBytes); + this.reader = new InboundMessage.Reader(version, namedWriteableRegistry, threadPool.getThreadContext()); } @Override @@ -182,11 +181,15 @@ protected void doStart() { } public void addMessageListener(TransportMessageListener listener) { - messageListener.listeners.add(listener); + messageListener.addListener(listener); + // The outbound handler handles the send listeners + outboundHandler.addMessageListener(listener); } - public boolean removeMessageListener(TransportMessageListener listener) { - return messageListener.listeners.remove(listener); + public void removeMessageListener(TransportMessageListener listener) { + messageListener.removeListener(listener); + // The outbound handler handles the send listeners + outboundHandler.removeMessageListener(listener); } @Override @@ -267,7 +270,7 @@ public void sendRequest(long requestId, String action, TransportRequest request, throw new NodeNotConnectedException(node, "connection already closed"); } TcpChannel channel = channel(options.type()); - sendRequestToChannel(this.node, channel, requestId, action, request, options, getVersion(), compress); + outboundHandler.sendRequest(node, channel, requestId, action, request, options, getVersion(), compress, false); } } @@ -661,23 +664,6 @@ protected void serverAcceptedChannel(TcpChannel channel) { */ protected abstract void stopInternal(); - private void sendRequestToChannel(final DiscoveryNode node, final TcpChannel channel, final long requestId, final String action, - final TransportRequest request, TransportRequestOptions options, Version channelVersion, - boolean compressRequest) throws IOException, TransportException { - sendRequestToChannel(node, channel, requestId, action, request, options, channelVersion, compressRequest, false); - } - - private void sendRequestToChannel(final DiscoveryNode node, final TcpChannel channel, final long requestId, final String action, - final TransportRequest request, TransportRequestOptions options, Version channelVersion, - boolean compressRequest, boolean isHandshake) throws IOException, TransportException { - Version version = Version.min(this.version, channelVersion); - OutboundMessage.Request message = new OutboundMessage.Request(threadPool.getThreadContext(), features, request, version, action, - requestId, isHandshake, compressRequest); - ActionListener listener = ActionListener.wrap(() -> - messageListener.onRequestSent(node, requestId, action, request, options)); - outboundHandler.sendMessage(channel, message, listener); - } - /** * Sends back an error response to the caller via the given channel * @@ -695,13 +681,7 @@ public void sendErrorResponse( final Exception error, final long requestId, final String action) throws IOException { - Version version = Version.min(this.version, nodeVersion); - TransportAddress address = new TransportAddress(channel.getLocalAddress()); - RemoteTransportException tx = new RemoteTransportException(nodeName, address, action, error); - OutboundMessage.Response message = new OutboundMessage.Response(threadPool.getThreadContext(), features, tx, version, requestId, - false, false); - ActionListener listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, error)); - outboundHandler.sendMessage(channel, message, listener); + outboundHandler.sendErrorResponse(nodeVersion, features, channel, error, requestId, action); } /** @@ -717,23 +697,7 @@ public void sendResponse( final long requestId, final String action, final boolean compress) throws IOException { - sendResponse(nodeVersion, features, channel, response, requestId, action, compress, false); - } - - private void sendResponse( - final Version nodeVersion, - final Set features, - final TcpChannel channel, - final TransportResponse response, - final long requestId, - final String action, - boolean compress, - boolean isHandshake) throws IOException { - Version version = Version.min(this.version, nodeVersion); - OutboundMessage.Response message = new OutboundMessage.Response(threadPool.getThreadContext(), features, response, version, - requestId, isHandshake, compress); - ActionListener listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, response)); - outboundHandler.sendMessage(channel, message, listener); + outboundHandler.sendResponse(nodeVersion, features, channel, response, requestId, action, compress, false); } /** @@ -1021,8 +985,8 @@ protected void handleRequest(TcpChannel channel, InboundMessage.RequestMessage m } else { getInFlightRequestBreaker().addWithoutBreaking(messageLengthBytes); } - transportChannel = new TcpTransportChannel(this, channel, action, requestId, version, features, profileName, - messageLengthBytes, message.isCompress()); + transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, features, + circuitBreakerService, messageLengthBytes, message.isCompress()); final TransportRequest request = reg.newRequest(stream); request.remoteAddress(new TransportAddress(channel.getRemoteAddress())); // in case we throw an exception, i.e. when the limit is hit, we don't want to verify @@ -1032,8 +996,8 @@ protected void handleRequest(TcpChannel channel, InboundMessage.RequestMessage m } catch (Exception e) { // the circuit breaker tripped if (transportChannel == null) { - transportChannel = new TcpTransportChannel(this, channel, action, requestId, version, features, - profileName, 0, message.isCompress()); + transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version, features, + circuitBreakerService, 0, message.isCompress()); } try { transportChannel.sendResponse(e); @@ -1184,47 +1148,6 @@ public ProfileSettings(Settings settings, String profileName) { } } - private static final class DelegatingTransportMessageListener implements TransportMessageListener { - - private final List listeners = new CopyOnWriteArrayList<>(); - - @Override - public void onRequestReceived(long requestId, String action) { - for (TransportMessageListener listener : listeners) { - listener.onRequestReceived(requestId, action); - } - } - - @Override - public void onResponseSent(long requestId, String action, TransportResponse response) { - for (TransportMessageListener listener : listeners) { - listener.onResponseSent(requestId, action, response); - } - } - - @Override - public void onResponseSent(long requestId, String action, Exception error) { - for (TransportMessageListener listener : listeners) { - listener.onResponseSent(requestId, action, error); - } - } - - @Override - public void onRequestSent(DiscoveryNode node, long requestId, String action, TransportRequest request, - TransportRequestOptions finalOptions) { - for (TransportMessageListener listener : listeners) { - listener.onRequestSent(node, requestId, action, request, finalOptions); - } - } - - @Override - public void onResponseReceived(long requestId, ResponseContext holder) { - for (TransportMessageListener listener : listeners) { - listener.onResponseReceived(requestId, holder); - } - } - } - @Override public final ResponseHandlers getResponseHandlers() { return responseHandlers; diff --git a/server/src/main/java/org/elasticsearch/transport/TcpTransportChannel.java b/server/src/main/java/org/elasticsearch/transport/TcpTransportChannel.java index b45fc19c762e9..a52fd7a8a1eae 100644 --- a/server/src/main/java/org/elasticsearch/transport/TcpTransportChannel.java +++ b/server/src/main/java/org/elasticsearch/transport/TcpTransportChannel.java @@ -20,6 +20,8 @@ package org.elasticsearch.transport; import org.elasticsearch.Version; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.indices.breaker.CircuitBreakerService; import java.io.IOException; import java.util.Set; @@ -28,38 +30,38 @@ public final class TcpTransportChannel implements TransportChannel { private final AtomicBoolean released = new AtomicBoolean(); - private final TcpTransport transport; - private final Version version; - private final Set features; + private final OutboundHandler outboundHandler; + private final TcpChannel channel; private final String action; private final long requestId; - private final String profileName; + private final Version version; + private final Set features; + private final CircuitBreakerService breakerService; private final long reservedBytes; - private final TcpChannel channel; private final boolean compressResponse; - TcpTransportChannel(TcpTransport transport, TcpChannel channel, String action, long requestId, Version version, Set features, - String profileName, long reservedBytes, boolean compressResponse) { + TcpTransportChannel(OutboundHandler outboundHandler, TcpChannel channel, String action, long requestId, Version version, + Set features, CircuitBreakerService breakerService, long reservedBytes, boolean compressResponse) { this.version = version; this.features = features; this.channel = channel; - this.transport = transport; + this.outboundHandler = outboundHandler; this.action = action; this.requestId = requestId; - this.profileName = profileName; + this.breakerService = breakerService; this.reservedBytes = reservedBytes; this.compressResponse = compressResponse; } @Override public String getProfileName() { - return profileName; + return channel.getProfile(); } @Override public void sendResponse(TransportResponse response) throws IOException { try { - transport.sendResponse(version, features, channel, response, requestId, action, compressResponse); + outboundHandler.sendResponse(version, features, channel, response, requestId, action, compressResponse, false); } finally { release(false); } @@ -68,7 +70,7 @@ public void sendResponse(TransportResponse response) throws IOException { @Override public void sendResponse(Exception exception) throws IOException { try { - transport.sendErrorResponse(version, features, channel, exception, requestId, action); + outboundHandler.sendErrorResponse(version, features, channel, exception, requestId, action); } finally { release(true); } @@ -79,7 +81,7 @@ public void sendResponse(Exception exception) throws IOException { private void release(boolean isExceptionResponse) { if (released.compareAndSet(false, true)) { assert (releaseBy = new Exception()) != null; // easier to debug if it's already closed - transport.getInFlightRequestBreaker().addWithoutBreaking(-reservedBytes); + breakerService.getBreaker(CircuitBreaker.IN_FLIGHT_REQUESTS).addWithoutBreaking(-reservedBytes); } else if (isExceptionResponse == false) { // only fail if we are not sending an error - we might send the error triggered by the previous // sendResponse call diff --git a/server/src/main/java/org/elasticsearch/transport/Transport.java b/server/src/main/java/org/elasticsearch/transport/Transport.java index 4a8a061602a52..0cd860d73c589 100644 --- a/server/src/main/java/org/elasticsearch/transport/Transport.java +++ b/server/src/main/java/org/elasticsearch/transport/Transport.java @@ -55,7 +55,7 @@ public interface Transport extends LifecycleComponent { void addMessageListener(TransportMessageListener listener); - boolean removeMessageListener(TransportMessageListener listener); + void removeMessageListener(TransportMessageListener listener); /** * The address the transport is bound on. diff --git a/server/src/test/java/org/elasticsearch/client/transport/FailAndRetryMockTransport.java b/server/src/test/java/org/elasticsearch/client/transport/FailAndRetryMockTransport.java index 7ae8156088db1..3e8e11dd5f043 100644 --- a/server/src/test/java/org/elasticsearch/client/transport/FailAndRetryMockTransport.java +++ b/server/src/test/java/org/elasticsearch/client/transport/FailAndRetryMockTransport.java @@ -234,7 +234,7 @@ public void addMessageListener(TransportMessageListener listener) { } @Override - public boolean removeMessageListener(TransportMessageListener listener) { + public void removeMessageListener(TransportMessageListener listener) { throw new UnsupportedOperationException(); } diff --git a/server/src/test/java/org/elasticsearch/cluster/NodeConnectionsServiceTests.java b/server/src/test/java/org/elasticsearch/cluster/NodeConnectionsServiceTests.java index 0a4ef759cb68f..5728622d05036 100644 --- a/server/src/test/java/org/elasticsearch/cluster/NodeConnectionsServiceTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/NodeConnectionsServiceTests.java @@ -391,7 +391,7 @@ public void addMessageListener(TransportMessageListener listener) { } @Override - public boolean removeMessageListener(TransportMessageListener listener) { + public void removeMessageListener(TransportMessageListener listener) { throw new UnsupportedOperationException(); } diff --git a/server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java b/server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java index 01e391a30a732..55246c92e8ade 100644 --- a/server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java +++ b/server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java @@ -55,7 +55,8 @@ public void setUp() throws Exception { super.setUp(); TransportLogger transportLogger = new TransportLogger(); fakeTcpChannel = new FakeTcpChannel(randomBoolean()); - handler = new OutboundHandler(threadPool, BigArrays.NON_RECYCLING_INSTANCE, transportLogger); + handler = new OutboundHandler("node", Version.CURRENT, new String[0], threadPool, BigArrays.NON_RECYCLING_INSTANCE, + transportLogger); } @After diff --git a/test/framework/src/main/java/org/elasticsearch/test/transport/MockTransport.java b/test/framework/src/main/java/org/elasticsearch/test/transport/MockTransport.java index 52ba7efa3fa41..619c1ab1a9eea 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/transport/MockTransport.java +++ b/test/framework/src/main/java/org/elasticsearch/test/transport/MockTransport.java @@ -272,12 +272,10 @@ public void addMessageListener(TransportMessageListener listener) { } @Override - public boolean removeMessageListener(TransportMessageListener listener) { + public void removeMessageListener(TransportMessageListener listener) { if (listener == this.listener) { this.listener = null; - return true; } - return false; } protected NamedWriteableRegistry writeableRegistry() { diff --git a/test/framework/src/main/java/org/elasticsearch/test/transport/StubbableTransport.java b/test/framework/src/main/java/org/elasticsearch/test/transport/StubbableTransport.java index 673ed49387570..c7796d032b583 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/transport/StubbableTransport.java +++ b/test/framework/src/main/java/org/elasticsearch/test/transport/StubbableTransport.java @@ -100,8 +100,8 @@ public void addMessageListener(TransportMessageListener listener) { } @Override - public boolean removeMessageListener(TransportMessageListener listener) { - return delegate.removeMessageListener(listener); + public void removeMessageListener(TransportMessageListener listener) { + delegate.removeMessageListener(listener); } @Override From b6dea05c0450947a07c0d4d1bd16d310a358a463 Mon Sep 17 00:00:00 2001 From: Tim Brooks Date: Tue, 19 Mar 2019 18:15:45 -0600 Subject: [PATCH 2/5] WIP --- .../transport/OutboundHandler.java | 18 +++--- .../elasticsearch/transport/TcpTransport.java | 53 ++-------------- .../transport/OutboundHandlerTests.java | 63 +++++++++++++++++++ 3 files changed, 80 insertions(+), 54 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java b/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java index c1532900a42e7..460f8c0f051c1 100644 --- a/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java +++ b/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java @@ -77,6 +77,10 @@ void sendBytes(TcpChannel channel, BytesReference bytes, ActionListener li } } + /** + * Sends the request to the given channel. This method should be used to send {@link TransportRequest} + * objects back to the caller. + */ public void sendRequest(final DiscoveryNode node, final TcpChannel channel, final long requestId, final String action, final TransportRequest request, final TransportRequestOptions options, final Version channelVersion, final boolean compressRequest, final boolean isHandshake) throws IOException, TransportException { @@ -94,9 +98,9 @@ public void sendRequest(final DiscoveryNode node, final TcpChannel channel, fina * * @see #sendErrorResponse(Version, Set, TcpChannel, Exception, long, String) for sending error responses */ - public void sendResponse(final Version nodeVersion, final Set features, final TcpChannel channel, - final TransportResponse response, final long requestId, final String action, - final boolean compress, final boolean isHandshake) throws IOException { + void sendResponse(final Version nodeVersion, final Set features, final TcpChannel channel, + final TransportResponse response, final long requestId, final String action, + final boolean compress, final boolean isHandshake) throws IOException { Version version = Version.min(this.version, nodeVersion); OutboundMessage.Response message = new OutboundMessage.Response(threadPool.getThreadContext(), features, response, version, requestId, isHandshake, compress); @@ -107,8 +111,8 @@ public void sendResponse(final Version nodeVersion, final Set features, /** * Sends back an error response to the caller via the given channel */ - public void sendErrorResponse(final Version nodeVersion, final Set features, final TcpChannel channel, final Exception error, - final long requestId, final String action) throws IOException { + void sendErrorResponse(final Version nodeVersion, final Set features, final TcpChannel channel, final Exception error, + final long requestId, final String action) throws IOException { Version version = Version.min(this.version, nodeVersion); TransportAddress address = new TransportAddress(channel.getLocalAddress()); RemoteTransportException tx = new RemoteTransportException(nodeName, address, action, error); @@ -145,11 +149,11 @@ MeanMetric getTransmittedBytes() { return transmittedBytesMetric; } - public void addMessageListener(TransportMessageListener listener) { + void addMessageListener(TransportMessageListener listener) { messageListener.addListener(listener); } - public void removeMessageListener(TransportMessageListener listener) { + void removeMessageListener(TransportMessageListener listener) { messageListener.removeListener(listener); } diff --git a/server/src/main/java/org/elasticsearch/transport/TcpTransport.java b/server/src/main/java/org/elasticsearch/transport/TcpTransport.java index 804d5a6089cbf..dd1c0d32015e5 100644 --- a/server/src/main/java/org/elasticsearch/transport/TcpTransport.java +++ b/server/src/main/java/org/elasticsearch/transport/TcpTransport.java @@ -101,17 +101,13 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements public static final String TRANSPORT_WORKER_THREAD_NAME_PREFIX = "transport_worker"; - // This is the number of bytes necessary to read the message size private static final int BYTES_NEEDED_FOR_MESSAGE_SIZE = TcpHeader.MARKER_BYTES_SIZE + TcpHeader.MESSAGE_LENGTH_SIZE; private static final long NINETY_PER_HEAP_SIZE = (long) (JvmInfo.jvmInfo().getMem().getHeapMax().getBytes() * 0.9); private static final BytesReference EMPTY_BYTES_REFERENCE = new BytesArray(new byte[0]); - private final String[] features; - protected final Settings settings; private final CircuitBreakerService circuitBreakerService; - private final Version version; protected final ThreadPool threadPool; protected final BigArrays bigArrays; protected final PageCacheRecycler pageCacheRecycler; @@ -137,24 +133,23 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements private final TransportKeepAlive keepAlive; private final InboundMessage.Reader reader; private final OutboundHandler outboundHandler; - private final String nodeName; public TcpTransport(Settings settings, Version version, ThreadPool threadPool, PageCacheRecycler pageCacheRecycler, CircuitBreakerService circuitBreakerService, NamedWriteableRegistry namedWriteableRegistry, NetworkService networkService) { this.settings = settings; this.profileSettings = getProfileSettings(settings); - this.version = version; this.threadPool = threadPool; this.bigArrays = new BigArrays(pageCacheRecycler, circuitBreakerService, CircuitBreaker.IN_FLIGHT_REQUESTS); this.pageCacheRecycler = pageCacheRecycler; this.circuitBreakerService = circuitBreakerService; this.networkService = networkService; this.transportLogger = new TransportLogger(); - this.nodeName = Node.NODE_NAME_SETTING.get(settings); + String nodeName = Node.NODE_NAME_SETTING.get(settings); final Settings defaultFeatures = TransportSettings.DEFAULT_FEATURES_SETTING.get(settings); + String[] features; if (defaultFeatures == null) { - this.features = new String[0]; + features = new String[0]; } else { defaultFeatures.names().forEach(key -> { if (Booleans.parseBoolean(defaultFeatures.get(key)) == false) { @@ -162,15 +157,15 @@ public TcpTransport(Settings settings, Version version, ThreadPool threadPool, P } }); // use a sorted set to present the features in a consistent order - this.features = new TreeSet<>(defaultFeatures.names()).toArray(new String[defaultFeatures.names().size()]); + features = new TreeSet<>(defaultFeatures.names()).toArray(new String[defaultFeatures.names().size()]); } - this.outboundHandler = new OutboundHandler(nodeName, version, this.features, threadPool, bigArrays, transportLogger); + this.outboundHandler = new OutboundHandler(nodeName, version, features, threadPool, bigArrays, transportLogger); this.handshaker = new TransportHandshaker(version, threadPool, (node, channel, requestId, v) -> outboundHandler.sendRequest(node, channel, requestId, TransportHandshaker.HANDSHAKE_ACTION_NAME, new TransportHandshaker.HandshakeRequest(version), TransportRequestOptions.EMPTY, v, false, true), - (v, features, channel, response, requestId) -> outboundHandler.sendResponse(v, features, channel, response, requestId, + (v, features1, channel, response, requestId) -> outboundHandler.sendResponse(v, features1, channel, response, requestId, TransportHandshaker.HANDSHAKE_ACTION_NAME, false, true)); this.keepAlive = new TransportKeepAlive(threadPool, this.outboundHandler::sendBytes); this.reader = new InboundMessage.Reader(version, namedWriteableRegistry, threadPool.getThreadContext()); @@ -664,42 +659,6 @@ protected void serverAcceptedChannel(TcpChannel channel) { */ protected abstract void stopInternal(); - /** - * Sends back an error response to the caller via the given channel - * - * @param nodeVersion the caller node version - * @param features the caller features - * @param channel the channel to send the response to - * @param error the error to return - * @param requestId the request ID this response replies to - * @param action the action this response replies to - */ - public void sendErrorResponse( - final Version nodeVersion, - final Set features, - final TcpChannel channel, - final Exception error, - final long requestId, - final String action) throws IOException { - outboundHandler.sendErrorResponse(nodeVersion, features, channel, error, requestId, action); - } - - /** - * Sends the response to the given channel. This method should be used to send {@link TransportResponse} objects back to the caller. - * - * @see #sendErrorResponse(Version, Set, TcpChannel, Exception, long, String) for sending back errors to the caller - */ - public void sendResponse( - final Version nodeVersion, - final Set features, - final TcpChannel channel, - final TransportResponse response, - final long requestId, - final String action, - final boolean compress) throws IOException { - outboundHandler.sendResponse(nodeVersion, features, channel, response, requestId, action, compress, false); - } - /** * Handles inbound message that has been decoded. * diff --git a/server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java b/server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java index 55246c92e8ade..c96f55ac2798b 100644 --- a/server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java +++ b/server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java @@ -89,6 +89,69 @@ public void testSendRawBytes() { assertEquals(bytesArray, reference); } + public void testSendRequest() throws IOException { + ThreadContext threadContext = threadPool.getThreadContext(); + Version version = Version.CURRENT; + String actionName = "handshake"; + long requestId = randomLongBetween(0, 300); + boolean isHandshake = randomBoolean(); + boolean compress = randomBoolean(); + String value = "message"; + threadContext.putHeader("header", "header_value"); + Writeable writeable = new Message(value); + + OutboundMessage message = new OutboundMessage.Request(threadContext, new String[0], writeable, version, actionName, requestId, isHandshake, + compress); + + AtomicBoolean isSuccess = new AtomicBoolean(false); + AtomicReference exception = new AtomicReference<>(); + ActionListener listener = ActionListener.wrap((v) -> isSuccess.set(true), exception::set); + handler.sendMessage(fakeTcpChannel, message, listener); + + BytesReference reference = fakeTcpChannel.getMessageCaptor().get(); + ActionListener sendListener = fakeTcpChannel.getListenerCaptor().get(); + if (randomBoolean()) { + sendListener.onResponse(null); + assertTrue(isSuccess.get()); + assertNull(exception.get()); + } else { + IOException e = new IOException("failed"); + sendListener.onFailure(e); + assertFalse(isSuccess.get()); + assertSame(e, exception.get()); + } + + InboundMessage.Reader reader = new InboundMessage.Reader(Version.CURRENT, namedWriteableRegistry, threadPool.getThreadContext()); + try (InboundMessage inboundMessage = reader.deserialize(reference.slice(6, reference.length() - 6))) { + assertEquals(version, inboundMessage.getVersion()); + assertEquals(requestId, inboundMessage.getRequestId()); + assertTrue(inboundMessage.isRequest()); + assertFalse(inboundMessage.isResponse()); + if (isHandshake) { + assertTrue(inboundMessage.isHandshake()); + } else { + assertFalse(inboundMessage.isHandshake()); + } + if (compress) { + assertTrue(inboundMessage.isCompress()); + } else { + assertFalse(inboundMessage.isCompress()); + } + Message readMessage = new Message(); + readMessage.readFrom(inboundMessage.getStreamInput()); + assertEquals(value, readMessage.value); + + try (ThreadContext.StoredContext existing = threadContext.stashContext()) { + ThreadContext.StoredContext storedContext = inboundMessage.getStoredContext(); + assertNull(threadContext.getHeader("header")); + storedContext.restore(); + assertEquals("header_value", threadContext.getHeader("header")); + } + } + } + + + public void testSendMessage() throws IOException { OutboundMessage message; ThreadContext threadContext = threadPool.getThreadContext(); From da86c7b902fe7075993f1fcf4d9328a34e8541b1 Mon Sep 17 00:00:00 2001 From: Tim Brooks Date: Thu, 21 Mar 2019 16:39:36 -0600 Subject: [PATCH 3/5] Changes --- .../transport/netty4/Netty4TransportIT.java | 2 +- .../transport/nio/NioTransportIT.java | 2 +- .../transport/InboundMessage.java | 14 +- .../transport/OutboundHandler.java | 41 ++-- .../elasticsearch/transport/TcpTransport.java | 15 +- .../transport/TcpTransportChannel.java | 4 +- .../transport/TransportMessageListener.java | 2 + .../transport/InboundMessageTests.java | 6 +- .../transport/OutboundHandlerTests.java | 231 +++++++++++++----- .../AbstractSimpleTransportTestCase.java | 4 +- .../transport/FakeTcpChannel.java | 4 + 11 files changed, 212 insertions(+), 113 deletions(-) diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4TransportIT.java b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4TransportIT.java index bc24789341e04..cf9791ce85a4f 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4TransportIT.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4TransportIT.java @@ -111,7 +111,7 @@ public ExceptionThrowingNetty4Transport( } @Override - protected void handleRequest(TcpChannel channel, InboundMessage.RequestMessage request, int messageLengthBytes) throws IOException { + protected void handleRequest(TcpChannel channel, InboundMessage.Request request, int messageLengthBytes) throws IOException { super.handleRequest(channel, request, messageLengthBytes); channelProfileName = TransportSettings.DEFAULT_PROFILE; } diff --git a/plugins/transport-nio/src/test/java/org/elasticsearch/transport/nio/NioTransportIT.java b/plugins/transport-nio/src/test/java/org/elasticsearch/transport/nio/NioTransportIT.java index 087c3758bb98b..d02be2cff9e7c 100644 --- a/plugins/transport-nio/src/test/java/org/elasticsearch/transport/nio/NioTransportIT.java +++ b/plugins/transport-nio/src/test/java/org/elasticsearch/transport/nio/NioTransportIT.java @@ -113,7 +113,7 @@ public Map> getTransports(Settings settings, ThreadP } @Override - protected void handleRequest(TcpChannel channel, InboundMessage.RequestMessage request, int messageLengthBytes) throws IOException { + protected void handleRequest(TcpChannel channel, InboundMessage.Request request, int messageLengthBytes) throws IOException { super.handleRequest(channel, request, messageLengthBytes); channelProfileName = TransportSettings.DEFAULT_PROFILE; } diff --git a/server/src/main/java/org/elasticsearch/transport/InboundMessage.java b/server/src/main/java/org/elasticsearch/transport/InboundMessage.java index 777073613798a..318e22701627d 100644 --- a/server/src/main/java/org/elasticsearch/transport/InboundMessage.java +++ b/server/src/main/java/org/elasticsearch/transport/InboundMessage.java @@ -101,9 +101,9 @@ InboundMessage deserialize(BytesReference reference) throws IOException { if (TransportStatus.isRequest(status)) { final Set features = Collections.unmodifiableSet(new TreeSet<>(Arrays.asList(streamInput.readStringArray()))); final String action = streamInput.readString(); - message = new RequestMessage(threadContext, remoteVersion, status, requestId, action, features, streamInput); + message = new Request(threadContext, remoteVersion, status, requestId, action, features, streamInput); } else { - message = new ResponseMessage(threadContext, remoteVersion, status, requestId, streamInput); + message = new Response(threadContext, remoteVersion, status, requestId, streamInput); } success = true; return message; @@ -133,13 +133,13 @@ private static void ensureVersionCompatibility(Version version, Version currentV } } - public static class RequestMessage extends InboundMessage { + public static class Request extends InboundMessage { private final String actionName; private final Set features; - RequestMessage(ThreadContext threadContext, Version version, byte status, long requestId, String actionName, Set features, - StreamInput streamInput) { + Request(ThreadContext threadContext, Version version, byte status, long requestId, String actionName, Set features, + StreamInput streamInput) { super(threadContext, version, status, requestId, streamInput); this.actionName = actionName; this.features = features; @@ -154,9 +154,9 @@ Set getFeatures() { } } - public static class ResponseMessage extends InboundMessage { + public static class Response extends InboundMessage { - ResponseMessage(ThreadContext threadContext, Version version, byte status, long requestId, StreamInput streamInput) { + Response(ThreadContext threadContext, Version version, byte status, long requestId, StreamInput streamInput) { super(threadContext, version, status, requestId, streamInput); } } diff --git a/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java b/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java index 460f8c0f051c1..4b816c6a065e5 100644 --- a/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java +++ b/server/src/main/java/org/elasticsearch/transport/OutboundHandler.java @@ -54,7 +54,7 @@ final class OutboundHandler { private final ThreadPool threadPool; private final BigArrays bigArrays; private final TransportLogger transportLogger; - private final DelegatingTransportMessageListener messageListener = new DelegatingTransportMessageListener(); + private volatile TransportMessageListener messageListener = TransportMessageListener.NOOP_LISTENER; OutboundHandler(String nodeName, Version version, String[] features, ThreadPool threadPool, BigArrays bigArrays, TransportLogger transportLogger) { @@ -67,10 +67,9 @@ final class OutboundHandler { } void sendBytes(TcpChannel channel, BytesReference bytes, ActionListener listener) { - channel.getChannelStats().markAccessed(threadPool.relativeTimeInMillis()); SendContext sendContext = new SendContext(channel, () -> bytes, listener); try { - internalSendMessage(channel, sendContext); + internalSend(channel, sendContext); } catch (IOException e) { // This should not happen as the bytes are already serialized throw new AssertionError(e); @@ -81,9 +80,9 @@ void sendBytes(TcpChannel channel, BytesReference bytes, ActionListener li * Sends the request to the given channel. This method should be used to send {@link TransportRequest} * objects back to the caller. */ - public void sendRequest(final DiscoveryNode node, final TcpChannel channel, final long requestId, final String action, - final TransportRequest request, final TransportRequestOptions options, final Version channelVersion, - final boolean compressRequest, final boolean isHandshake) throws IOException, TransportException { + void sendRequest(final DiscoveryNode node, final TcpChannel channel, final long requestId, final String action, + final TransportRequest request, final TransportRequestOptions options, final Version channelVersion, + final boolean compressRequest, final boolean isHandshake) throws IOException, TransportException { Version version = Version.min(this.version, channelVersion); OutboundMessage.Request message = new OutboundMessage.Request(threadPool.getThreadContext(), features, request, version, action, requestId, isHandshake, compressRequest); @@ -96,10 +95,10 @@ public void sendRequest(final DiscoveryNode node, final TcpChannel channel, fina * Sends the response to the given channel. This method should be used to send {@link TransportResponse} * objects back to the caller. * - * @see #sendErrorResponse(Version, Set, TcpChannel, Exception, long, String) for sending error responses + * @see #sendErrorResponse(Version, Set, TcpChannel, long, String, Exception) for sending error responses */ void sendResponse(final Version nodeVersion, final Set features, final TcpChannel channel, - final TransportResponse response, final long requestId, final String action, + final long requestId, final String action, final TransportResponse response, final boolean compress, final boolean isHandshake) throws IOException { Version version = Version.min(this.version, nodeVersion); OutboundMessage.Response message = new OutboundMessage.Response(threadPool.getThreadContext(), features, response, version, @@ -111,8 +110,8 @@ void sendResponse(final Version nodeVersion, final Set features, final T /** * Sends back an error response to the caller via the given channel */ - void sendErrorResponse(final Version nodeVersion, final Set features, final TcpChannel channel, final Exception error, - final long requestId, final String action) throws IOException { + void sendErrorResponse(final Version nodeVersion, final Set features, final TcpChannel channel, final long requestId, + final String action, final Exception error) throws IOException { Version version = Version.min(this.version, nodeVersion); TransportAddress address = new TransportAddress(channel.getLocalAddress()); RemoteTransportException tx = new RemoteTransportException(nodeName, address, action, error); @@ -122,17 +121,13 @@ void sendErrorResponse(final Version nodeVersion, final Set features, fi sendMessage(channel, message, listener); } - void sendMessage(TcpChannel channel, OutboundMessage networkMessage, ActionListener listener) throws IOException { - channel.getChannelStats().markAccessed(threadPool.relativeTimeInMillis()); + private void sendMessage(TcpChannel channel, OutboundMessage networkMessage, ActionListener listener) throws IOException { MessageSerializer serializer = new MessageSerializer(networkMessage, bigArrays); SendContext sendContext = new SendContext(channel, serializer, listener, serializer); - internalSendMessage(channel, sendContext); + internalSend(channel, sendContext); } - /** - * sends a message to the given channel, using the given callbacks. - */ - private void internalSendMessage(TcpChannel channel, SendContext sendContext) throws IOException { + private void internalSend(TcpChannel channel, SendContext sendContext) throws IOException { channel.getChannelStats().markAccessed(threadPool.relativeTimeInMillis()); BytesReference reference = sendContext.get(); try { @@ -149,12 +144,12 @@ MeanMetric getTransmittedBytes() { return transmittedBytesMetric; } - void addMessageListener(TransportMessageListener listener) { - messageListener.addListener(listener); - } - - void removeMessageListener(TransportMessageListener listener) { - messageListener.removeListener(listener); + void setMessageListener(TransportMessageListener listener) { + if (messageListener == TransportMessageListener.NOOP_LISTENER) { + messageListener = listener; + } else { + throw new IllegalStateException("Cannot set message listener twice"); + } } private static class MessageSerializer implements CheckedSupplier, Releasable { diff --git a/server/src/main/java/org/elasticsearch/transport/TcpTransport.java b/server/src/main/java/org/elasticsearch/transport/TcpTransport.java index b86bb90e85146..8125d5bcb12f6 100644 --- a/server/src/main/java/org/elasticsearch/transport/TcpTransport.java +++ b/server/src/main/java/org/elasticsearch/transport/TcpTransport.java @@ -114,8 +114,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent implements protected final NetworkService networkService; protected final Set profileSettings; - private static final TransportMessageListener NOOP_LISTENER = new TransportMessageListener() {}; - private volatile TransportMessageListener messageListener = NOOP_LISTENER; + private volatile TransportMessageListener messageListener = TransportMessageListener.NOOP_LISTENER; private final ConcurrentMap profileBoundAddresses = newConcurrentMap(); private final Map> serverChannels = newConcurrentMap(); @@ -166,8 +165,8 @@ public TcpTransport(Settings settings, Version version, ThreadPool threadPool, P (node, channel, requestId, v) -> outboundHandler.sendRequest(node, channel, requestId, TransportHandshaker.HANDSHAKE_ACTION_NAME, new TransportHandshaker.HandshakeRequest(version), TransportRequestOptions.EMPTY, v, false, true), - (v, features1, channel, response, requestId) -> outboundHandler.sendResponse(v, features1, channel, response, requestId, - TransportHandshaker.HANDSHAKE_ACTION_NAME, false, true)); + (v, features1, channel, response, requestId) -> outboundHandler.sendResponse(v, features1, channel, requestId, + TransportHandshaker.HANDSHAKE_ACTION_NAME, response, false, true)); this.keepAlive = new TransportKeepAlive(threadPool, this.outboundHandler::sendBytes); this.reader = new InboundMessage.Reader(version, namedWriteableRegistry, threadPool.getThreadContext()); } @@ -178,9 +177,9 @@ protected void doStart() { @Override public synchronized void setMessageListener(TransportMessageListener listener) { - // TODO - if (messageListener == NOOP_LISTENER) { + if (messageListener == TransportMessageListener.NOOP_LISTENER) { messageListener = listener; + outboundHandler.setMessageListener(listener); } else { throw new IllegalStateException("Cannot set message listener twice"); } @@ -835,7 +834,7 @@ public final void messageReceived(BytesReference reference, TcpChannel channel) message.getStoredContext().restore(); threadContext.putTransient("_remote_address", remoteAddress); if (message.isRequest()) { - handleRequest(channel, (InboundMessage.RequestMessage) message, reference.length()); + handleRequest(channel, (InboundMessage.Request) message, reference.length()); } else { final TransportResponseHandler handler; long requestId = message.getRequestId(); @@ -921,7 +920,7 @@ private void handleException(final TransportResponseHandler handler, Throwable e }); } - protected void handleRequest(TcpChannel channel, InboundMessage.RequestMessage message, int messageLengthBytes) throws IOException { + protected void handleRequest(TcpChannel channel, InboundMessage.Request message, int messageLengthBytes) throws IOException { final Set features = message.getFeatures(); final String profileName = channel.getProfile(); final String action = message.getActionName(); diff --git a/server/src/main/java/org/elasticsearch/transport/TcpTransportChannel.java b/server/src/main/java/org/elasticsearch/transport/TcpTransportChannel.java index a52fd7a8a1eae..aab6e25001ddc 100644 --- a/server/src/main/java/org/elasticsearch/transport/TcpTransportChannel.java +++ b/server/src/main/java/org/elasticsearch/transport/TcpTransportChannel.java @@ -61,7 +61,7 @@ public String getProfileName() { @Override public void sendResponse(TransportResponse response) throws IOException { try { - outboundHandler.sendResponse(version, features, channel, response, requestId, action, compressResponse, false); + outboundHandler.sendResponse(version, features, channel, requestId, action, response, compressResponse, false); } finally { release(false); } @@ -70,7 +70,7 @@ public void sendResponse(TransportResponse response) throws IOException { @Override public void sendResponse(Exception exception) throws IOException { try { - outboundHandler.sendErrorResponse(version, features, channel, exception, requestId, action); + outboundHandler.sendErrorResponse(version, features, channel, requestId, action, exception); } finally { release(true); } diff --git a/server/src/main/java/org/elasticsearch/transport/TransportMessageListener.java b/server/src/main/java/org/elasticsearch/transport/TransportMessageListener.java index bc57c62ca8d70..62ff3d8fa4302 100644 --- a/server/src/main/java/org/elasticsearch/transport/TransportMessageListener.java +++ b/server/src/main/java/org/elasticsearch/transport/TransportMessageListener.java @@ -22,6 +22,8 @@ public interface TransportMessageListener { + TransportMessageListener NOOP_LISTENER = new TransportMessageListener() {}; + /** * Called once a request is received * @param requestId the internal request ID diff --git a/server/src/test/java/org/elasticsearch/transport/InboundMessageTests.java b/server/src/test/java/org/elasticsearch/transport/InboundMessageTests.java index 499b6586543ed..2615a3fdc35a9 100644 --- a/server/src/test/java/org/elasticsearch/transport/InboundMessageTests.java +++ b/server/src/test/java/org/elasticsearch/transport/InboundMessageTests.java @@ -63,7 +63,7 @@ public void testReadRequest() throws IOException { InboundMessage.Reader reader = new InboundMessage.Reader(version, registry, threadContext); BytesReference sliced = reference.slice(6, reference.length() - 6); - InboundMessage.RequestMessage inboundMessage = (InboundMessage.RequestMessage) reader.deserialize(sliced); + InboundMessage.Request inboundMessage = (InboundMessage.Request) reader.deserialize(sliced); // Check that deserialize does not overwrite current thread context. assertEquals("header_value2", threadContext.getHeader("header")); inboundMessage.getStoredContext().restore(); @@ -102,7 +102,7 @@ public void testReadResponse() throws IOException { InboundMessage.Reader reader = new InboundMessage.Reader(version, registry, threadContext); BytesReference sliced = reference.slice(6, reference.length() - 6); - InboundMessage.ResponseMessage inboundMessage = (InboundMessage.ResponseMessage) reader.deserialize(sliced); + InboundMessage.Response inboundMessage = (InboundMessage.Response) reader.deserialize(sliced); // Check that deserialize does not overwrite current thread context. assertEquals("header_value2", threadContext.getHeader("header")); inboundMessage.getStoredContext().restore(); @@ -138,7 +138,7 @@ public void testReadErrorResponse() throws IOException { InboundMessage.Reader reader = new InboundMessage.Reader(version, registry, threadContext); BytesReference sliced = reference.slice(6, reference.length() - 6); - InboundMessage.ResponseMessage inboundMessage = (InboundMessage.ResponseMessage) reader.deserialize(sliced); + InboundMessage.Response inboundMessage = (InboundMessage.Response) reader.deserialize(sliced); // Check that deserialize does not overwrite current thread context. assertEquals("header_value2", threadContext.getHeader("header")); inboundMessage.getStoredContext().restore(); diff --git a/server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java b/server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java index f82d1d38f55d5..578474048fd8a 100644 --- a/server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java +++ b/server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java @@ -19,14 +19,16 @@ package org.elasticsearch.transport; +import org.elasticsearch.ElasticsearchException; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.test.ESTestCase; @@ -36,27 +38,39 @@ import org.junit.Before; import java.io.IOException; +import java.net.InetSocketAddress; import java.nio.charset.StandardCharsets; import java.util.Collections; -import java.util.HashSet; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.instanceOf; + public class OutboundHandlerTests extends ESTestCase { + private final String feature1 = "feature1"; + private final String feature2 = "feature2"; private final TestThreadPool threadPool = new TestThreadPool(getClass().getName()); private final NamedWriteableRegistry namedWriteableRegistry = new NamedWriteableRegistry(Collections.emptyList()); + private final TransportRequestOptions options = TransportRequestOptions.EMPTY; private OutboundHandler handler; - private FakeTcpChannel fakeTcpChannel; + private FakeTcpChannel channel; + private DiscoveryNode node; @Before public void setUp() throws Exception { super.setUp(); TransportLogger transportLogger = new TransportLogger(); - fakeTcpChannel = new FakeTcpChannel(randomBoolean()); - handler = new OutboundHandler("node", Version.CURRENT, new String[0], threadPool, BigArrays.NON_RECYCLING_INSTANCE, - transportLogger); + InetSocketAddress localAddress = buildNewFakeTransportAddress().address(); + InetSocketAddress remoteAddress = buildNewFakeTransportAddress().address(); + channel = new FakeTcpChannel(randomBoolean(), localAddress, remoteAddress); + TransportAddress transportAddress = buildNewFakeTransportAddress(); + node = new DiscoveryNode("", transportAddress, Version.CURRENT); + String[] features = {feature1, feature2}; + handler = new OutboundHandler("node", Version.CURRENT, features, threadPool, BigArrays.NON_RECYCLING_INSTANCE, transportLogger); } @After @@ -71,10 +85,10 @@ public void testSendRawBytes() { AtomicBoolean isSuccess = new AtomicBoolean(false); AtomicReference exception = new AtomicReference<>(); ActionListener listener = ActionListener.wrap((v) -> isSuccess.set(true), exception::set); - handler.sendBytes(fakeTcpChannel, bytesArray, listener); + handler.sendBytes(channel, bytesArray, listener); - BytesReference reference = fakeTcpChannel.getMessageCaptor().get(); - ActionListener sendListener = fakeTcpChannel.getListenerCaptor().get(); + BytesReference reference = channel.getMessageCaptor().get(); + ActionListener sendListener = channel.getListenerCaptor().get(); if (randomBoolean()) { sendListener.onResponse(null); assertTrue(isSuccess.get()); @@ -91,35 +105,42 @@ public void testSendRawBytes() { public void testSendRequest() throws IOException { ThreadContext threadContext = threadPool.getThreadContext(); - Version version = Version.CURRENT; - String actionName = "handshake"; + Version version = randomFrom(Version.CURRENT, Version.CURRENT.minimumCompatibilityVersion()); + String action = "handshake"; long requestId = randomLongBetween(0, 300); boolean isHandshake = randomBoolean(); boolean compress = randomBoolean(); String value = "message"; threadContext.putHeader("header", "header_value"); - Writeable writeable = new Message(value); - - OutboundMessage message = new OutboundMessage.Request(threadContext, new String[0], writeable, version, actionName, requestId, - isHandshake, compress); + Request request = new Request(value); - AtomicBoolean isSuccess = new AtomicBoolean(false); - AtomicReference exception = new AtomicReference<>(); - ActionListener listener = ActionListener.wrap((v) -> isSuccess.set(true), exception::set); - handler.sendMessage(fakeTcpChannel, message, listener); + AtomicReference nodeRef = new AtomicReference<>(); + AtomicLong requestIdRef = new AtomicLong(); + AtomicReference actionRef = new AtomicReference<>(); + AtomicReference requestRef = new AtomicReference<>(); + handler.setMessageListener(new TransportMessageListener() { + @Override + public void onRequestSent(DiscoveryNode node, long requestId, String action, TransportRequest request, + TransportRequestOptions options) { + nodeRef.set(node); + requestIdRef.set(requestId); + actionRef.set(action); + requestRef.set(request); + } + }); + handler.sendRequest(node, channel, requestId, action, request, options, version, compress, isHandshake); - BytesReference reference = fakeTcpChannel.getMessageCaptor().get(); - ActionListener sendListener = fakeTcpChannel.getListenerCaptor().get(); + BytesReference reference = channel.getMessageCaptor().get(); + ActionListener sendListener = channel.getListenerCaptor().get(); if (randomBoolean()) { sendListener.onResponse(null); - assertTrue(isSuccess.get()); - assertNull(exception.get()); } else { - IOException e = new IOException("failed"); - sendListener.onFailure(e); - assertFalse(isSuccess.get()); - assertSame(e, exception.get()); + sendListener.onFailure(new IOException("failed")); } + assertEquals(node, nodeRef.get()); + assertEquals(requestId, requestIdRef.get()); + assertEquals(action, actionRef.get()); + assertEquals(request, requestRef.get()); InboundMessage.Reader reader = new InboundMessage.Reader(Version.CURRENT, namedWriteableRegistry, threadPool.getThreadContext()); try (InboundMessage inboundMessage = reader.deserialize(reference.slice(6, reference.length() - 6))) { @@ -137,7 +158,10 @@ public void testSendRequest() throws IOException { } else { assertFalse(inboundMessage.isCompress()); } - Message readMessage = new Message(); + InboundMessage.Request inboundRequest = (InboundMessage.Request) inboundMessage; + assertThat(inboundRequest.getFeatures(), contains(feature1, feature2)); + + Request readMessage = new Request(); readMessage.readFrom(inboundMessage.getStreamInput()); assertEquals(value, readMessage.value); @@ -150,57 +174,47 @@ public void testSendRequest() throws IOException { } } - - - public void testSendMessage() throws IOException { - OutboundMessage message; + public void testSendResponse() throws IOException { ThreadContext threadContext = threadPool.getThreadContext(); - Version version = Version.CURRENT; - String actionName = "handshake"; + Version version = randomFrom(Version.CURRENT, Version.CURRENT.minimumCompatibilityVersion()); + String action = "handshake"; long requestId = randomLongBetween(0, 300); boolean isHandshake = randomBoolean(); boolean compress = randomBoolean(); String value = "message"; threadContext.putHeader("header", "header_value"); - Writeable writeable = new Message(value); + Response response = new Response(value); - boolean isRequest = randomBoolean(); - if (isRequest) { - message = new OutboundMessage.Request(threadContext, new String[0], writeable, version, actionName, requestId, isHandshake, - compress); - } else { - message = new OutboundMessage.Response(threadContext, new HashSet<>(), writeable, version, requestId, isHandshake, compress); - } - - AtomicBoolean isSuccess = new AtomicBoolean(false); - AtomicReference exception = new AtomicReference<>(); - ActionListener listener = ActionListener.wrap((v) -> isSuccess.set(true), exception::set); - handler.sendMessage(fakeTcpChannel, message, listener); + AtomicLong requestIdRef = new AtomicLong(); + AtomicReference actionRef = new AtomicReference<>(); + AtomicReference responseRef = new AtomicReference<>(); + handler.setMessageListener(new TransportMessageListener() { + @Override + public void onResponseSent(long requestId, String action, TransportResponse response) { + requestIdRef.set(requestId); + actionRef.set(action); + responseRef.set(response); + } + }); + handler.sendResponse(version, Collections.emptySet(), channel, requestId, action, response, compress, isHandshake); - BytesReference reference = fakeTcpChannel.getMessageCaptor().get(); - ActionListener sendListener = fakeTcpChannel.getListenerCaptor().get(); + BytesReference reference = channel.getMessageCaptor().get(); + ActionListener sendListener = channel.getListenerCaptor().get(); if (randomBoolean()) { sendListener.onResponse(null); - assertTrue(isSuccess.get()); - assertNull(exception.get()); } else { - IOException e = new IOException("failed"); - sendListener.onFailure(e); - assertFalse(isSuccess.get()); - assertSame(e, exception.get()); + sendListener.onFailure(new IOException("failed")); } + assertEquals(requestId, requestIdRef.get()); + assertEquals(action, actionRef.get()); + assertEquals(response, responseRef.get()); InboundMessage.Reader reader = new InboundMessage.Reader(Version.CURRENT, namedWriteableRegistry, threadPool.getThreadContext()); try (InboundMessage inboundMessage = reader.deserialize(reference.slice(6, reference.length() - 6))) { assertEquals(version, inboundMessage.getVersion()); assertEquals(requestId, inboundMessage.getRequestId()); - if (isRequest) { - assertTrue(inboundMessage.isRequest()); - assertFalse(inboundMessage.isResponse()); - } else { - assertTrue(inboundMessage.isResponse()); - assertFalse(inboundMessage.isRequest()); - } + assertFalse(inboundMessage.isRequest()); + assertTrue(inboundMessage.isResponse()); if (isHandshake) { assertTrue(inboundMessage.isHandshake()); } else { @@ -211,7 +225,11 @@ public void testSendMessage() throws IOException { } else { assertFalse(inboundMessage.isCompress()); } - Message readMessage = new Message(); + + InboundMessage.Response inboundResponse = (InboundMessage.Response) inboundMessage; + assertFalse(inboundResponse.isError()); + + Response readMessage = new Response(); readMessage.readFrom(inboundMessage.getStreamInput()); assertEquals(value, readMessage.value); @@ -224,14 +242,95 @@ public void testSendMessage() throws IOException { } } - private static final class Message extends TransportMessage { + public void testErrorResponse() throws IOException { + ThreadContext threadContext = threadPool.getThreadContext(); + Version version = randomFrom(Version.CURRENT, Version.CURRENT.minimumCompatibilityVersion()); + String action = "handshake"; + long requestId = randomLongBetween(0, 300); + threadContext.putHeader("header", "header_value"); + ElasticsearchException error = new ElasticsearchException("boom"); + + AtomicLong requestIdRef = new AtomicLong(); + AtomicReference actionRef = new AtomicReference<>(); + AtomicReference responseRef = new AtomicReference<>(); + handler.setMessageListener(new TransportMessageListener() { + @Override + public void onResponseSent(long requestId, String action, Exception error) { + requestIdRef.set(requestId); + actionRef.set(action); + responseRef.set(error); + } + }); + handler.sendErrorResponse(version, Collections.emptySet(), channel, requestId, action, error); + + BytesReference reference = channel.getMessageCaptor().get(); + ActionListener sendListener = channel.getListenerCaptor().get(); + if (randomBoolean()) { + sendListener.onResponse(null); + } else { + sendListener.onFailure(new IOException("failed")); + } + assertEquals(requestId, requestIdRef.get()); + assertEquals(action, actionRef.get()); + assertEquals(error, responseRef.get()); + + InboundMessage.Reader reader = new InboundMessage.Reader(Version.CURRENT, namedWriteableRegistry, threadPool.getThreadContext()); + try (InboundMessage inboundMessage = reader.deserialize(reference.slice(6, reference.length() - 6))) { + assertEquals(version, inboundMessage.getVersion()); + assertEquals(requestId, inboundMessage.getRequestId()); + assertFalse(inboundMessage.isRequest()); + assertTrue(inboundMessage.isResponse()); + assertFalse(inboundMessage.isCompress()); + assertFalse(inboundMessage.isHandshake()); + + InboundMessage.Response inboundResponse = (InboundMessage.Response) inboundMessage; + assertTrue(inboundResponse.isError()); + + RemoteTransportException remoteException = inboundMessage.getStreamInput().readException(); + assertThat(remoteException.getCause(), instanceOf(ElasticsearchException.class)); + assertEquals(remoteException.getCause().getMessage(), "boom"); + assertEquals(action, remoteException.action()); + assertEquals(channel.getLocalAddress(), remoteException.address().address()); + + try (ThreadContext.StoredContext existing = threadContext.stashContext()) { + ThreadContext.StoredContext storedContext = inboundMessage.getStoredContext(); + assertNull(threadContext.getHeader("header")); + storedContext.restore(); + assertEquals("header_value", threadContext.getHeader("header")); + } + } + } + + private static final class Request extends TransportRequest { + + public String value; + + private Request() { + } + + private Request(String value) { + this.value = value; + } + + @Override + public void readFrom(StreamInput in) throws IOException { + value = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(value); + } + } + + private static final class Response extends TransportResponse { public String value; - private Message() { + private Response() { } - private Message(String value) { + private Response(String value) { this.value = value; } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java index ad05e6e3d2288..b51f2f78c9434 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/AbstractSimpleTransportTestCase.java @@ -2008,12 +2008,12 @@ public void testTcpHandshake() { new NetworkService(Collections.emptyList()), PageCacheRecycler.NON_RECYCLING_INSTANCE, namedWriteableRegistry, new NoneCircuitBreakerService()) { @Override - protected void handleRequest(TcpChannel channel, InboundMessage.RequestMessage request, int messageLengthBytes) + protected void handleRequest(TcpChannel channel, InboundMessage.Request request, int messageLengthBytes) throws IOException { // we flip the isHandshake bit back and act like the handler is not found byte status = (byte) (request.status & ~(1 << 3)); Version version = request.getVersion(); - InboundMessage.RequestMessage nonHandshakeRequest = new InboundMessage.RequestMessage(request.threadContext, version, + InboundMessage.Request nonHandshakeRequest = new InboundMessage.Request(request.threadContext, version, status, request.getRequestId(), request.getActionName(), request.getFeatures(), request.getStreamInput()); super.handleRequest(channel, nonHandshakeRequest, messageLengthBytes); } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/FakeTcpChannel.java b/test/framework/src/main/java/org/elasticsearch/transport/FakeTcpChannel.java index bb392554305c0..e9593fc662257 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/FakeTcpChannel.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/FakeTcpChannel.java @@ -44,6 +44,10 @@ public FakeTcpChannel(boolean isServer) { this(isServer, "profile", new AtomicReference<>()); } + public FakeTcpChannel(boolean isServer, InetSocketAddress localAddress, InetSocketAddress remoteAddress) { + this(isServer, localAddress, remoteAddress, "profile", new AtomicReference<>()); + } + public FakeTcpChannel(boolean isServer, AtomicReference messageCaptor) { this(isServer, "profile", messageCaptor); } From f4cf7c721b0e3a184859ee29504030539b2fcaa2 Mon Sep 17 00:00:00 2001 From: Tim Brooks Date: Fri, 22 Mar 2019 09:10:32 -0600 Subject: [PATCH 4/5] Delete --- .../DelegatingTransportMessageListener.java | 74 ------------------- 1 file changed, 74 deletions(-) delete mode 100644 server/src/main/java/org/elasticsearch/transport/DelegatingTransportMessageListener.java diff --git a/server/src/main/java/org/elasticsearch/transport/DelegatingTransportMessageListener.java b/server/src/main/java/org/elasticsearch/transport/DelegatingTransportMessageListener.java deleted file mode 100644 index df1e2979ee88c..0000000000000 --- a/server/src/main/java/org/elasticsearch/transport/DelegatingTransportMessageListener.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.transport; - -import org.elasticsearch.cluster.node.DiscoveryNode; - -import java.util.List; -import java.util.concurrent.CopyOnWriteArrayList; - -final class DelegatingTransportMessageListener implements TransportMessageListener { - - private final List listeners = new CopyOnWriteArrayList<>(); - - @Override - public void onRequestReceived(long requestId, String action) { - for (TransportMessageListener listener : listeners) { - listener.onRequestReceived(requestId, action); - } - } - - @Override - public void onResponseSent(long requestId, String action, TransportResponse response) { - for (TransportMessageListener listener : listeners) { - listener.onResponseSent(requestId, action, response); - } - } - - @Override - public void onResponseSent(long requestId, String action, Exception error) { - for (TransportMessageListener listener : listeners) { - listener.onResponseSent(requestId, action, error); - } - } - - @Override - public void onRequestSent(DiscoveryNode node, long requestId, String action, TransportRequest request, - TransportRequestOptions finalOptions) { - for (TransportMessageListener listener : listeners) { - listener.onRequestSent(node, requestId, action, request, finalOptions); - } - } - - @Override - public void onResponseReceived(long requestId, Transport.ResponseContext holder) { - for (TransportMessageListener listener : listeners) { - listener.onResponseReceived(requestId, holder); - } - } - - public void addListener(TransportMessageListener listener) { - listeners.add(listener); - } - - public boolean removeListener(TransportMessageListener listener) { - return listeners.remove(listener); - } -} From 4c28044519bcd47eaee835d5d1635f8f06f43214 Mon Sep 17 00:00:00 2001 From: Tim Brooks Date: Fri, 22 Mar 2019 11:15:45 -0600 Subject: [PATCH 5/5] Cleanup --- .../org/elasticsearch/transport/OutboundHandlerTests.java | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java b/server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java index 578474048fd8a..baab504e61fa4 100644 --- a/server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java +++ b/server/src/test/java/org/elasticsearch/transport/OutboundHandlerTests.java @@ -38,7 +38,6 @@ import org.junit.Before; import java.io.IOException; -import java.net.InetSocketAddress; import java.nio.charset.StandardCharsets; import java.util.Collections; import java.util.concurrent.TimeUnit; @@ -64,9 +63,7 @@ public class OutboundHandlerTests extends ESTestCase { public void setUp() throws Exception { super.setUp(); TransportLogger transportLogger = new TransportLogger(); - InetSocketAddress localAddress = buildNewFakeTransportAddress().address(); - InetSocketAddress remoteAddress = buildNewFakeTransportAddress().address(); - channel = new FakeTcpChannel(randomBoolean(), localAddress, remoteAddress); + channel = new FakeTcpChannel(randomBoolean(), buildNewFakeTransportAddress().address(), buildNewFakeTransportAddress().address()); TransportAddress transportAddress = buildNewFakeTransportAddress(); node = new DiscoveryNode("", transportAddress, Version.CURRENT); String[] features = {feature1, feature2};