Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions src/main/java/com/timgroup/statsd/NonBlockingStatsDClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ String tag() {
public static final int SOCKET_BUFFER_BYTES = -1;
public static final boolean DEFAULT_BLOCKING = false;
public static final boolean DEFAULT_ENABLE_TELEMETRY = true;
public static final boolean DEFAULT_ENABLE_JDK_SOCKET = true;

public static final boolean DEFAULT_ENABLE_AGGREGATION = true;
public static final boolean DEFAULT_ENABLE_ORIGIN_DETECTION = true;
Expand Down Expand Up @@ -248,7 +249,8 @@ public NonBlockingStatsDClient(final NonBlockingStatsDClientBuilder builder)
builder.addressLookup,
builder.timeout,
builder.connectionTimeout,
builder.socketBufferSize);
builder.socketBufferSize,
builder.enableJdkSocket);

ThreadFactory threadFactory =
builder.threadFactory != null
Expand Down Expand Up @@ -296,7 +298,8 @@ public NonBlockingStatsDClient(final NonBlockingStatsDClientBuilder builder)
builder.telemetryAddressLookup,
builder.timeout,
builder.connectionTimeout,
builder.socketBufferSize);
builder.socketBufferSize,
builder.enableJdkSocket);

// similar settings, but a single worker and non-blocking.
telemetryStatsDProcessor =
Expand Down Expand Up @@ -482,7 +485,8 @@ ClientChannel createByteChannel(
Callable<SocketAddress> addressLookup,
int timeout,
int connectionTimeout,
int bufferSize)
int bufferSize,
boolean enableJdkSocket)
throws Exception {
final SocketAddress address = addressLookup.call();
if (address instanceof NamedPipeSocketAddress) {
Expand All @@ -497,7 +501,11 @@ ClientChannel createByteChannel(
switch (unixAddr.getTransportType()) {
case UDS_STREAM:
return new UnixStreamClientChannel(
unixAddr.getAddress(), timeout, connectionTimeout, bufferSize);
unixAddr.getAddress(),
timeout,
connectionTimeout,
bufferSize,
enableJdkSocket);
case UDS_DATAGRAM:
case UDS:
return new UnixDatagramClientChannel(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.timgroup.statsd;

import java.lang.reflect.Method;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
Expand Down Expand Up @@ -52,6 +53,9 @@ public class NonBlockingStatsDClientBuilder implements Cloneable {

public boolean enableAggregation = NonBlockingStatsDClient.DEFAULT_ENABLE_AGGREGATION;

/** Enable native JDK support for UDS. Only available on Java 16+. */
public boolean enableJdkSocket = NonBlockingStatsDClient.DEFAULT_ENABLE_JDK_SOCKET;

/** Telemetry flush interval, in milliseconds. */
public int telemetryFlushInterval = Telemetry.DEFAULT_FLUSH_INTERVAL;

Expand Down Expand Up @@ -322,6 +326,11 @@ public NonBlockingStatsDClientBuilder originDetectionEnabled(boolean val) {
return this;
}

public NonBlockingStatsDClientBuilder enableJdkSocket(boolean val) {
enableJdkSocket = val;
return this;
}

/**
* Request that all metrics from this client to be enriched to specified tag cardinality.
*
Expand Down Expand Up @@ -523,8 +532,30 @@ protected static Callable<SocketAddress> staticUnixResolution(
return new Callable<SocketAddress>() {
@Override
public SocketAddress call() {
final UnixSocketAddress socketAddress = new UnixSocketAddress(path);
return new UnixSocketAddressWithTransport(socketAddress, transportType);
SocketAddress socketAddress;

// Use native JDK support for UDS on Java 16+ and jnr-unixsocket otherwise
if (VersionUtils.isJavaVersionAtLeast(16)
&& NonBlockingStatsDClient.DEFAULT_ENABLE_JDK_SOCKET) {
try {
// Avoid compiling Java 16+ classes in incompatible versions
Class<?> unixDomainSocketAddressClass =
Class.forName("java.net.UnixDomainSocketAddress");
Method ofMethod =
unixDomainSocketAddressClass.getMethod("of", String.class);
socketAddress = (SocketAddress) ofMethod.invoke(null, path);
} catch (Exception e) {
throw new StatsDClientException(
"Failed to create UnixSocketAddress for native JDK UDS implementation",
e);
}
} else {
socketAddress = new UnixSocketAddress(path);
}
UnixSocketAddressWithTransport result =
new UnixSocketAddressWithTransport(socketAddress, transportType);

return result;
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ class UnixDatagramClientChannel extends DatagramClientChannel {
*/
UnixDatagramClientChannel(SocketAddress address, int timeout, int bufferSize)
throws IOException {
// Ideally we could use native JDK UDS support such as with the UnixStreamClientChannel.
// However, DatagramChannels do not support StandardProtocolFamily.UNIX, so this is
// unavailable.
// See this open issue for updates: https://bugs.openjdk.org/browse/JDK-8297837?
super(UnixDatagramChannel.open(), address);
// Set send timeout, to handle the case where the transmission buffer is full
// If no timeout is set, the send becomes blocking
Expand Down
143 changes: 121 additions & 22 deletions src/main/java/com/timgroup/statsd/UnixStreamClientChannel.java
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
package com.timgroup.statsd;

import java.io.IOException;
import java.lang.reflect.Method;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.SocketChannel;
import jnr.unixsocket.UnixSocketAddress;
import jnr.unixsocket.UnixSocketChannel;
import jnr.unixsocket.UnixSocketOptions;

/** A ClientChannel for Unix domain sockets. */
public class UnixStreamClientChannel implements ClientChannel {
private final UnixSocketAddress address;
private final SocketAddress address;
private final int timeout;
private final int connectionTimeout;
private final int bufferSize;
private final boolean enableJdkSocket;

private SocketChannel delegate;
private final ByteBuffer delimiterBuffer =
Expand All @@ -26,13 +30,18 @@ public class UnixStreamClientChannel implements ClientChannel {
* @param address Location of named pipe
*/
UnixStreamClientChannel(
SocketAddress address, int timeout, int connectionTimeout, int bufferSize)
SocketAddress address,
int timeout,
int connectionTimeout,
int bufferSize,
boolean enableJdkSocket)
throws IOException {
this.delegate = null;
this.address = (UnixSocketAddress) address;
this.address = address;
this.timeout = timeout;
this.connectionTimeout = connectionTimeout;
this.bufferSize = bufferSize;
this.enableJdkSocket = enableJdkSocket;
}

@Override
Expand Down Expand Up @@ -87,19 +96,37 @@ public int writeAll(ByteBuffer bb, boolean canReturnOnTimeout, long deadline)
throws IOException {
int remaining = bb.remaining();
int written = 0;
long timeoutMs = timeout;

while (remaining > 0) {
int read = delegate.write(bb);

// If we haven't written anything yet, we can still return
if (read == 0 && canReturnOnTimeout && written == 0) {
return written;
if (read > 0) {
remaining -= read;
written += read;
continue;
}

remaining -= read;
written += read;
if (read == 0) {
if (canReturnOnTimeout && written == 0) {
return written;
}

try (Selector selector = Selector.open()) {
SelectionKey key = delegate.register(selector, SelectionKey.OP_WRITE);
long selectTimeout = timeoutMs;

if (deadline > 0) {
long remainingNs = deadline - System.nanoTime();
if (remainingNs <= 0) {
throw new IOException("Write timed out");
}
selectTimeout = Math.min(timeoutMs, remainingNs / 1_000_000L);
}

if (deadline > 0 && System.nanoTime() > deadline) {
throw new IOException("Write timed out");
if (selector.select(selectTimeout) == 0) {
throw new IOException("Write timed out after " + selectTimeout + "ms");
}
}
}
}
return written;
Expand Down Expand Up @@ -127,40 +154,112 @@ private void connect() throws IOException {
}
}

UnixSocketChannel delegate = UnixSocketChannel.create();

long deadline = System.nanoTime() + connectionTimeout * 1_000_000L;
// Use native JDK support for UDS on Java 16+ and jnr-unixsocket otherwise
if (VersionUtils.isJavaVersionAtLeast(16) && enableJdkSocket) {
try {
// Avoid compiling Java 16+ classes in incompatible versions
Class<?> protocolFamilyClass = Class.forName("java.net.ProtocolFamily");
Class<?> standardProtocolFamilyClass =
Class.forName("java.net.StandardProtocolFamily");
Object unixProtocol =
Enum.valueOf((Class<Enum>) standardProtocolFamilyClass, "UNIX");
Method openMethod = SocketChannel.class.getMethod("open", protocolFamilyClass);
SocketChannel channel = (SocketChannel) openMethod.invoke(null, unixProtocol);

channel.configureBlocking(false);

try {
SocketAddress connectAddress = address;
if (address instanceof UnixSocketAddressWithTransport) {
connectAddress = ((UnixSocketAddressWithTransport) address).getAddress();
}

Method connectMethod =
SocketChannel.class.getMethod("connect", SocketAddress.class);
boolean connected = (boolean) connectMethod.invoke(channel, connectAddress);

if (!connected) {
try (Selector selector = Selector.open()) {
SelectionKey key = channel.register(selector, SelectionKey.OP_CONNECT);
int timeoutMs = connectionTimeout > 0 ? connectionTimeout : 1000;
int ready = selector.select(timeoutMs);

if (ready == 0) {
throw new IOException(
"Connection timed out after " + timeoutMs + "ms");
}

if (key.isConnectable()) {
connected = channel.finishConnect();
if (!connected) {
throw new IOException("Failed to complete connection");
}
}
}
}
} catch (Exception e) {
try {
channel.close();
} catch (IOException __) {
// ignore
}
throw e;
}

this.delegate = channel;
return;
} catch (Exception e) {
Throwable cause = e.getCause();
if (e instanceof java.lang.reflect.InvocationTargetException
&& cause instanceof IOException) {
throw (IOException) cause;
}
throw new IOException(
"Failed to create UnixStreamClientChannel for native UDS implementation",
e);
}
}
// Default to jnr-unixsocket if Java version is < 16 or native support is disabled
UnixSocketChannel channel = UnixSocketChannel.create();

if (connectionTimeout > 0) {
// Set connect timeout, this should work at least on linux
// https://elixir.bootlin.com/linux/v5.7.4/source/net/unix/af_unix.c#L1696
// We'd have better timeout support if we used Java 16's native Unix domain socket
// support (JEP 380)
delegate.setOption(UnixSocketOptions.SO_SNDTIMEO, connectionTimeout);
channel.setOption(UnixSocketOptions.SO_SNDTIMEO, connectionTimeout);
}

try {
if (!delegate.connect(address)) {
UnixSocketAddress unixAddress;
if (address instanceof UnixSocketAddress) {
unixAddress = (UnixSocketAddress) address;
} else {
unixAddress = new UnixSocketAddress(address.toString());
}

if (!channel.connect(unixAddress)) {
if (connectionTimeout > 0 && System.nanoTime() > deadline) {
throw new IOException("Connection timed out");
}
if (!delegate.finishConnect()) {
if (!channel.finishConnect()) {
throw new IOException("Connection failed");
}
}

delegate.setOption(UnixSocketOptions.SO_SNDTIMEO, Math.max(timeout, 0));
channel.setOption(UnixSocketOptions.SO_SNDTIMEO, Math.max(timeout, 0));
if (bufferSize > 0) {
delegate.setOption(UnixSocketOptions.SO_SNDBUF, bufferSize);
channel.setOption(UnixSocketOptions.SO_SNDBUF, bufferSize);
}
} catch (Exception e) {
try {
delegate.close();
channel.close();
} catch (IOException __) {
// ignore
}
throw e;
}

this.delegate = delegate;
this.delegate = channel;
}

@Override
Expand Down
Loading