From 5c403f86aad6e06680aad6c99450536ec0f22e5c Mon Sep 17 00:00:00 2001 From: jaymode Date: Tue, 19 Feb 2019 14:15:22 -0700 Subject: [PATCH 1/2] Fix failures in SessionFactoryLoadBalancingTests This change aims to fix failures in the session factory load balancing tests that mock failure scenarios. For these tests, we randomly shut down ldap servers and bind a client socket to the port they were listening on. Unfortunately, we would occasionally encounter failures in these tests where a socket was already in use and/or the port we expected to connect to was wrong and in fact was to one of the ldap instances that should have been shut down. The failures are caused by the behavior of certain operating systems when it comes to binding ports and wildcard addresses. It is possible for a separate application to be bound to a wildcard address and still allow our code to bind to that port on a specific address. So when we close the server socket and open the client socket, we are still able to establish a connection since the other application is already listening on that port on a wildcard address. Another variant is that the os will allow a wildcard bind of a server socket when there is already an application listening on that port for a specific address. In order to do our best to prevent failures in these scenarios, this change does the following: 1. Binds a client socket to all addresses in an awaitBusy 2. Adds assumption that we could bind all valid addresses 3. In the case that we still establish a connection to an address that we should not be able to, try to bind and expect a failure of not being connected Closes #32190 --- .../authc/ldap/LdapSessionFactoryTests.java | 16 +- .../authc/ldap/support/LdapTestCase.java | 18 +- .../SessionFactoryLoadBalancingTests.java | 301 +++++++++++++----- 3 files changed, 256 insertions(+), 79 deletions(-) diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/LdapSessionFactoryTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/LdapSessionFactoryTests.java index 106eaa8932629..9867cc29fd3da 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/LdapSessionFactoryTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/LdapSessionFactoryTests.java @@ -9,6 +9,7 @@ import com.unboundid.ldap.sdk.LDAPException; import com.unboundid.ldap.sdk.LDAPURL; import com.unboundid.ldap.sdk.SimpleBindRequest; +import org.elasticsearch.common.network.NetworkAddress; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ThreadContext; @@ -29,6 +30,7 @@ import org.junit.After; import org.junit.Before; +import java.net.InetAddress; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.StandardCopyOption; @@ -73,7 +75,12 @@ public void shutdown() throws InterruptedException { public void testBindWithReadTimeout() throws Exception { InMemoryDirectoryServer ldapServer = randomFrom(ldapServers); String protocol = randomFrom("ldap", "ldaps"); - String ldapUrl = new LDAPURL(protocol, "localhost", ldapServer.getListenPort(protocol), null, null, null, null).toString(); + InetAddress listenAddress = ldapServer.getListenAddress(protocol); + if (listenAddress == null) { + listenAddress = InetAddress.getLoopbackAddress(); + } + String ldapUrl = new LDAPURL(protocol, NetworkAddress.format(listenAddress), ldapServer.getListenPort(protocol), + null, null, null, null).toString(); String groupSearchBase = "o=sevenSeas"; String userTemplates = "cn={0},ou=people,o=sevenSeas"; @@ -233,7 +240,12 @@ public void testGroupLookupBase() throws Exception { */ public void testSslTrustIsReloaded() throws Exception { InMemoryDirectoryServer ldapServer = randomFrom(ldapServers); - String ldapUrl = new LDAPURL("ldaps", "localhost", ldapServer.getListenPort("ldaps"), null, null, null, null).toString(); + InetAddress listenAddress = ldapServer.getListenAddress("ldaps"); + if (listenAddress == null) { + listenAddress = InetAddress.getLoopbackAddress(); + } + String ldapUrl = new LDAPURL("ldaps", NetworkAddress.format(listenAddress), ldapServer.getListenPort("ldaps"), + null, null, null, null).toString(); String groupSearchBase = "o=sevenSeas"; String userTemplates = "cn={0},ou=people,o=sevenSeas"; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/support/LdapTestCase.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/support/LdapTestCase.java index 2c0b2f7716650..a76ee71d114b3 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/support/LdapTestCase.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/support/LdapTestCase.java @@ -18,6 +18,7 @@ import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.network.NetworkAddress; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; @@ -25,6 +26,7 @@ import org.elasticsearch.env.TestEnvironment; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.watcher.ResourceWatcherService; +import org.elasticsearch.xpack.core.XPackSettings; import org.elasticsearch.xpack.core.security.authc.RealmConfig; import org.elasticsearch.xpack.core.security.authc.ldap.LdapSessionFactorySettings; import org.elasticsearch.xpack.core.security.authc.ldap.SearchGroupsResolverSettings; @@ -46,6 +48,7 @@ import javax.net.ssl.SSLServerSocketFactory; import javax.net.ssl.SSLSocketFactory; import javax.net.ssl.X509ExtendedKeyManager; +import java.net.InetAddress; import java.security.AccessController; import java.security.KeyStore; import java.security.PrivilegedAction; @@ -76,7 +79,7 @@ public void startLdap() throws Exception { for (int i = 0; i < numberOfLdapServers; i++) { InMemoryDirectoryServerConfig serverConfig = new InMemoryDirectoryServerConfig("o=sevenSeas"); List listeners = new ArrayList<>(2); - listeners.add(InMemoryListenerConfig.createLDAPConfig("ldap")); + listeners.add(InMemoryListenerConfig.createLDAPConfig("ldap", null, 0, null)); if (openLdapsPort()) { final char[] ldapPassword = "ldap-password".toCharArray(); final KeyStore ks = CertParsingUtils.getKeyStoreFromPEM( @@ -85,11 +88,12 @@ public void startLdap() throws Exception { ldapPassword ); X509ExtendedKeyManager keyManager = CertParsingUtils.keyManager(ks, ldapPassword, KeyManagerFactory.getDefaultAlgorithm()); - final SSLContext context = SSLContext.getInstance("TLSv1.2"); + final SSLContext context = SSLContext.getInstance(XPackSettings.DEFAULT_SUPPORTED_PROTOCOLS.get(0)); context.init(new KeyManager[] { keyManager }, null, null); SSLServerSocketFactory serverSocketFactory = context.getServerSocketFactory(); SSLSocketFactory clientSocketFactory = context.getSocketFactory(); - listeners.add(InMemoryListenerConfig.createLDAPSConfig("ldaps", null, 0, serverSocketFactory, clientSocketFactory)); + listeners.add(InMemoryListenerConfig.createLDAPSConfig("ldaps", InetAddress.getLoopbackAddress(), 0, + serverSocketFactory, clientSocketFactory)); } serverConfig.setListenerConfigs(listeners); InMemoryDirectoryServer ldapServer = new InMemoryDirectoryServer(serverConfig); @@ -111,7 +115,7 @@ protected boolean openLdapsPort() { } @After - public void stopLdap() throws Exception { + public void stopLdap() { for (int i = 0; i < numberOfLdapServers; i++) { ldapServers[i].shutDown(true); } @@ -120,7 +124,11 @@ public void stopLdap() throws Exception { protected String[] ldapUrls() throws LDAPException { List urls = new ArrayList<>(numberOfLdapServers); for (int i = 0; i < numberOfLdapServers; i++) { - LDAPURL url = new LDAPURL("ldap", "localhost", ldapServers[i].getListenPort(), null, null, null, null); + InetAddress listenAddress = ldapServers[i].getListenAddress(); + if (listenAddress == null) { + listenAddress = InetAddress.getLoopbackAddress(); + } + LDAPURL url = new LDAPURL("ldap", NetworkAddress.format(listenAddress), ldapServers[i].getListenPort(), null, null, null, null); urls.add(url.toString()); } return urls.toArray(Strings.EMPTY_ARRAY); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/support/SessionFactoryLoadBalancingTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/support/SessionFactoryLoadBalancingTests.java index 1483b2f474bf0..bef06d4c9e203 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/support/SessionFactoryLoadBalancingTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/support/SessionFactoryLoadBalancingTests.java @@ -7,11 +7,16 @@ import com.unboundid.ldap.listener.InMemoryDirectoryServer; import com.unboundid.ldap.sdk.LDAPConnection; +import com.unboundid.ldap.sdk.LDAPException; +import com.unboundid.ldap.sdk.SimpleBindRequest; +import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.SuppressForbidden; +import org.elasticsearch.common.network.InetAddressHelper; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.core.internal.io.IOUtils; import org.elasticsearch.env.TestEnvironment; import org.elasticsearch.mocksocket.MockServerSocket; import org.elasticsearch.mocksocket.MockSocket; @@ -28,12 +33,17 @@ import java.io.IOException; import java.net.InetAddress; import java.net.InetSocketAddress; +import java.net.NoRouteToHostException; import java.net.Socket; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Collectors; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; @@ -52,7 +62,7 @@ public void init() throws Exception { } @After - public void shutdown() throws InterruptedException { + public void shutdown() { terminate(threadPool); } @@ -62,29 +72,22 @@ public void testRoundRobin() throws Exception { final int numberOfIterations = randomIntBetween(1, 5); for (int iteration = 0; iteration < numberOfIterations; iteration++) { for (int i = 0; i < numberOfLdapServers; i++) { - LDAPConnection connection = null; - try { - connection = LdapUtils.privilegedConnect(testSessionFactory.getServerSet()::getConnection); + try (LDAPConnection connection = LdapUtils.privilegedConnect(testSessionFactory.getServerSet()::getConnection)) { assertThat(connection.getConnectedPort(), is(ldapServers[i].getListenPort())); - } finally { - if (connection != null) { - connection.close(); - } } } } } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/32190") public void testRoundRobinWithFailures() throws Exception { - assumeTrue("at least one ldap server should be present for this test", ldapServers.length > 1); + assumeTrue("at least two ldap servers should be present for this test", ldapServers.length > 1); logger.debug("using [{}] ldap servers, urls {}", ldapServers.length, ldapUrls()); TestSessionFactory testSessionFactory = createSessionFactory(LdapLoadBalancing.ROUND_ROBIN); // create a list of ports List ports = new ArrayList<>(numberOfLdapServers); - for (int i = 0; i < ldapServers.length; i++) { - ports.add(ldapServers[i].getListenPort()); + for (InMemoryDirectoryServer ldapServer : ldapServers) { + ports.add(ldapServer.getListenPort()); } logger.debug("list of all ports {}", ports); @@ -94,18 +97,19 @@ public void testRoundRobinWithFailures() throws Exception { // get a subset to kill final List ldapServersToKill = randomSubsetOf(numberToKill, ldapServers); final List ldapServersList = Arrays.asList(ldapServers); - final InetAddress local = InetAddress.getByName("localhost"); - final MockServerSocket mockServerSocket = new MockServerSocket(0, 0, local); + final MockServerSocket mockServerSocket = new MockServerSocket(0, 0); final List listenThreads = new ArrayList<>(); final CountDownLatch latch = new CountDownLatch(ldapServersToKill.size()); final CountDownLatch closeLatch = new CountDownLatch(1); try { + final AtomicBoolean success = new AtomicBoolean(true); + final List openMockSockets = Collections.synchronizedList(new ArrayList<>()); for (InMemoryDirectoryServer ldapServerToKill : ldapServersToKill) { final int index = ldapServersList.indexOf(ldapServerToKill); assertThat(index, greaterThanOrEqualTo(0)); - final Integer port = Integer.valueOf(ldapServers[index].getListenPort()); + final int port = ldapServers[index].getListenPort(); logger.debug("shutting down server index [{}] listening on [{}]", index, port); - assertTrue(ports.remove(port)); + assertTrue(ports.remove(Integer.valueOf(port))); ldapServers[index].shutDown(true); // when running multiple test jvms, there is a chance that something else could @@ -114,15 +118,58 @@ public void testRoundRobinWithFailures() throws Exception { // a mock server socket. // NOTE: this is not perfect as there is a small amount of time between the shutdown // of the ldap server and the opening of the socket - logger.debug("opening mock server socket listening on [{}]", port); + logger.debug("opening mock client sockets bound to [{}]", port); Runnable runnable = () -> { - try (Socket socket = openMockSocket(local, mockServerSocket.getLocalPort(), local, port)) { - logger.debug("opened socket [{}]", socket); - latch.countDown(); + final List openedSockets = new ArrayList<>(); + final List blacklistedAddress = new ArrayList<>(); + try { + final boolean allSocketsOpened = awaitBusy(() -> { + try { + final List inetAddressesToBind = Arrays.stream(InetAddressHelper.getAllAddresses()) + .filter(addr -> openedSockets.stream().noneMatch(s -> addr.equals(s.getLocalAddress()))) + .filter(addr -> blacklistedAddress.contains(addr) == false) + .collect(Collectors.toList()); + for (InetAddress localAddress : inetAddressesToBind) { + try { + Socket socket = openMockSocket(mockServerSocket.getInetAddress(), mockServerSocket.getLocalPort(), + localAddress, port); + openedSockets.add(socket); + openMockSockets.add(socket); + logger.debug("opened socket [{}]", socket); + } catch (NoRouteToHostException e) { + logger.debug(new ParameterizedMessage("blacklisting address [{}] due to:", localAddress), e); + blacklistedAddress.add(localAddress); + } + } + return true; + } catch (IOException e) { + logger.debug(new ParameterizedMessage("caught exception while opening socket on [{}]", port), e); + return false; + } + }); + + if (allSocketsOpened) { + latch.countDown(); + } else { + success.set(false); + IOUtils.closeWhileHandlingException(openedSockets); + openedSockets.clear(); + latch.countDown(); + return; + } + } catch (InterruptedException e) { + logger.debug(new ParameterizedMessage("interrupted while trying to open sockets on [{}]", port), e); + Thread.currentThread().interrupt(); + } + + try { closeLatch.await(); - logger.debug("closing socket [{}]", socket); - } catch (IOException | InterruptedException e) { - logger.debug("caught exception", e); + } catch (InterruptedException e) { + logger.debug("caught exception while waiting for close latch", e); + Thread.currentThread().interrupt(); + } finally { + logger.debug("closing sockets on [{}]", port); + IOUtils.closeWhileHandlingException(openedSockets); } }; Thread thread = new Thread(runnable); @@ -133,14 +180,37 @@ public void testRoundRobinWithFailures() throws Exception { } latch.await(); + + assumeTrue("Failed to open sockets on all addresses with the port that an LDAP server was bound to. Some operating systems " + + "allow binding to an address and port combination even if an application is bound to the port on a wildcard address", + success.get()); final int numberOfIterations = randomIntBetween(1, 5); + logger.debug("list of all open ports {}", ports); // go one iteration through and attempt a bind for (int iteration = 0; iteration < numberOfIterations; iteration++) { logger.debug("iteration [{}]", iteration); for (Integer port : ports) { logger.debug("attempting connection with expected port [{}]", port); - try (LDAPConnection connection = LdapUtils.privilegedConnect(testSessionFactory.getServerSet()::getConnection)) { + LDAPConnection connection = null; + try { + do { + final LDAPConnection finalConnection = + LdapUtils.privilegedConnect(testSessionFactory.getServerSet()::getConnection); + connection = finalConnection; + logger.debug("established connection with port [{}] expected port [{}]", + finalConnection.getConnectedPort(), port); + if (finalConnection.getConnectedPort() != port) { + LDAPException e = expectThrows(LDAPException.class, () -> finalConnection.bind(new SimpleBindRequest())); + assertThat(e.getMessage(), containsString("not connected")); + finalConnection.close(); + } + } while (connection.getConnectedPort() != port); + assertThat(connection.getConnectedPort(), is(port)); + } finally { + if (connection != null) { + connection.close(); + } } } } @@ -160,76 +230,163 @@ private MockSocket openMockSocket(InetAddress remoteAddress, int remotePort, Ine socket.setReuseAddress(true); // allow binding even if the previous socket is in timed wait state. socket.setSoLinger(true, 0); // close immediately as we are not writing anything here. socket.bind(new InetSocketAddress(localAddress, localPort)); - SocketAccess.doPrivileged(() -> socket.connect(new InetSocketAddress(localAddress, remotePort))); + SocketAccess.doPrivileged(() -> socket.connect(new InetSocketAddress(remoteAddress, remotePort))); return socket; } public void testFailover() throws Exception { - assumeTrue("at least one ldap server should be present for this test", ldapServers.length > 1); + assumeTrue("at least two ldap servers should be present for this test", ldapServers.length > 1); logger.debug("using [{}] ldap servers, urls {}", ldapServers.length, ldapUrls()); TestSessionFactory testSessionFactory = createSessionFactory(LdapLoadBalancing.FAILOVER); // first test that there is no round robin stuff going on final int firstPort = ldapServers[0].getListenPort(); for (int i = 0; i < numberOfLdapServers; i++) { - LDAPConnection connection = null; - try { - connection = LdapUtils.privilegedConnect(testSessionFactory.getServerSet()::getConnection); + try (LDAPConnection connection = LdapUtils.privilegedConnect(testSessionFactory.getServerSet()::getConnection)) { assertThat(connection.getConnectedPort(), is(firstPort)); - } finally { - if (connection != null) { - connection.close(); - } } } - logger.debug("shutting down server index [0] listening on [{}]", ldapServers[0].getListenPort()); - // always kill the first one - ldapServers[0].shutDown(true); - assertThat(ldapServers[0].getListenPort(), is(-1)); - - // now randomly shutdown some others + // we need at least one good server. Hence the upper bound is number - 2 since we need at least + // one server to use! + InMemoryDirectoryServer[] allButFirstServer = Arrays.copyOfRange(ldapServers, 1, ldapServers.length); + final List ldapServersToKill; + int numberToKill = 1; if (ldapServers.length > 2) { - // kill at least one other server, but we need at least one good one. Hence the upper bound is number - 2 since we need at least - // one server to use! - final int numberToKill = randomIntBetween(1, numberOfLdapServers - 2); - InMemoryDirectoryServer[] allButFirstServer = Arrays.copyOfRange(ldapServers, 1, ldapServers.length); - // get a subset to kil - final List ldapServersToKill = randomSubsetOf(numberToKill, allButFirstServer); - final List ldapServersList = Arrays.asList(ldapServers); - for (InMemoryDirectoryServer ldapServerToKill : ldapServersToKill) { - final int index = ldapServersList.indexOf(ldapServerToKill); - assertThat(index, greaterThanOrEqualTo(1)); - final Integer port = Integer.valueOf(ldapServers[index].getListenPort()); - logger.debug("shutting down server index [{}] listening on [{}]", index, port); - ldapServers[index].shutDown(true); - assertThat(ldapServers[index].getListenPort(), is(-1)); - } + numberToKill = randomIntBetween(1, numberOfLdapServers - 2); + ldapServersToKill = randomSubsetOf(numberToKill, allButFirstServer); + numberToKill++; // since we kill the first automatically + ldapServersToKill.add(ldapServers[0]); // always kill the first one + } else { + ldapServersToKill = Collections.singletonList(ldapServers[0]); } + final List ldapServersList = Arrays.asList(ldapServers); + final MockServerSocket mockServerSocket = new MockServerSocket(0, 0); + final List listenThreads = new ArrayList<>(); + final CountDownLatch latch = new CountDownLatch(ldapServersToKill.size()); + final CountDownLatch closeLatch = new CountDownLatch(1); + final List openMockSockets = Collections.synchronizedList(new ArrayList<>()); + final AtomicBoolean success = new AtomicBoolean(true); + for (InMemoryDirectoryServer ldapServerToKill : ldapServersToKill) { + final int index = ldapServersList.indexOf(ldapServerToKill); + final int port = ldapServers[index].getListenPort(); + logger.debug("shutting down server index [{}] listening on [{}]", index, port); + ldapServers[index].shutDown(true); - int firstNonStoppedPort = -1; - // now we find the first that isn't stopped - for (int i = 0; i < numberOfLdapServers; i++) { - if (ldapServers[i].getListenPort() != -1) { - firstNonStoppedPort = ldapServers[i].getListenPort(); - break; - } + // when running multiple test jvms, there is a chance that something else could + // start listening on this port so we try to avoid this by creating a local socket + // that will be bound to the port the ldap server was running on and connecting to + // a mock server socket. + // NOTE: this is not perfect as there is a small amount of time between the shutdown + // of the ldap server and the opening of the socket + logger.debug("opening mock server socket listening on [{}]", port); + Runnable runnable = () -> { + final List openedSockets = new ArrayList<>(); + final List blacklistedAddress = new ArrayList<>(); + try { + final boolean allSocketsOpened = awaitBusy(() -> { + try { + final List inetAddressesToBind = Arrays.stream(InetAddressHelper.getAllAddresses()) + .filter(addr -> openedSockets.stream().noneMatch(s -> addr.equals(s.getLocalAddress()))) + .filter(addr -> blacklistedAddress.contains(addr) == false) + .collect(Collectors.toList()); + for (InetAddress localAddress : inetAddressesToBind) { + try { + Socket socket = openMockSocket(mockServerSocket.getInetAddress(), mockServerSocket.getLocalPort(), + localAddress, port); + openedSockets.add(socket); + openMockSockets.add(socket); + logger.debug("opened socket [{}]", socket); + } catch (NoRouteToHostException e) { + logger.debug(new ParameterizedMessage("blacklisting address [{}] due to:", localAddress), e); + blacklistedAddress.add(localAddress); + } + } + return true; + } catch (IOException e) { + logger.debug(new ParameterizedMessage("caught exception while opening socket on [{}]", port), e); + return false; + } + }); + + if (allSocketsOpened) { + latch.countDown(); + } else { + success.set(false); + IOUtils.closeWhileHandlingException(openedSockets); + openedSockets.clear(); + latch.countDown(); + return; + } + } catch (InterruptedException e) { + logger.debug(new ParameterizedMessage("interrupted while trying to open sockets on [{}]", port), e); + Thread.currentThread().interrupt(); + } + + try { + closeLatch.await(); + } catch (InterruptedException e) { + logger.debug("caught exception while waiting for close latch", e); + Thread.currentThread().interrupt(); + } finally { + logger.debug("closing sockets on [{}]", port); + IOUtils.closeWhileHandlingException(openedSockets); + } + }; + Thread thread = new Thread(runnable); + thread.start(); + listenThreads.add(thread); + + assertThat(ldapServers[index].getListenPort(), is(-1)); } - logger.debug("first non stopped port [{}]", firstNonStoppedPort); - assertThat(firstNonStoppedPort, not(-1)); - final int numberOfIterations = randomIntBetween(1, 5); - for (int iteration = 0; iteration < numberOfIterations; iteration++) { - LDAPConnection connection = null; - try { + try { + latch.await(); + + assumeTrue("Failed to open sockets on all addresses with the port that an LDAP server was bound to. Some operating systems " + + "allow binding to an address and port combination even if an application is bound to the port on a wildcard address", + success.get()); + int firstNonStoppedPort = -1; + // now we find the first that isn't stopped + for (int i = 0; i < numberOfLdapServers; i++) { + if (ldapServers[i].getListenPort() != -1) { + firstNonStoppedPort = ldapServers[i].getListenPort(); + break; + } + } + logger.debug("first non stopped port [{}]", firstNonStoppedPort); + assertThat(firstNonStoppedPort, not(-1)); + final int numberOfIterations = randomIntBetween(1, 5); + for (int iteration = 0; iteration < numberOfIterations; iteration++) { logger.debug("attempting connection with expected port [{}] iteration [{}]", firstNonStoppedPort, iteration); - connection = LdapUtils.privilegedConnect(testSessionFactory.getServerSet()::getConnection); - assertThat(connection.getConnectedPort(), is(firstNonStoppedPort)); - } finally { - if (connection != null) { - connection.close(); + LDAPConnection connection = null; + try { + do { + final LDAPConnection finalConnection = + LdapUtils.privilegedConnect(testSessionFactory.getServerSet()::getConnection); + connection = finalConnection; + logger.debug("established connection with port [{}] expected port [{}]", + finalConnection.getConnectedPort(), firstNonStoppedPort); + if (finalConnection.getConnectedPort() != firstNonStoppedPort) { + LDAPException e = expectThrows(LDAPException.class, () -> finalConnection.bind(new SimpleBindRequest())); + assertThat(e.getMessage(), containsString("not connected")); + finalConnection.close(); + } + } while (connection.getConnectedPort() != firstNonStoppedPort); + + assertThat(connection.getConnectedPort(), is(firstNonStoppedPort)); + } finally { + if (connection != null) { + connection.close(); + } } } + } finally { + closeLatch.countDown(); + mockServerSocket.close(); + for (Thread t : listenThreads) { + t.join(); + } } } From 03cd8f969ee6b9e5c9fe8b3a60f1546fd9238d7f Mon Sep 17 00:00:00 2001 From: jaymode Date: Wed, 20 Feb 2019 09:34:49 -0700 Subject: [PATCH 2/2] address review feedback --- .../authc/ldap/support/LdapTestCase.java | 3 +- .../SessionFactoryLoadBalancingTests.java | 189 ++++++++---------- 2 files changed, 79 insertions(+), 113 deletions(-) diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/support/LdapTestCase.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/support/LdapTestCase.java index a76ee71d114b3..957167e60d281 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/support/LdapTestCase.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/support/LdapTestCase.java @@ -92,8 +92,7 @@ public void startLdap() throws Exception { context.init(new KeyManager[] { keyManager }, null, null); SSLServerSocketFactory serverSocketFactory = context.getServerSocketFactory(); SSLSocketFactory clientSocketFactory = context.getSocketFactory(); - listeners.add(InMemoryListenerConfig.createLDAPSConfig("ldaps", InetAddress.getLoopbackAddress(), 0, - serverSocketFactory, clientSocketFactory)); + listeners.add(InMemoryListenerConfig.createLDAPSConfig("ldaps", null, 0, serverSocketFactory, clientSocketFactory)); } serverConfig.setListenerConfigs(listeners); InMemoryDirectoryServer ldapServer = new InMemoryDirectoryServer(serverConfig); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/support/SessionFactoryLoadBalancingTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/support/SessionFactoryLoadBalancingTests.java index bef06d4c9e203..cd159f69c486c 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/support/SessionFactoryLoadBalancingTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/support/SessionFactoryLoadBalancingTests.java @@ -103,7 +103,6 @@ public void testRoundRobinWithFailures() throws Exception { final CountDownLatch closeLatch = new CountDownLatch(1); try { final AtomicBoolean success = new AtomicBoolean(true); - final List openMockSockets = Collections.synchronizedList(new ArrayList<>()); for (InMemoryDirectoryServer ldapServerToKill : ldapServersToKill) { final int index = ldapServersList.indexOf(ldapServerToKill); assertThat(index, greaterThanOrEqualTo(0)); @@ -119,59 +118,8 @@ public void testRoundRobinWithFailures() throws Exception { // NOTE: this is not perfect as there is a small amount of time between the shutdown // of the ldap server and the opening of the socket logger.debug("opening mock client sockets bound to [{}]", port); - Runnable runnable = () -> { - final List openedSockets = new ArrayList<>(); - final List blacklistedAddress = new ArrayList<>(); - try { - final boolean allSocketsOpened = awaitBusy(() -> { - try { - final List inetAddressesToBind = Arrays.stream(InetAddressHelper.getAllAddresses()) - .filter(addr -> openedSockets.stream().noneMatch(s -> addr.equals(s.getLocalAddress()))) - .filter(addr -> blacklistedAddress.contains(addr) == false) - .collect(Collectors.toList()); - for (InetAddress localAddress : inetAddressesToBind) { - try { - Socket socket = openMockSocket(mockServerSocket.getInetAddress(), mockServerSocket.getLocalPort(), - localAddress, port); - openedSockets.add(socket); - openMockSockets.add(socket); - logger.debug("opened socket [{}]", socket); - } catch (NoRouteToHostException e) { - logger.debug(new ParameterizedMessage("blacklisting address [{}] due to:", localAddress), e); - blacklistedAddress.add(localAddress); - } - } - return true; - } catch (IOException e) { - logger.debug(new ParameterizedMessage("caught exception while opening socket on [{}]", port), e); - return false; - } - }); - - if (allSocketsOpened) { - latch.countDown(); - } else { - success.set(false); - IOUtils.closeWhileHandlingException(openedSockets); - openedSockets.clear(); - latch.countDown(); - return; - } - } catch (InterruptedException e) { - logger.debug(new ParameterizedMessage("interrupted while trying to open sockets on [{}]", port), e); - Thread.currentThread().interrupt(); - } - - try { - closeLatch.await(); - } catch (InterruptedException e) { - logger.debug("caught exception while waiting for close latch", e); - Thread.currentThread().interrupt(); - } finally { - logger.debug("closing sockets on [{}]", port); - IOUtils.closeWhileHandlingException(openedSockets); - } - }; + Runnable runnable = new PortBlockingRunnable(mockServerSocket.getInetAddress(), mockServerSocket.getLocalPort(), port, + latch, closeLatch, success); Thread thread = new Thread(runnable); thread.start(); listenThreads.add(thread); @@ -251,11 +199,9 @@ public void testFailover() throws Exception { // one server to use! InMemoryDirectoryServer[] allButFirstServer = Arrays.copyOfRange(ldapServers, 1, ldapServers.length); final List ldapServersToKill; - int numberToKill = 1; if (ldapServers.length > 2) { - numberToKill = randomIntBetween(1, numberOfLdapServers - 2); + final int numberToKill = randomIntBetween(1, numberOfLdapServers - 2); ldapServersToKill = randomSubsetOf(numberToKill, allButFirstServer); - numberToKill++; // since we kill the first automatically ldapServersToKill.add(ldapServers[0]); // always kill the first one } else { ldapServersToKill = Collections.singletonList(ldapServers[0]); @@ -265,7 +211,6 @@ public void testFailover() throws Exception { final List listenThreads = new ArrayList<>(); final CountDownLatch latch = new CountDownLatch(ldapServersToKill.size()); final CountDownLatch closeLatch = new CountDownLatch(1); - final List openMockSockets = Collections.synchronizedList(new ArrayList<>()); final AtomicBoolean success = new AtomicBoolean(true); for (InMemoryDirectoryServer ldapServerToKill : ldapServersToKill) { final int index = ldapServersList.indexOf(ldapServerToKill); @@ -280,59 +225,8 @@ public void testFailover() throws Exception { // NOTE: this is not perfect as there is a small amount of time between the shutdown // of the ldap server and the opening of the socket logger.debug("opening mock server socket listening on [{}]", port); - Runnable runnable = () -> { - final List openedSockets = new ArrayList<>(); - final List blacklistedAddress = new ArrayList<>(); - try { - final boolean allSocketsOpened = awaitBusy(() -> { - try { - final List inetAddressesToBind = Arrays.stream(InetAddressHelper.getAllAddresses()) - .filter(addr -> openedSockets.stream().noneMatch(s -> addr.equals(s.getLocalAddress()))) - .filter(addr -> blacklistedAddress.contains(addr) == false) - .collect(Collectors.toList()); - for (InetAddress localAddress : inetAddressesToBind) { - try { - Socket socket = openMockSocket(mockServerSocket.getInetAddress(), mockServerSocket.getLocalPort(), - localAddress, port); - openedSockets.add(socket); - openMockSockets.add(socket); - logger.debug("opened socket [{}]", socket); - } catch (NoRouteToHostException e) { - logger.debug(new ParameterizedMessage("blacklisting address [{}] due to:", localAddress), e); - blacklistedAddress.add(localAddress); - } - } - return true; - } catch (IOException e) { - logger.debug(new ParameterizedMessage("caught exception while opening socket on [{}]", port), e); - return false; - } - }); - - if (allSocketsOpened) { - latch.countDown(); - } else { - success.set(false); - IOUtils.closeWhileHandlingException(openedSockets); - openedSockets.clear(); - latch.countDown(); - return; - } - } catch (InterruptedException e) { - logger.debug(new ParameterizedMessage("interrupted while trying to open sockets on [{}]", port), e); - Thread.currentThread().interrupt(); - } - - try { - closeLatch.await(); - } catch (InterruptedException e) { - logger.debug("caught exception while waiting for close latch", e); - Thread.currentThread().interrupt(); - } finally { - logger.debug("closing sockets on [{}]", port); - IOUtils.closeWhileHandlingException(openedSockets); - } - }; + Runnable runnable = new PortBlockingRunnable(mockServerSocket.getInetAddress(), mockServerSocket.getLocalPort(), port, + latch, closeLatch, success); Thread thread = new Thread(runnable); thread.start(); listenThreads.add(thread); @@ -402,6 +296,79 @@ private TestSessionFactory createSessionFactory(LdapLoadBalancing loadBalancing) threadPool); } + private class PortBlockingRunnable implements Runnable { + + private final InetAddress serverAddress; + private final int serverPort; + private final int portToBind; + private final CountDownLatch latch; + private final CountDownLatch closeLatch; + private final AtomicBoolean success; + + private PortBlockingRunnable(InetAddress serverAddress, int serverPort, int portToBind, CountDownLatch latch, + CountDownLatch closeLatch, AtomicBoolean success) { + this.serverAddress = serverAddress; + this.serverPort = serverPort; + this.portToBind = portToBind; + this.latch = latch; + this.closeLatch = closeLatch; + this.success = success; + } + + @Override + public void run() { + final List openedSockets = new ArrayList<>(); + final List blacklistedAddress = new ArrayList<>(); + try { + final boolean allSocketsOpened = awaitBusy(() -> { + try { + final List inetAddressesToBind = Arrays.stream(InetAddressHelper.getAllAddresses()) + .filter(addr -> openedSockets.stream().noneMatch(s -> addr.equals(s.getLocalAddress()))) + .filter(addr -> blacklistedAddress.contains(addr) == false) + .collect(Collectors.toList()); + for (InetAddress localAddress : inetAddressesToBind) { + try { + final Socket socket = openMockSocket(serverAddress, serverPort, localAddress, portToBind); + openedSockets.add(socket); + logger.debug("opened socket [{}]", socket); + } catch (NoRouteToHostException e) { + logger.debug(new ParameterizedMessage("blacklisting address [{}] due to:", localAddress), e); + blacklistedAddress.add(localAddress); + } + } + return true; + } catch (IOException e) { + logger.debug(new ParameterizedMessage("caught exception while opening socket on [{}]", portToBind), e); + return false; + } + }); + + if (allSocketsOpened) { + latch.countDown(); + } else { + success.set(false); + IOUtils.closeWhileHandlingException(openedSockets); + openedSockets.clear(); + latch.countDown(); + return; + } + } catch (InterruptedException e) { + logger.debug(new ParameterizedMessage("interrupted while trying to open sockets on [{}]", portToBind), e); + Thread.currentThread().interrupt(); + } + + try { + closeLatch.await(); + } catch (InterruptedException e) { + logger.debug("caught exception while waiting for close latch", e); + Thread.currentThread().interrupt(); + } finally { + logger.debug("closing sockets on [{}]", portToBind); + IOUtils.closeWhileHandlingException(openedSockets); + } + } + } + static class TestSessionFactory extends SessionFactory { protected TestSessionFactory(RealmConfig config, SSLService sslService, ThreadPool threadPool) {