Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -329,10 +329,27 @@ public static Callable<SocketAddress> volatileAddressResolution(final String hos
if (port == 0) {
return new Callable<SocketAddress>() {
@Override public SocketAddress call() throws UnknownHostException {
return new UnixSocketAddressWithTransport(
new UnixSocketAddress(hostname),
UnixSocketAddressWithTransport.TransportType.UDS
);
if (VersionUtils.hasNativeUdsSupport()) {
try {
Class<?> udsAddressClass = Class.forName("java.net.UnixDomainSocketAddress");
SocketAddress udsAddress = (SocketAddress)
udsAddressClass.getMethod("of", String.class).invoke(null, hostname);
System.out.println("================UnixSocketAddressWithTransport returned with: " + udsAddress);
return new UnixSocketAddressWithTransport(
udsAddress,
UnixSocketAddressWithTransport.TransportType.UDS
);
} catch (Exception e) {
throw new UnknownHostException("Failed to create UnixDomainSocketAddress: " + e.getMessage());
}
} else {
SocketAddress jnrAddress = new UnixSocketAddress(hostname);
System.out.println("================UnixSocketAddressWithTransport returned with: " + jnrAddress);
return new UnixSocketAddressWithTransport(
jnrAddress,
UnixSocketAddressWithTransport.TransportType.UDS
);
}
}
};
} else {
Expand Down Expand Up @@ -374,12 +391,29 @@ protected static Callable<SocketAddress> staticNamedPipeResolution(String namedP
protected static Callable<SocketAddress> staticUnixResolution(
final String path,
final UnixSocketAddressWithTransport.TransportType transportType) {
return new Callable<SocketAddress>() {
@Override public SocketAddress call() {
final UnixSocketAddress socketAddress = new UnixSocketAddress(path);
return new UnixSocketAddressWithTransport(socketAddress, transportType);
if (VersionUtils.hasNativeUdsSupport()) {
try {
Class<?> udsAddressClass = Class.forName("java.net.UnixDomainSocketAddress");
final SocketAddress udsAddress = (SocketAddress) udsAddressClass.getMethod("of", String.class).invoke(null, path);
System.out.println("================new Callable<SocketAddress> returned with udsAddress: " + udsAddress);
return new Callable<SocketAddress>() {
@Override public SocketAddress call() {
System.out.println("================UnixSocketAddressWithTransport returned with: " + udsAddress);
return new UnixSocketAddressWithTransport(udsAddress, transportType);
}
};
} catch (Exception e) {
throw new RuntimeException("Failed to create UnixDomainSocketAddress: " + e.getMessage(), e);
}
};
} else {
return new Callable<SocketAddress>() {
@Override public SocketAddress call() {
final UnixSocketAddress jnrAddress = new UnixSocketAddress(path);
System.out.println("================UnixSocketAddressWithTransport returned with: " + jnrAddress);
return new UnixSocketAddressWithTransport(jnrAddress, transportType);
}
};
}
}

private static Callable<SocketAddress> staticAddress(final String hostname, final int port) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import java.io.IOException;
import java.net.SocketAddress;
import java.nio.channels.DatagramChannel;

class UnixDatagramClientChannel extends DatagramClientChannel {
/**
Expand All @@ -16,14 +17,33 @@ class UnixDatagramClientChannel extends DatagramClientChannel {
* @throws IOException if socket options cannot be set
*/
UnixDatagramClientChannel(SocketAddress address, int timeout, int bufferSize) throws IOException {
super(UnixDatagramChannel.open(), address);
// Set send timeout, to handle the case where the transmission buffer is full
// If no timeout is set, the send becomes blocking
if (timeout > 0) {
delegate.setOption(UnixSocketOptions.SO_SNDTIMEO, timeout);
super(createChannel(address), address);
configureChannel(timeout, bufferSize);
}

private static DatagramChannel createChannel(SocketAddress address) throws IOException {
if (VersionUtils.hasNativeUdsSupport()) {
return DatagramChannel.open();
} else {
return UnixDatagramChannel.open();
}
if (bufferSize > 0) {
delegate.setOption(UnixSocketOptions.SO_SNDBUF, bufferSize);
}

private void configureChannel(int timeout, int bufferSize) throws IOException {
if (VersionUtils.hasNativeUdsSupport()) {
if (timeout > 0) {
delegate.socket().setSoTimeout(timeout);
}
if (bufferSize > 0) {
delegate.socket().setSendBufferSize(bufferSize);
}
} else {
if (timeout > 0) {
delegate.setOption(UnixSocketOptions.SO_SNDTIMEO, timeout);
}
if (bufferSize > 0) {
delegate.setOption(UnixSocketOptions.SO_SNDBUF, bufferSize);
}
}
}

Expand Down
126 changes: 88 additions & 38 deletions src/main/java/com/timgroup/statsd/UnixStreamClientChannel.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import jnr.unixsocket.UnixSocketOptions;

import java.io.IOException;
import java.lang.reflect.Method;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
Expand All @@ -14,12 +15,11 @@
* A ClientChannel for Unix domain sockets.
*/
public class UnixStreamClientChannel implements ClientChannel {
private final UnixSocketAddress address;
private final SocketAddress address;
private final int timeout;
private final int connectionTimeout;
private final int bufferSize;


private SocketChannel delegate;
private final ByteBuffer delimiterBuffer = ByteBuffer.allocateDirect(Integer.SIZE / Byte.SIZE).order(ByteOrder.LITTLE_ENDIAN);

Expand All @@ -30,10 +30,11 @@ public class UnixStreamClientChannel implements ClientChannel {
*/
UnixStreamClientChannel(SocketAddress address, int timeout, int connectionTimeout, int bufferSize) throws IOException {
this.delegate = null;
this.address = (UnixSocketAddress) address;
this.address = address;
this.timeout = timeout;
this.connectionTimeout = connectionTimeout;
this.bufferSize = bufferSize;
System.out.println("================Created UnixStreamClientChannel with address: " + address);
}

@Override
Expand Down Expand Up @@ -74,35 +75,6 @@ public synchronized int write(ByteBuffer src) throws IOException {
return size;
}

/**
* Writes all bytes from the given buffer to the channel.
* @param bb buffer to write
* @param canReturnOnTimeout if true, we return if the channel is blocking and we haven't written anything yet
* @param deadline deadline for the write
* @return number of bytes written
* @throws IOException if the channel is closed or an error occurs
*/
public int writeAll(ByteBuffer bb, boolean canReturnOnTimeout, long deadline) throws IOException {
int remaining = bb.remaining();
int written = 0;
while (remaining > 0) {
int read = delegate.write(bb);

// If we haven't written anything yet, we can still return
if (read == 0 && canReturnOnTimeout && written == 0) {
return written;
}

remaining -= read;
written += read;

if (deadline > 0 && System.nanoTime() > deadline) {
throw new IOException("Write timed out");
}
}
return written;
}

private void connectIfNeeded() throws IOException {
if (delegate == null) {
connect();
Expand All @@ -125,29 +97,81 @@ private void connect() throws IOException {
}
}

UnixSocketChannel delegate = UnixSocketChannel.create();

long deadline = System.nanoTime() + connectionTimeout * 1_000_000L;
// Use native JDK Unix domain socket support for compatible versions (Java 16+). Fall back to JNR support otherwise.
if (VersionUtils.hasNativeUdsSupport()) {
connectJdkSocket(deadline);
} else {
connectJnrSocket(deadline);
}
}

private void connectJdkSocket(long deadline) throws IOException {
String socketPath = address.toString();
try {
Class<?> udsAddressClass = Class.forName("java.net.UnixDomainSocketAddress");
Object udsAddress = udsAddressClass.getMethod("of", String.class).invoke(null, socketPath);

SocketChannel delegate = SocketChannel.open();
if (connectionTimeout > 0) {
delegate.socket().setSoTimeout(connectionTimeout);
}

delegate.configureBlocking(false);
System.out.println("================Attempting to connect delegate to: " + udsAddress);
if (!delegate.connect((SocketAddress) udsAddress)) {
System.out.println("================Initial connect returned false, checking deadline");
if (connectionTimeout > 0 && System.nanoTime() > deadline) {
throw new IOException("Connection timed out");
}
System.out.println("================Finishing connection");
if (!delegate.finishConnect()) {
throw new IOException("Connection failed");
}
}
System.out.println("================Connection successful");
delegate.configureBlocking(true);
delegate.socket().setSoTimeout(Math.max(timeout, 0));
if (bufferSize > 0) {
delegate.socket().setSendBufferSize(bufferSize);
}
this.delegate = delegate;
System.out.println("================Set up complete.");
} catch (Exception e) {
System.out.println("================Failed to connect to UDS at: " + socketPath);
try {
delegate.close();
} catch (IOException __) {
// ignore
}
throw new IOException("Failed to connect to Unix Domain Socket: " + socketPath, e);
}
}

private void connectJnrSocket(long deadline) throws IOException {
UnixSocketChannel delegate = UnixSocketChannel.create();
if (connectionTimeout > 0) {
// Set connect timeout, this should work at least on linux
// https://elixir.bootlin.com/linux/v5.7.4/source/net/unix/af_unix.c#L1696
// We'd have better timeout support if we used Java 16's native Unix domain socket support (JEP 380)
delegate.setOption(UnixSocketOptions.SO_SNDTIMEO, connectionTimeout);
}
try {
if (!delegate.connect(address)) {
System.out.println("================Attempting to connect delegate to: " + address);
if (!delegate.connect((UnixSocketAddress) address)) {
if (connectionTimeout > 0 && System.nanoTime() > deadline) {
throw new IOException("Connection timed out");
}
if (!delegate.finishConnect()) {
throw new IOException("Connection failed");
}
}

System.out.println("================Connection successful");
delegate.setOption(UnixSocketOptions.SO_SNDTIMEO, Math.max(timeout, 0));
if (bufferSize > 0) {
delegate.setOption(UnixSocketOptions.SO_SNDBUF, bufferSize);
}
this.delegate = delegate;
System.out.println("================Set up complete.");
} catch (Exception e) {
try {
delegate.close();
Expand All @@ -156,9 +180,35 @@ private void connect() throws IOException {
}
throw e;
}
}

/**
* Writes all bytes from the given buffer to the channel.
* @param bb buffer to write
* @param canReturnOnTimeout if true, we return if the channel is blocking and we haven't written anything yet
* @param deadline deadline for the write
* @return number of bytes written
* @throws IOException if the channel is closed or an error occurs
*/
public int writeAll(ByteBuffer bb, boolean canReturnOnTimeout, long deadline) throws IOException {
int remaining = bb.remaining();
int written = 0;
while (remaining > 0) {
int read = delegate.write(bb);

// If we haven't written anything yet, we can still return
if (read == 0 && canReturnOnTimeout && written == 0) {
return written;
}

remaining -= read;
written += read;

this.delegate = delegate;
if (deadline > 0 && System.nanoTime() > deadline) {
throw new IOException("Write timed out");
}
}
return written;
}

@Override
Expand Down
Loading