diff --git a/src/main/java/com/timgroup/statsd/NonBlockingStatsDClientBuilder.java b/src/main/java/com/timgroup/statsd/NonBlockingStatsDClientBuilder.java index 2ceedccd..c9e0b694 100644 --- a/src/main/java/com/timgroup/statsd/NonBlockingStatsDClientBuilder.java +++ b/src/main/java/com/timgroup/statsd/NonBlockingStatsDClientBuilder.java @@ -329,10 +329,27 @@ public static Callable volatileAddressResolution(final String hos if (port == 0) { return new Callable() { @Override public SocketAddress call() throws UnknownHostException { - return new UnixSocketAddressWithTransport( - new UnixSocketAddress(hostname), - UnixSocketAddressWithTransport.TransportType.UDS - ); + if (VersionUtils.hasNativeUdsSupport()) { + try { + Class udsAddressClass = Class.forName("java.net.UnixDomainSocketAddress"); + SocketAddress udsAddress = (SocketAddress) + udsAddressClass.getMethod("of", String.class).invoke(null, hostname); + System.out.println("================UnixSocketAddressWithTransport returned with: " + udsAddress); + return new UnixSocketAddressWithTransport( + udsAddress, + UnixSocketAddressWithTransport.TransportType.UDS + ); + } catch (Exception e) { + throw new UnknownHostException("Failed to create UnixDomainSocketAddress: " + e.getMessage()); + } + } else { + SocketAddress jnrAddress = new UnixSocketAddress(hostname); + System.out.println("================UnixSocketAddressWithTransport returned with: " + jnrAddress); + return new UnixSocketAddressWithTransport( + jnrAddress, + UnixSocketAddressWithTransport.TransportType.UDS + ); + } } }; } else { @@ -374,12 +391,29 @@ protected static Callable staticNamedPipeResolution(String namedP protected static Callable staticUnixResolution( final String path, final UnixSocketAddressWithTransport.TransportType transportType) { - return new Callable() { - @Override public SocketAddress call() { - final UnixSocketAddress socketAddress = new UnixSocketAddress(path); - return new UnixSocketAddressWithTransport(socketAddress, transportType); + if (VersionUtils.hasNativeUdsSupport()) { + try { + Class udsAddressClass = Class.forName("java.net.UnixDomainSocketAddress"); + final SocketAddress udsAddress = (SocketAddress) udsAddressClass.getMethod("of", String.class).invoke(null, path); + System.out.println("================new Callable returned with udsAddress: " + udsAddress); + return new Callable() { + @Override public SocketAddress call() { + System.out.println("================UnixSocketAddressWithTransport returned with: " + udsAddress); + return new UnixSocketAddressWithTransport(udsAddress, transportType); + } + }; + } catch (Exception e) { + throw new RuntimeException("Failed to create UnixDomainSocketAddress: " + e.getMessage(), e); } - }; + } else { + return new Callable() { + @Override public SocketAddress call() { + final UnixSocketAddress jnrAddress = new UnixSocketAddress(path); + System.out.println("================UnixSocketAddressWithTransport returned with: " + jnrAddress); + return new UnixSocketAddressWithTransport(jnrAddress, transportType); + } + }; + } } private static Callable staticAddress(final String hostname, final int port) { diff --git a/src/main/java/com/timgroup/statsd/UnixDatagramClientChannel.java b/src/main/java/com/timgroup/statsd/UnixDatagramClientChannel.java index 7d996963..b3dc5a5b 100644 --- a/src/main/java/com/timgroup/statsd/UnixDatagramClientChannel.java +++ b/src/main/java/com/timgroup/statsd/UnixDatagramClientChannel.java @@ -5,6 +5,7 @@ import java.io.IOException; import java.net.SocketAddress; +import java.nio.channels.DatagramChannel; class UnixDatagramClientChannel extends DatagramClientChannel { /** @@ -16,14 +17,33 @@ class UnixDatagramClientChannel extends DatagramClientChannel { * @throws IOException if socket options cannot be set */ UnixDatagramClientChannel(SocketAddress address, int timeout, int bufferSize) throws IOException { - super(UnixDatagramChannel.open(), address); - // Set send timeout, to handle the case where the transmission buffer is full - // If no timeout is set, the send becomes blocking - if (timeout > 0) { - delegate.setOption(UnixSocketOptions.SO_SNDTIMEO, timeout); + super(createChannel(address), address); + configureChannel(timeout, bufferSize); + } + + private static DatagramChannel createChannel(SocketAddress address) throws IOException { + if (VersionUtils.hasNativeUdsSupport()) { + return DatagramChannel.open(); + } else { + return UnixDatagramChannel.open(); } - if (bufferSize > 0) { - delegate.setOption(UnixSocketOptions.SO_SNDBUF, bufferSize); + } + + private void configureChannel(int timeout, int bufferSize) throws IOException { + if (VersionUtils.hasNativeUdsSupport()) { + if (timeout > 0) { + delegate.socket().setSoTimeout(timeout); + } + if (bufferSize > 0) { + delegate.socket().setSendBufferSize(bufferSize); + } + } else { + if (timeout > 0) { + delegate.setOption(UnixSocketOptions.SO_SNDTIMEO, timeout); + } + if (bufferSize > 0) { + delegate.setOption(UnixSocketOptions.SO_SNDBUF, bufferSize); + } } } diff --git a/src/main/java/com/timgroup/statsd/UnixStreamClientChannel.java b/src/main/java/com/timgroup/statsd/UnixStreamClientChannel.java index d910d786..b3a163f4 100644 --- a/src/main/java/com/timgroup/statsd/UnixStreamClientChannel.java +++ b/src/main/java/com/timgroup/statsd/UnixStreamClientChannel.java @@ -5,6 +5,7 @@ import jnr.unixsocket.UnixSocketOptions; import java.io.IOException; +import java.lang.reflect.Method; import java.net.SocketAddress; import java.nio.ByteBuffer; import java.nio.ByteOrder; @@ -14,12 +15,11 @@ * A ClientChannel for Unix domain sockets. */ public class UnixStreamClientChannel implements ClientChannel { - private final UnixSocketAddress address; + private final SocketAddress address; private final int timeout; private final int connectionTimeout; private final int bufferSize; - private SocketChannel delegate; private final ByteBuffer delimiterBuffer = ByteBuffer.allocateDirect(Integer.SIZE / Byte.SIZE).order(ByteOrder.LITTLE_ENDIAN); @@ -30,10 +30,11 @@ public class UnixStreamClientChannel implements ClientChannel { */ UnixStreamClientChannel(SocketAddress address, int timeout, int connectionTimeout, int bufferSize) throws IOException { this.delegate = null; - this.address = (UnixSocketAddress) address; + this.address = address; this.timeout = timeout; this.connectionTimeout = connectionTimeout; this.bufferSize = bufferSize; + System.out.println("================Created UnixStreamClientChannel with address: " + address); } @Override @@ -74,35 +75,6 @@ public synchronized int write(ByteBuffer src) throws IOException { return size; } - /** - * Writes all bytes from the given buffer to the channel. - * @param bb buffer to write - * @param canReturnOnTimeout if true, we return if the channel is blocking and we haven't written anything yet - * @param deadline deadline for the write - * @return number of bytes written - * @throws IOException if the channel is closed or an error occurs - */ - public int writeAll(ByteBuffer bb, boolean canReturnOnTimeout, long deadline) throws IOException { - int remaining = bb.remaining(); - int written = 0; - while (remaining > 0) { - int read = delegate.write(bb); - - // If we haven't written anything yet, we can still return - if (read == 0 && canReturnOnTimeout && written == 0) { - return written; - } - - remaining -= read; - written += read; - - if (deadline > 0 && System.nanoTime() > deadline) { - throw new IOException("Write timed out"); - } - } - return written; - } - private void connectIfNeeded() throws IOException { if (delegate == null) { connect(); @@ -125,17 +97,67 @@ private void connect() throws IOException { } } - UnixSocketChannel delegate = UnixSocketChannel.create(); - long deadline = System.nanoTime() + connectionTimeout * 1_000_000L; + // Use native JDK Unix domain socket support for compatible versions (Java 16+). Fall back to JNR support otherwise. + if (VersionUtils.hasNativeUdsSupport()) { + connectJdkSocket(deadline); + } else { + connectJnrSocket(deadline); + } + } + + private void connectJdkSocket(long deadline) throws IOException { + String socketPath = address.toString(); + try { + Class udsAddressClass = Class.forName("java.net.UnixDomainSocketAddress"); + Object udsAddress = udsAddressClass.getMethod("of", String.class).invoke(null, socketPath); + + SocketChannel delegate = SocketChannel.open(); + if (connectionTimeout > 0) { + delegate.socket().setSoTimeout(connectionTimeout); + } + + delegate.configureBlocking(false); + System.out.println("================Attempting to connect delegate to: " + udsAddress); + if (!delegate.connect((SocketAddress) udsAddress)) { + System.out.println("================Initial connect returned false, checking deadline"); + if (connectionTimeout > 0 && System.nanoTime() > deadline) { + throw new IOException("Connection timed out"); + } + System.out.println("================Finishing connection"); + if (!delegate.finishConnect()) { + throw new IOException("Connection failed"); + } + } + System.out.println("================Connection successful"); + delegate.configureBlocking(true); + delegate.socket().setSoTimeout(Math.max(timeout, 0)); + if (bufferSize > 0) { + delegate.socket().setSendBufferSize(bufferSize); + } + this.delegate = delegate; + System.out.println("================Set up complete."); + } catch (Exception e) { + System.out.println("================Failed to connect to UDS at: " + socketPath); + try { + delegate.close(); + } catch (IOException __) { + // ignore + } + throw new IOException("Failed to connect to Unix Domain Socket: " + socketPath, e); + } + } + + private void connectJnrSocket(long deadline) throws IOException { + UnixSocketChannel delegate = UnixSocketChannel.create(); if (connectionTimeout > 0) { // Set connect timeout, this should work at least on linux // https://elixir.bootlin.com/linux/v5.7.4/source/net/unix/af_unix.c#L1696 - // We'd have better timeout support if we used Java 16's native Unix domain socket support (JEP 380) delegate.setOption(UnixSocketOptions.SO_SNDTIMEO, connectionTimeout); } try { - if (!delegate.connect(address)) { + System.out.println("================Attempting to connect delegate to: " + address); + if (!delegate.connect((UnixSocketAddress) address)) { if (connectionTimeout > 0 && System.nanoTime() > deadline) { throw new IOException("Connection timed out"); } @@ -143,11 +165,13 @@ private void connect() throws IOException { throw new IOException("Connection failed"); } } - + System.out.println("================Connection successful"); delegate.setOption(UnixSocketOptions.SO_SNDTIMEO, Math.max(timeout, 0)); if (bufferSize > 0) { delegate.setOption(UnixSocketOptions.SO_SNDBUF, bufferSize); } + this.delegate = delegate; + System.out.println("================Set up complete."); } catch (Exception e) { try { delegate.close(); @@ -156,9 +180,35 @@ private void connect() throws IOException { } throw e; } + } + + /** + * Writes all bytes from the given buffer to the channel. + * @param bb buffer to write + * @param canReturnOnTimeout if true, we return if the channel is blocking and we haven't written anything yet + * @param deadline deadline for the write + * @return number of bytes written + * @throws IOException if the channel is closed or an error occurs + */ + public int writeAll(ByteBuffer bb, boolean canReturnOnTimeout, long deadline) throws IOException { + int remaining = bb.remaining(); + int written = 0; + while (remaining > 0) { + int read = delegate.write(bb); + // If we haven't written anything yet, we can still return + if (read == 0 && canReturnOnTimeout && written == 0) { + return written; + } + + remaining -= read; + written += read; - this.delegate = delegate; + if (deadline > 0 && System.nanoTime() > deadline) { + throw new IOException("Write timed out"); + } + } + return written; } @Override diff --git a/src/main/java/com/timgroup/statsd/VersionUtils.java b/src/main/java/com/timgroup/statsd/VersionUtils.java new file mode 100644 index 00000000..f7c082c2 --- /dev/null +++ b/src/main/java/com/timgroup/statsd/VersionUtils.java @@ -0,0 +1,113 @@ +package com.timgroup.statsd; + +import java.util.ArrayList; +import java.util.List; + +// Logic copied from dd-trace-java Platform class. See: +// https://github.com/DataDog/dd-trace-java/blob/master/internal-api/src/main/java/datadog/trace/api/Platform.java +public class VersionUtils { + private static final Version JAVA_VERSION = parseJavaVersion(System.getProperty("java.version")); + private static final int NATIVE_UDS_MIN_VERSION = 16; // Java 16+ has native Unix Domain Socket support + + private static Version parseJavaVersion(String javaVersion) { + // Remove pre-release part, usually -ea + final int indexOfDash = javaVersion.indexOf('-'); + if (indexOfDash >= 0) { + javaVersion = javaVersion.substring(0, indexOfDash); + } + + int major = 0; + int minor = 0; + int update = 0; + + try { + List nums = splitDigits(javaVersion); + major = nums.get(0); + + // for java 1.6/1.7/1.8 + if (major == 1) { + major = nums.get(1); + minor = nums.get(2); + update = nums.get(3); + } else { + minor = nums.get(1); + update = nums.get(2); + } + } catch (NumberFormatException | IndexOutOfBoundsException e) { + // unable to parse version string - do nothing + } + return new Version(major, minor, update); + } + + private static List splitDigits(String str) { + List results = new ArrayList<>(); + + int len = str.length(); + + int value = 0; + for (int i = 0; i < len; i++) { + char ch = str.charAt(i); + if (ch >= '0' && ch <= '9') { + value = value * 10 + (ch - '0'); + } else if (ch == '.' || ch == '_' || ch == '+') { + results.add(value); + value = 0; + } else { + throw new NumberFormatException(); + } + } + results.add(value); + return results; + } + + static final class Version { + public final int major; + public final int minor; + public final int update; + + public Version(int major, int minor, int update) { + this.major = major; + this.minor = minor; + this.update = update; + } + + public boolean is(int major) { + return this.major == major; + } + + public boolean is(int major, int minor) { + return this.major == major && this.minor == minor; + } + + public boolean is(int major, int minor, int update) { + return this.major == major && this.minor == minor && this.update == update; + } + + public boolean isAtLeast(int major, int minor, int update) { + return isAtLeast(this.major, this.minor, this.update, major, minor, update); + } + + private static boolean isAtLeast( + int major, int minor, int update, int atLeastMajor, int atLeastMinor, int atLeastUpdate) { + return (major > atLeastMajor) + || (major == atLeastMajor && minor > atLeastMinor) + || (major == atLeastMajor && minor == atLeastMinor && update >= atLeastUpdate); + } + } + + public static boolean isJavaVersionAtLeast(int major) { + return isJavaVersionAtLeast(major, 0, 0); + } + + public static boolean isJavaVersionAtLeast(int major, int minor) { + return isJavaVersionAtLeast(major, minor, 0); + } + + public static boolean isJavaVersionAtLeast(int major, int minor, int update) { + return JAVA_VERSION.isAtLeast(major, minor, update); + } + + public static boolean hasNativeUdsSupport() { + return isJavaVersionAtLeast(NATIVE_UDS_MIN_VERSION); + } +} diff --git a/src/test/java/com/timgroup/statsd/UnixStreamSocketDummyStatsDServer.java b/src/test/java/com/timgroup/statsd/UnixStreamSocketDummyStatsDServer.java index ea743c8e..6ba45555 100644 --- a/src/test/java/com/timgroup/statsd/UnixStreamSocketDummyStatsDServer.java +++ b/src/test/java/com/timgroup/statsd/UnixStreamSocketDummyStatsDServer.java @@ -23,6 +23,7 @@ public UnixStreamSocketDummyStatsDServer(String socketPath) throws IOException { server = UnixServerSocketChannel.open(); server.configureBlocking(true); server.socket().bind(new UnixSocketAddress(socketPath)); + System.out.println("================Server bound to " + socketPath); this.listen(); } @@ -39,6 +40,7 @@ protected void receive(ByteBuffer packet) throws IOException { @Override protected void listen() { logger.info("Listening on " + server.getLocalSocketAddress()); + System.out.println("================Server listening on " + server.getLocalSocketAddress()); Thread thread = new Thread(new Runnable() { @Override public void run() { @@ -48,7 +50,9 @@ public void run() { } try { logger.info("Waiting for connection"); + System.out.println("================Server waiting for connection"); UnixSocketChannel clientChannel = server.accept(); + System.out.println("================Server accepted connection"); if (clientChannel != null) { clientChannel.configureBlocking(true); try { @@ -60,6 +64,7 @@ public void run() { readChannel(clientChannel); } } catch (IOException e) { + System.out.println("================Server caught IOException: " + e); } } }