diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/LdapSessionFactoryTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/LdapSessionFactoryTests.java index 106eaa8932629..9867cc29fd3da 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/LdapSessionFactoryTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/LdapSessionFactoryTests.java @@ -9,6 +9,7 @@ import com.unboundid.ldap.sdk.LDAPException; import com.unboundid.ldap.sdk.LDAPURL; import com.unboundid.ldap.sdk.SimpleBindRequest; +import org.elasticsearch.common.network.NetworkAddress; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ThreadContext; @@ -29,6 +30,7 @@ import org.junit.After; import org.junit.Before; +import java.net.InetAddress; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.StandardCopyOption; @@ -73,7 +75,12 @@ public void shutdown() throws InterruptedException { public void testBindWithReadTimeout() throws Exception { InMemoryDirectoryServer ldapServer = randomFrom(ldapServers); String protocol = randomFrom("ldap", "ldaps"); - String ldapUrl = new LDAPURL(protocol, "localhost", ldapServer.getListenPort(protocol), null, null, null, null).toString(); + InetAddress listenAddress = ldapServer.getListenAddress(protocol); + if (listenAddress == null) { + listenAddress = InetAddress.getLoopbackAddress(); + } + String ldapUrl = new LDAPURL(protocol, NetworkAddress.format(listenAddress), ldapServer.getListenPort(protocol), + null, null, null, null).toString(); String groupSearchBase = "o=sevenSeas"; String userTemplates = "cn={0},ou=people,o=sevenSeas"; @@ -233,7 +240,12 @@ public void testGroupLookupBase() throws Exception { */ public void testSslTrustIsReloaded() throws Exception { InMemoryDirectoryServer ldapServer = randomFrom(ldapServers); - String ldapUrl = new LDAPURL("ldaps", "localhost", ldapServer.getListenPort("ldaps"), null, null, null, null).toString(); + InetAddress listenAddress = ldapServer.getListenAddress("ldaps"); + if (listenAddress == null) { + listenAddress = InetAddress.getLoopbackAddress(); + } + String ldapUrl = new LDAPURL("ldaps", NetworkAddress.format(listenAddress), ldapServer.getListenPort("ldaps"), + null, null, null, null).toString(); String groupSearchBase = "o=sevenSeas"; String userTemplates = "cn={0},ou=people,o=sevenSeas"; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/support/LdapTestCase.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/support/LdapTestCase.java index 2c0b2f7716650..957167e60d281 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/support/LdapTestCase.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/support/LdapTestCase.java @@ -18,6 +18,7 @@ import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.network.NetworkAddress; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; @@ -25,6 +26,7 @@ import org.elasticsearch.env.TestEnvironment; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.watcher.ResourceWatcherService; +import org.elasticsearch.xpack.core.XPackSettings; import org.elasticsearch.xpack.core.security.authc.RealmConfig; import org.elasticsearch.xpack.core.security.authc.ldap.LdapSessionFactorySettings; import org.elasticsearch.xpack.core.security.authc.ldap.SearchGroupsResolverSettings; @@ -46,6 +48,7 @@ import javax.net.ssl.SSLServerSocketFactory; import javax.net.ssl.SSLSocketFactory; import javax.net.ssl.X509ExtendedKeyManager; +import java.net.InetAddress; import java.security.AccessController; import java.security.KeyStore; import java.security.PrivilegedAction; @@ -76,7 +79,7 @@ public void startLdap() throws Exception { for (int i = 0; i < numberOfLdapServers; i++) { InMemoryDirectoryServerConfig serverConfig = new InMemoryDirectoryServerConfig("o=sevenSeas"); List listeners = new ArrayList<>(2); - listeners.add(InMemoryListenerConfig.createLDAPConfig("ldap")); + listeners.add(InMemoryListenerConfig.createLDAPConfig("ldap", null, 0, null)); if (openLdapsPort()) { final char[] ldapPassword = "ldap-password".toCharArray(); final KeyStore ks = CertParsingUtils.getKeyStoreFromPEM( @@ -85,7 +88,7 @@ public void startLdap() throws Exception { ldapPassword ); X509ExtendedKeyManager keyManager = CertParsingUtils.keyManager(ks, ldapPassword, KeyManagerFactory.getDefaultAlgorithm()); - final SSLContext context = SSLContext.getInstance("TLSv1.2"); + final SSLContext context = SSLContext.getInstance(XPackSettings.DEFAULT_SUPPORTED_PROTOCOLS.get(0)); context.init(new KeyManager[] { keyManager }, null, null); SSLServerSocketFactory serverSocketFactory = context.getServerSocketFactory(); SSLSocketFactory clientSocketFactory = context.getSocketFactory(); @@ -111,7 +114,7 @@ protected boolean openLdapsPort() { } @After - public void stopLdap() throws Exception { + public void stopLdap() { for (int i = 0; i < numberOfLdapServers; i++) { ldapServers[i].shutDown(true); } @@ -120,7 +123,11 @@ public void stopLdap() throws Exception { protected String[] ldapUrls() throws LDAPException { List urls = new ArrayList<>(numberOfLdapServers); for (int i = 0; i < numberOfLdapServers; i++) { - LDAPURL url = new LDAPURL("ldap", "localhost", ldapServers[i].getListenPort(), null, null, null, null); + InetAddress listenAddress = ldapServers[i].getListenAddress(); + if (listenAddress == null) { + listenAddress = InetAddress.getLoopbackAddress(); + } + LDAPURL url = new LDAPURL("ldap", NetworkAddress.format(listenAddress), ldapServers[i].getListenPort(), null, null, null, null); urls.add(url.toString()); } return urls.toArray(Strings.EMPTY_ARRAY); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/support/SessionFactoryLoadBalancingTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/support/SessionFactoryLoadBalancingTests.java index 1483b2f474bf0..cd159f69c486c 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/support/SessionFactoryLoadBalancingTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/ldap/support/SessionFactoryLoadBalancingTests.java @@ -7,11 +7,16 @@ import com.unboundid.ldap.listener.InMemoryDirectoryServer; import com.unboundid.ldap.sdk.LDAPConnection; +import com.unboundid.ldap.sdk.LDAPException; +import com.unboundid.ldap.sdk.SimpleBindRequest; +import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.SuppressForbidden; +import org.elasticsearch.common.network.InetAddressHelper; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.core.internal.io.IOUtils; import org.elasticsearch.env.TestEnvironment; import org.elasticsearch.mocksocket.MockServerSocket; import org.elasticsearch.mocksocket.MockSocket; @@ -28,12 +33,17 @@ import java.io.IOException; import java.net.InetAddress; import java.net.InetSocketAddress; +import java.net.NoRouteToHostException; import java.net.Socket; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Collectors; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; @@ -52,7 +62,7 @@ public void init() throws Exception { } @After - public void shutdown() throws InterruptedException { + public void shutdown() { terminate(threadPool); } @@ -62,29 +72,22 @@ public void testRoundRobin() throws Exception { final int numberOfIterations = randomIntBetween(1, 5); for (int iteration = 0; iteration < numberOfIterations; iteration++) { for (int i = 0; i < numberOfLdapServers; i++) { - LDAPConnection connection = null; - try { - connection = LdapUtils.privilegedConnect(testSessionFactory.getServerSet()::getConnection); + try (LDAPConnection connection = LdapUtils.privilegedConnect(testSessionFactory.getServerSet()::getConnection)) { assertThat(connection.getConnectedPort(), is(ldapServers[i].getListenPort())); - } finally { - if (connection != null) { - connection.close(); - } } } } } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/32190") public void testRoundRobinWithFailures() throws Exception { - assumeTrue("at least one ldap server should be present for this test", ldapServers.length > 1); + assumeTrue("at least two ldap servers should be present for this test", ldapServers.length > 1); logger.debug("using [{}] ldap servers, urls {}", ldapServers.length, ldapUrls()); TestSessionFactory testSessionFactory = createSessionFactory(LdapLoadBalancing.ROUND_ROBIN); // create a list of ports List ports = new ArrayList<>(numberOfLdapServers); - for (int i = 0; i < ldapServers.length; i++) { - ports.add(ldapServers[i].getListenPort()); + for (InMemoryDirectoryServer ldapServer : ldapServers) { + ports.add(ldapServer.getListenPort()); } logger.debug("list of all ports {}", ports); @@ -94,18 +97,18 @@ public void testRoundRobinWithFailures() throws Exception { // get a subset to kill final List ldapServersToKill = randomSubsetOf(numberToKill, ldapServers); final List ldapServersList = Arrays.asList(ldapServers); - final InetAddress local = InetAddress.getByName("localhost"); - final MockServerSocket mockServerSocket = new MockServerSocket(0, 0, local); + final MockServerSocket mockServerSocket = new MockServerSocket(0, 0); final List listenThreads = new ArrayList<>(); final CountDownLatch latch = new CountDownLatch(ldapServersToKill.size()); final CountDownLatch closeLatch = new CountDownLatch(1); try { + final AtomicBoolean success = new AtomicBoolean(true); for (InMemoryDirectoryServer ldapServerToKill : ldapServersToKill) { final int index = ldapServersList.indexOf(ldapServerToKill); assertThat(index, greaterThanOrEqualTo(0)); - final Integer port = Integer.valueOf(ldapServers[index].getListenPort()); + final int port = ldapServers[index].getListenPort(); logger.debug("shutting down server index [{}] listening on [{}]", index, port); - assertTrue(ports.remove(port)); + assertTrue(ports.remove(Integer.valueOf(port))); ldapServers[index].shutDown(true); // when running multiple test jvms, there is a chance that something else could @@ -114,17 +117,9 @@ public void testRoundRobinWithFailures() throws Exception { // a mock server socket. // NOTE: this is not perfect as there is a small amount of time between the shutdown // of the ldap server and the opening of the socket - logger.debug("opening mock server socket listening on [{}]", port); - Runnable runnable = () -> { - try (Socket socket = openMockSocket(local, mockServerSocket.getLocalPort(), local, port)) { - logger.debug("opened socket [{}]", socket); - latch.countDown(); - closeLatch.await(); - logger.debug("closing socket [{}]", socket); - } catch (IOException | InterruptedException e) { - logger.debug("caught exception", e); - } - }; + logger.debug("opening mock client sockets bound to [{}]", port); + Runnable runnable = new PortBlockingRunnable(mockServerSocket.getInetAddress(), mockServerSocket.getLocalPort(), port, + latch, closeLatch, success); Thread thread = new Thread(runnable); thread.start(); listenThreads.add(thread); @@ -133,14 +128,37 @@ public void testRoundRobinWithFailures() throws Exception { } latch.await(); + + assumeTrue("Failed to open sockets on all addresses with the port that an LDAP server was bound to. Some operating systems " + + "allow binding to an address and port combination even if an application is bound to the port on a wildcard address", + success.get()); final int numberOfIterations = randomIntBetween(1, 5); + logger.debug("list of all open ports {}", ports); // go one iteration through and attempt a bind for (int iteration = 0; iteration < numberOfIterations; iteration++) { logger.debug("iteration [{}]", iteration); for (Integer port : ports) { logger.debug("attempting connection with expected port [{}]", port); - try (LDAPConnection connection = LdapUtils.privilegedConnect(testSessionFactory.getServerSet()::getConnection)) { + LDAPConnection connection = null; + try { + do { + final LDAPConnection finalConnection = + LdapUtils.privilegedConnect(testSessionFactory.getServerSet()::getConnection); + connection = finalConnection; + logger.debug("established connection with port [{}] expected port [{}]", + finalConnection.getConnectedPort(), port); + if (finalConnection.getConnectedPort() != port) { + LDAPException e = expectThrows(LDAPException.class, () -> finalConnection.bind(new SimpleBindRequest())); + assertThat(e.getMessage(), containsString("not connected")); + finalConnection.close(); + } + } while (connection.getConnectedPort() != port); + assertThat(connection.getConnectedPort(), is(port)); + } finally { + if (connection != null) { + connection.close(); + } } } } @@ -160,76 +178,109 @@ private MockSocket openMockSocket(InetAddress remoteAddress, int remotePort, Ine socket.setReuseAddress(true); // allow binding even if the previous socket is in timed wait state. socket.setSoLinger(true, 0); // close immediately as we are not writing anything here. socket.bind(new InetSocketAddress(localAddress, localPort)); - SocketAccess.doPrivileged(() -> socket.connect(new InetSocketAddress(localAddress, remotePort))); + SocketAccess.doPrivileged(() -> socket.connect(new InetSocketAddress(remoteAddress, remotePort))); return socket; } public void testFailover() throws Exception { - assumeTrue("at least one ldap server should be present for this test", ldapServers.length > 1); + assumeTrue("at least two ldap servers should be present for this test", ldapServers.length > 1); logger.debug("using [{}] ldap servers, urls {}", ldapServers.length, ldapUrls()); TestSessionFactory testSessionFactory = createSessionFactory(LdapLoadBalancing.FAILOVER); // first test that there is no round robin stuff going on final int firstPort = ldapServers[0].getListenPort(); for (int i = 0; i < numberOfLdapServers; i++) { - LDAPConnection connection = null; - try { - connection = LdapUtils.privilegedConnect(testSessionFactory.getServerSet()::getConnection); + try (LDAPConnection connection = LdapUtils.privilegedConnect(testSessionFactory.getServerSet()::getConnection)) { assertThat(connection.getConnectedPort(), is(firstPort)); - } finally { - if (connection != null) { - connection.close(); - } } } - logger.debug("shutting down server index [0] listening on [{}]", ldapServers[0].getListenPort()); - // always kill the first one - ldapServers[0].shutDown(true); - assertThat(ldapServers[0].getListenPort(), is(-1)); - - // now randomly shutdown some others + // we need at least one good server. Hence the upper bound is number - 2 since we need at least + // one server to use! + InMemoryDirectoryServer[] allButFirstServer = Arrays.copyOfRange(ldapServers, 1, ldapServers.length); + final List ldapServersToKill; if (ldapServers.length > 2) { - // kill at least one other server, but we need at least one good one. Hence the upper bound is number - 2 since we need at least - // one server to use! final int numberToKill = randomIntBetween(1, numberOfLdapServers - 2); - InMemoryDirectoryServer[] allButFirstServer = Arrays.copyOfRange(ldapServers, 1, ldapServers.length); - // get a subset to kil - final List ldapServersToKill = randomSubsetOf(numberToKill, allButFirstServer); - final List ldapServersList = Arrays.asList(ldapServers); - for (InMemoryDirectoryServer ldapServerToKill : ldapServersToKill) { - final int index = ldapServersList.indexOf(ldapServerToKill); - assertThat(index, greaterThanOrEqualTo(1)); - final Integer port = Integer.valueOf(ldapServers[index].getListenPort()); - logger.debug("shutting down server index [{}] listening on [{}]", index, port); - ldapServers[index].shutDown(true); - assertThat(ldapServers[index].getListenPort(), is(-1)); - } + ldapServersToKill = randomSubsetOf(numberToKill, allButFirstServer); + ldapServersToKill.add(ldapServers[0]); // always kill the first one + } else { + ldapServersToKill = Collections.singletonList(ldapServers[0]); } + final List ldapServersList = Arrays.asList(ldapServers); + final MockServerSocket mockServerSocket = new MockServerSocket(0, 0); + final List listenThreads = new ArrayList<>(); + final CountDownLatch latch = new CountDownLatch(ldapServersToKill.size()); + final CountDownLatch closeLatch = new CountDownLatch(1); + final AtomicBoolean success = new AtomicBoolean(true); + for (InMemoryDirectoryServer ldapServerToKill : ldapServersToKill) { + final int index = ldapServersList.indexOf(ldapServerToKill); + final int port = ldapServers[index].getListenPort(); + logger.debug("shutting down server index [{}] listening on [{}]", index, port); + ldapServers[index].shutDown(true); - int firstNonStoppedPort = -1; - // now we find the first that isn't stopped - for (int i = 0; i < numberOfLdapServers; i++) { - if (ldapServers[i].getListenPort() != -1) { - firstNonStoppedPort = ldapServers[i].getListenPort(); - break; - } + // when running multiple test jvms, there is a chance that something else could + // start listening on this port so we try to avoid this by creating a local socket + // that will be bound to the port the ldap server was running on and connecting to + // a mock server socket. + // NOTE: this is not perfect as there is a small amount of time between the shutdown + // of the ldap server and the opening of the socket + logger.debug("opening mock server socket listening on [{}]", port); + Runnable runnable = new PortBlockingRunnable(mockServerSocket.getInetAddress(), mockServerSocket.getLocalPort(), port, + latch, closeLatch, success); + Thread thread = new Thread(runnable); + thread.start(); + listenThreads.add(thread); + + assertThat(ldapServers[index].getListenPort(), is(-1)); } - logger.debug("first non stopped port [{}]", firstNonStoppedPort); - assertThat(firstNonStoppedPort, not(-1)); - final int numberOfIterations = randomIntBetween(1, 5); - for (int iteration = 0; iteration < numberOfIterations; iteration++) { - LDAPConnection connection = null; - try { + try { + latch.await(); + + assumeTrue("Failed to open sockets on all addresses with the port that an LDAP server was bound to. Some operating systems " + + "allow binding to an address and port combination even if an application is bound to the port on a wildcard address", + success.get()); + int firstNonStoppedPort = -1; + // now we find the first that isn't stopped + for (int i = 0; i < numberOfLdapServers; i++) { + if (ldapServers[i].getListenPort() != -1) { + firstNonStoppedPort = ldapServers[i].getListenPort(); + break; + } + } + logger.debug("first non stopped port [{}]", firstNonStoppedPort); + assertThat(firstNonStoppedPort, not(-1)); + final int numberOfIterations = randomIntBetween(1, 5); + for (int iteration = 0; iteration < numberOfIterations; iteration++) { logger.debug("attempting connection with expected port [{}] iteration [{}]", firstNonStoppedPort, iteration); - connection = LdapUtils.privilegedConnect(testSessionFactory.getServerSet()::getConnection); - assertThat(connection.getConnectedPort(), is(firstNonStoppedPort)); - } finally { - if (connection != null) { - connection.close(); + LDAPConnection connection = null; + try { + do { + final LDAPConnection finalConnection = + LdapUtils.privilegedConnect(testSessionFactory.getServerSet()::getConnection); + connection = finalConnection; + logger.debug("established connection with port [{}] expected port [{}]", + finalConnection.getConnectedPort(), firstNonStoppedPort); + if (finalConnection.getConnectedPort() != firstNonStoppedPort) { + LDAPException e = expectThrows(LDAPException.class, () -> finalConnection.bind(new SimpleBindRequest())); + assertThat(e.getMessage(), containsString("not connected")); + finalConnection.close(); + } + } while (connection.getConnectedPort() != firstNonStoppedPort); + + assertThat(connection.getConnectedPort(), is(firstNonStoppedPort)); + } finally { + if (connection != null) { + connection.close(); + } } } + } finally { + closeLatch.countDown(); + mockServerSocket.close(); + for (Thread t : listenThreads) { + t.join(); + } } } @@ -245,6 +296,79 @@ private TestSessionFactory createSessionFactory(LdapLoadBalancing loadBalancing) threadPool); } + private class PortBlockingRunnable implements Runnable { + + private final InetAddress serverAddress; + private final int serverPort; + private final int portToBind; + private final CountDownLatch latch; + private final CountDownLatch closeLatch; + private final AtomicBoolean success; + + private PortBlockingRunnable(InetAddress serverAddress, int serverPort, int portToBind, CountDownLatch latch, + CountDownLatch closeLatch, AtomicBoolean success) { + this.serverAddress = serverAddress; + this.serverPort = serverPort; + this.portToBind = portToBind; + this.latch = latch; + this.closeLatch = closeLatch; + this.success = success; + } + + @Override + public void run() { + final List openedSockets = new ArrayList<>(); + final List blacklistedAddress = new ArrayList<>(); + try { + final boolean allSocketsOpened = awaitBusy(() -> { + try { + final List inetAddressesToBind = Arrays.stream(InetAddressHelper.getAllAddresses()) + .filter(addr -> openedSockets.stream().noneMatch(s -> addr.equals(s.getLocalAddress()))) + .filter(addr -> blacklistedAddress.contains(addr) == false) + .collect(Collectors.toList()); + for (InetAddress localAddress : inetAddressesToBind) { + try { + final Socket socket = openMockSocket(serverAddress, serverPort, localAddress, portToBind); + openedSockets.add(socket); + logger.debug("opened socket [{}]", socket); + } catch (NoRouteToHostException e) { + logger.debug(new ParameterizedMessage("blacklisting address [{}] due to:", localAddress), e); + blacklistedAddress.add(localAddress); + } + } + return true; + } catch (IOException e) { + logger.debug(new ParameterizedMessage("caught exception while opening socket on [{}]", portToBind), e); + return false; + } + }); + + if (allSocketsOpened) { + latch.countDown(); + } else { + success.set(false); + IOUtils.closeWhileHandlingException(openedSockets); + openedSockets.clear(); + latch.countDown(); + return; + } + } catch (InterruptedException e) { + logger.debug(new ParameterizedMessage("interrupted while trying to open sockets on [{}]", portToBind), e); + Thread.currentThread().interrupt(); + } + + try { + closeLatch.await(); + } catch (InterruptedException e) { + logger.debug("caught exception while waiting for close latch", e); + Thread.currentThread().interrupt(); + } finally { + logger.debug("closing sockets on [{}]", portToBind); + IOUtils.closeWhileHandlingException(openedSockets); + } + } + } + static class TestSessionFactory extends SessionFactory { protected TestSessionFactory(RealmConfig config, SSLService sslService, ThreadPool threadPool) {