diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/NioGroup.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/NioGroup.java
new file mode 100644
index 0000000000000..b0e1862c706ca
--- /dev/null
+++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/NioGroup.java
@@ -0,0 +1,129 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.transport.nio;
+
+import org.apache.logging.log4j.Logger;
+import org.apache.lucene.util.IOUtils;
+import org.elasticsearch.transport.nio.channel.ChannelFactory;
+import org.elasticsearch.transport.nio.channel.NioServerSocketChannel;
+import org.elasticsearch.transport.nio.channel.NioSocketChannel;
+
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.util.ArrayList;
+import java.util.concurrent.ThreadFactory;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.function.BiFunction;
+import java.util.function.Function;
+import java.util.function.Supplier;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+/**
+ * The NioGroup is a group of selectors for interfacing with java nio. When it is started it will create the
+ * configured number of socket and acceptor selectors. Each selector will be running in a dedicated thread.
+ * Server connections can be bound using the {@link #bindServerChannel(InetSocketAddress, ChannelFactory)}
+ * method. Client connections can be opened using the {@link #openChannel(InetSocketAddress, ChannelFactory)}
+ * method.
+ *
+ * The logic specific to a particular channel is provided by the {@link ChannelFactory} passed to the method
+ * when the channel is created. This is what allows an NioGroup to support different channel types.
+ */
+public class NioGroup implements AutoCloseable {
+
+
+ private final ArrayList acceptors;
+ private final RoundRobinSupplier acceptorSupplier;
+
+ private final ArrayList socketSelectors;
+ private final RoundRobinSupplier socketSelectorSupplier;
+
+ private final AtomicBoolean isOpen = new AtomicBoolean(true);
+
+ public NioGroup(Logger logger, ThreadFactory acceptorThreadFactory, int acceptorCount,
+ BiFunction, AcceptorEventHandler> acceptorEventHandlerFunction,
+ ThreadFactory socketSelectorThreadFactory, int socketSelectorCount,
+ Function socketEventHandlerFunction) throws IOException {
+ acceptors = new ArrayList<>(acceptorCount);
+ socketSelectors = new ArrayList<>(socketSelectorCount);
+
+ try {
+ for (int i = 0; i < socketSelectorCount; ++i) {
+ SocketSelector selector = new SocketSelector(socketEventHandlerFunction.apply(logger));
+ socketSelectors.add(selector);
+ }
+ startSelectors(socketSelectors, socketSelectorThreadFactory);
+
+ for (int i = 0; i < acceptorCount; ++i) {
+ SocketSelector[] childSelectors = this.socketSelectors.toArray(new SocketSelector[this.socketSelectors.size()]);
+ Supplier selectorSupplier = new RoundRobinSupplier<>(childSelectors);
+ AcceptingSelector acceptor = new AcceptingSelector(acceptorEventHandlerFunction.apply(logger, selectorSupplier));
+ acceptors.add(acceptor);
+ }
+ startSelectors(acceptors, acceptorThreadFactory);
+ } catch (Exception e) {
+ try {
+ close();
+ } catch (Exception e1) {
+ e.addSuppressed(e1);
+ }
+ throw e;
+ }
+
+ socketSelectorSupplier = new RoundRobinSupplier<>(socketSelectors.toArray(new SocketSelector[socketSelectors.size()]));
+ acceptorSupplier = new RoundRobinSupplier<>(acceptors.toArray(new AcceptingSelector[acceptors.size()]));
+ }
+
+ public S bindServerChannel(InetSocketAddress address, ChannelFactory factory)
+ throws IOException {
+ ensureOpen();
+ if (acceptors.isEmpty()) {
+ throw new IllegalArgumentException("There are no acceptors configured. Without acceptors, server channels are not supported.");
+ }
+ return factory.openNioServerSocketChannel(address, acceptorSupplier.get());
+ }
+
+ public S openChannel(InetSocketAddress address, ChannelFactory, S> factory) throws IOException {
+ ensureOpen();
+ return factory.openNioChannel(address, socketSelectorSupplier.get());
+ }
+
+ @Override
+ public void close() throws IOException {
+ if (isOpen.compareAndSet(true, false)) {
+ IOUtils.close(Stream.concat(acceptors.stream(), socketSelectors.stream()).collect(Collectors.toList()));
+ }
+ }
+
+ private static void startSelectors(Iterable selectors, ThreadFactory threadFactory) {
+ for (ESSelector acceptor : selectors) {
+ if (acceptor.isRunning() == false) {
+ threadFactory.newThread(acceptor::runLoop).start();
+ acceptor.isRunningFuture().actionGet();
+ }
+ }
+ }
+
+ private void ensureOpen() {
+ if (isOpen.get() == false) {
+ throw new IllegalStateException("NioGroup is closed.");
+ }
+ }
+}
diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/NioShutdown.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/NioShutdown.java
deleted file mode 100644
index 3970e69b2c1d6..0000000000000
--- a/test/framework/src/main/java/org/elasticsearch/transport/nio/NioShutdown.java
+++ /dev/null
@@ -1,55 +0,0 @@
-/*
- * Licensed to Elasticsearch under one or more contributor
- * license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright
- * ownership. Elasticsearch licenses this file to you under
- * the Apache License, Version 2.0 (the "License"); you may
- * not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.elasticsearch.transport.nio;
-
-import org.apache.logging.log4j.Logger;
-import org.elasticsearch.ElasticsearchException;
-
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.concurrent.CountDownLatch;
-
-public class NioShutdown {
-
- private final Logger logger;
-
- public NioShutdown(Logger logger) {
- this.logger = logger;
- }
-
- void orderlyShutdown(ArrayList acceptors, ArrayList socketSelectors) {
-
- for (AcceptingSelector acceptor : acceptors) {
- shutdownSelector(acceptor);
- }
-
- for (SocketSelector selector : socketSelectors) {
- shutdownSelector(selector);
- }
- }
-
- private void shutdownSelector(ESSelector selector) {
- try {
- selector.close();
- } catch (IOException | ElasticsearchException e) {
- logger.warn("unexpected exception while stopping selector", e);
- }
- }
-}
diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/NioTransport.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/NioTransport.java
index 775030bc6dbb3..bb28d93f85a84 100644
--- a/test/framework/src/main/java/org/elasticsearch/transport/nio/NioTransport.java
+++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/NioTransport.java
@@ -19,6 +19,7 @@
package org.elasticsearch.transport.nio;
+import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
@@ -46,9 +47,7 @@
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
-import java.util.ArrayList;
import java.util.concurrent.ConcurrentMap;
-import java.util.concurrent.ThreadFactory;
import java.util.function.Consumer;
import java.util.function.Supplier;
@@ -71,11 +70,8 @@ public class NioTransport extends TcpTransport {
private final PageCacheRecycler pageCacheRecycler;
private final ConcurrentMap profileToChannelFactory = newConcurrentMap();
- private final ArrayList acceptors = new ArrayList<>();
- private final ArrayList socketSelectors = new ArrayList<>();
- private RoundRobinSelectorSupplier clientSelectorSupplier;
- private TcpChannelFactory clientChannelFactory;
- private int acceptorNumber;
+ private volatile NioGroup nioGroup;
+ private volatile TcpChannelFactory clientChannelFactory;
public NioTransport(Settings settings, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays,
PageCacheRecycler pageCacheRecycler, NamedWriteableRegistry namedWriteableRegistry,
@@ -87,14 +83,13 @@ public NioTransport(Settings settings, ThreadPool threadPool, NetworkService net
@Override
protected TcpNioServerSocketChannel bind(String name, InetSocketAddress address) throws IOException {
TcpChannelFactory channelFactory = this.profileToChannelFactory.get(name);
- AcceptingSelector selector = acceptors.get(++acceptorNumber % NioTransport.NIO_ACCEPTOR_COUNT.get(settings));
- return channelFactory.openNioServerSocketChannel(address, selector);
+ return nioGroup.bindServerChannel(address, channelFactory);
}
@Override
protected TcpNioSocketChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener connectListener)
throws IOException {
- TcpNioSocketChannel channel = clientChannelFactory.openNioChannel(node.getAddress().address(), clientSelectorSupplier.get());
+ TcpNioSocketChannel channel = nioGroup.openChannel(node.getAddress().address(), clientChannelFactory);
channel.addConnectListener(connectListener);
return channel;
}
@@ -103,42 +98,19 @@ protected TcpNioSocketChannel initiateChannel(DiscoveryNode node, TimeValue conn
protected void doStart() {
boolean success = false;
try {
- int workerCount = NioTransport.NIO_WORKER_COUNT.get(settings);
- for (int i = 0; i < workerCount; ++i) {
- SocketSelector selector = new SocketSelector(getSocketEventHandler());
- socketSelectors.add(selector);
+ int acceptorCount = 0;
+ boolean useNetworkServer = NetworkService.NETWORK_SERVER.get(settings);
+ if (useNetworkServer) {
+ acceptorCount = NioTransport.NIO_ACCEPTOR_COUNT.get(settings);
}
+ nioGroup = new NioGroup(logger, daemonThreadFactory(this.settings, TRANSPORT_ACCEPTOR_THREAD_NAME_PREFIX), acceptorCount,
+ AcceptorEventHandler::new, daemonThreadFactory(this.settings, TRANSPORT_WORKER_THREAD_NAME_PREFIX),
+ NioTransport.NIO_WORKER_COUNT.get(settings), this::getSocketEventHandler);
- for (SocketSelector selector : socketSelectors) {
- if (selector.isRunning() == false) {
- ThreadFactory threadFactory = daemonThreadFactory(this.settings, TRANSPORT_WORKER_THREAD_NAME_PREFIX);
- threadFactory.newThread(selector::runLoop).start();
- selector.isRunningFuture().actionGet();
- }
- }
-
- Consumer clientContextSetter = getContextSetter("client-socket");
- clientSelectorSupplier = new RoundRobinSelectorSupplier(socketSelectors);
ProfileSettings clientProfileSettings = new ProfileSettings(settings, "default");
- clientChannelFactory = new TcpChannelFactory(clientProfileSettings, clientContextSetter, getServerContextSetter());
-
- if (NetworkService.NETWORK_SERVER.get(settings)) {
- int acceptorCount = NioTransport.NIO_ACCEPTOR_COUNT.get(settings);
- for (int i = 0; i < acceptorCount; ++i) {
- Supplier selectorSupplier = new RoundRobinSelectorSupplier(socketSelectors);
- AcceptorEventHandler eventHandler = new AcceptorEventHandler(logger, selectorSupplier);
- AcceptingSelector acceptor = new AcceptingSelector(eventHandler);
- acceptors.add(acceptor);
- }
-
- for (AcceptingSelector acceptor : acceptors) {
- if (acceptor.isRunning() == false) {
- ThreadFactory threadFactory = daemonThreadFactory(this.settings, TRANSPORT_ACCEPTOR_THREAD_NAME_PREFIX);
- threadFactory.newThread(acceptor::runLoop).start();
- acceptor.isRunningFuture().actionGet();
- }
- }
+ clientChannelFactory = new TcpChannelFactory(clientProfileSettings, getContextSetter("client"), getServerContextSetter());
+ if (useNetworkServer) {
// loop through all profiles and start them up, special handling for default one
for (ProfileSettings profileSettings : profileSettings) {
String profileName = profileSettings.profileName;
@@ -162,14 +134,15 @@ protected void doStart() {
@Override
protected void stopInternal() {
- NioShutdown nioShutdown = new NioShutdown(logger);
- nioShutdown.orderlyShutdown(acceptors, socketSelectors);
-
+ try {
+ nioGroup.close();
+ } catch (Exception e) {
+ logger.warn("unexpected exception while stopping nio group", e);
+ }
profileToChannelFactory.clear();
- socketSelectors.clear();
}
- protected SocketEventHandler getSocketEventHandler() {
+ protected SocketEventHandler getSocketEventHandler(Logger logger) {
return new SocketEventHandler(logger);
}
@@ -189,8 +162,7 @@ private Consumer getContextSetter(String profileName) {
}
private void acceptChannel(NioSocketChannel channel) {
- TcpNioSocketChannel tcpChannel = (TcpNioSocketChannel) channel;
- serverAcceptedChannel(tcpChannel);
+ serverAcceptedChannel((TcpNioSocketChannel) channel);
}
diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/RoundRobinSelectorSupplier.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/RoundRobinSupplier.java
similarity index 73%
rename from test/framework/src/main/java/org/elasticsearch/transport/nio/RoundRobinSelectorSupplier.java
rename to test/framework/src/main/java/org/elasticsearch/transport/nio/RoundRobinSupplier.java
index 108242b1e0edc..395b955f7ab36 100644
--- a/test/framework/src/main/java/org/elasticsearch/transport/nio/RoundRobinSelectorSupplier.java
+++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/RoundRobinSupplier.java
@@ -19,22 +19,21 @@
package org.elasticsearch.transport.nio;
-import java.util.ArrayList;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
-public class RoundRobinSelectorSupplier implements Supplier {
+public class RoundRobinSupplier implements Supplier {
- private final ArrayList selectors;
+ private final S[] selectors;
private final int count;
private AtomicInteger counter = new AtomicInteger(0);
- public RoundRobinSelectorSupplier(ArrayList selectors) {
- this.count = selectors.size();
+ public RoundRobinSupplier(S[] selectors) {
+ this.count = selectors.length;
this.selectors = selectors;
}
- public SocketSelector get() {
- return selectors.get(counter.getAndIncrement() % count);
+ public S get() {
+ return selectors[counter.getAndIncrement() % count];
}
}
diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/AcceptorEventHandlerTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/AcceptorEventHandlerTests.java
index aedff1721f8d9..48a9e65f00dff 100644
--- a/test/framework/src/test/java/org/elasticsearch/transport/nio/AcceptorEventHandlerTests.java
+++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/AcceptorEventHandlerTests.java
@@ -57,7 +57,7 @@ public void setUpHandler() throws IOException {
acceptedChannelCallback = mock(Consumer.class);
ArrayList selectors = new ArrayList<>();
selectors.add(socketSelector);
- handler = new AcceptorEventHandler(logger, new RoundRobinSelectorSupplier(selectors));
+ handler = new AcceptorEventHandler(logger, new RoundRobinSupplier<>(selectors.toArray(new SocketSelector[selectors.size()])));
AcceptingSelector selector = mock(AcceptingSelector.class);
channel = new DoNotRegisterServerChannel(mock(ServerSocketChannel.class), channelFactory, selector);
diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/NioGroupTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/NioGroupTests.java
new file mode 100644
index 0000000000000..f9b3cbb4e5026
--- /dev/null
+++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/NioGroupTests.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to Elasticsearch under one or more contributor
+ * license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright
+ * ownership. Elasticsearch licenses this file to you under
+ * the Apache License, Version 2.0 (the "License"); you may
+ * not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.elasticsearch.transport.nio;
+
+import org.elasticsearch.common.CheckedRunnable;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.transport.nio.channel.ChannelFactory;
+
+import java.io.IOException;
+import java.net.InetSocketAddress;
+
+import static org.elasticsearch.common.util.concurrent.EsExecutors.daemonThreadFactory;
+import static org.mockito.Mockito.mock;
+
+public class NioGroupTests extends ESTestCase {
+
+ private NioGroup nioGroup;
+
+ @Override
+ public void setUp() throws Exception {
+ super.setUp();
+ nioGroup = new NioGroup(logger, daemonThreadFactory(Settings.EMPTY, "acceptor"), 1, AcceptorEventHandler::new,
+ daemonThreadFactory(Settings.EMPTY, "selector"), 1, SocketEventHandler::new);
+ }
+
+ @Override
+ public void tearDown() throws Exception {
+ nioGroup.close();
+ super.tearDown();
+ }
+
+ public void testStartAndClose() throws IOException {
+ // ctor starts threads. So we are testing that close() stops the threads. Our thread linger checks
+ // will throw an exception is stop fails
+ nioGroup.close();
+ }
+
+ @SuppressWarnings("unchecked")
+ public void testCannotOperateAfterClose() throws IOException {
+ nioGroup.close();
+
+ IllegalStateException ise = expectThrows(IllegalStateException.class,
+ () -> nioGroup.bindServerChannel(mock(InetSocketAddress.class), mock(ChannelFactory.class)));
+ assertEquals("NioGroup is closed.", ise.getMessage());
+ ise = expectThrows(IllegalStateException.class,
+ () -> nioGroup.openChannel(mock(InetSocketAddress.class), mock(ChannelFactory.class)));
+ assertEquals("NioGroup is closed.", ise.getMessage());
+ }
+
+ public void testCanCloseTwice() throws IOException {
+ nioGroup.close();
+ nioGroup.close();
+ }
+
+ public void testExceptionAtStartIsHandled() throws IOException {
+ RuntimeException ex = new RuntimeException();
+ CheckedRunnable ctor = () -> new NioGroup(logger, r -> {throw ex;}, 1,
+ AcceptorEventHandler::new, daemonThreadFactory(Settings.EMPTY, "selector"), 1, SocketEventHandler::new);
+ RuntimeException runtimeException = expectThrows(RuntimeException.class, ctor::run);
+ assertSame(ex, runtimeException);
+ // ctor starts threads. So we are testing that a failure to construct will stop threads. Our thread
+ // linger checks will throw an exception is stop fails
+ }
+}
diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/SimpleNioTransportTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/SimpleNioTransportTests.java
index 1cff80dec793f..1f17c3df54118 100644
--- a/test/framework/src/test/java/org/elasticsearch/transport/nio/SimpleNioTransportTests.java
+++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/SimpleNioTransportTests.java
@@ -19,6 +19,7 @@
package org.elasticsearch.transport.nio;
+import org.apache.logging.log4j.Logger;
import org.elasticsearch.Version;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
@@ -29,7 +30,6 @@
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.MockPageCacheRecycler;
-import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.indices.breaker.NoneCircuitBreakerService;
import org.elasticsearch.node.Node;
import org.elasticsearch.test.transport.MockTransportService;
@@ -78,7 +78,7 @@ protected Version getCurrentVersion() {
}
@Override
- protected SocketEventHandler getSocketEventHandler() {
+ protected SocketEventHandler getSocketEventHandler(Logger logger) {
return new TestingSocketEventHandler(logger);
}
};