Skip to content

Commit 373c67d

Browse files
authored
Add DirectByteBuffer strategy for transport-nio (#36289)
This is related to #27260. In Elasticsearch all of the messages that we serialize to write to the network are composed of heap bytes. When you read or write to a nio socket in java, the heap memory you passed down must be copied to/from direct memory. The JVM internally does some buffering of the direct memory, however it is essentially unbounded. This commit introduces a simple mechanism of buffering and copying the memory in transport-nio. Each network event loop is given a 64kb DirectByteBuffer. When we go to read we use this buffer and copy the data after the read. Additionally, when we go to write, we copy the data to the direct memory before calling write. 64KB is chosen as this is the default receive buffer size we use for transport-netty4 (NETTY_RECEIVE_PREDICTOR_SIZE). Since we only have one buffer per thread, we could afford larger. However, if we the buffer is large and not all of the data is flushed in a write call, we will do excess copies. This is something we can explore in the future.
1 parent fc85c37 commit 373c67d

File tree

6 files changed

+321
-69
lines changed

6 files changed

+321
-69
lines changed

libs/nio/src/main/java/org/elasticsearch/nio/BytesChannelContext.java

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,19 +38,12 @@ public BytesChannelContext(NioSocketChannel channel, NioSelector selector, Consu
3838

3939
@Override
4040
public int read() throws IOException {
41-
if (channelBuffer.getRemaining() == 0) {
42-
// Requiring one additional byte will ensure that a new page is allocated.
43-
channelBuffer.ensureCapacity(channelBuffer.getCapacity() + 1);
44-
}
45-
46-
int bytesRead = readFromChannel(channelBuffer.sliceBuffersFrom(channelBuffer.getIndex()));
41+
int bytesRead = readFromChannel(channelBuffer);
4742

4843
if (bytesRead == 0) {
4944
return 0;
5045
}
5146

52-
channelBuffer.incrementIndex(bytesRead);
53-
5447
handleReadBytes();
5548

5649
return bytesRead;
@@ -91,8 +84,7 @@ public boolean selectorShouldClose() {
9184
* Returns a boolean indicating if the operation was fully flushed.
9285
*/
9386
private boolean singleFlush(FlushOperation flushOperation) throws IOException {
94-
int written = flushToChannel(flushOperation.getBuffersToWrite());
95-
flushOperation.incrementIndex(written);
87+
flushToChannel(flushOperation);
9688
return flushOperation.isFullyFlushed();
9789
}
9890
}

libs/nio/src/main/java/org/elasticsearch/nio/NioSelector.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import java.io.Closeable;
2323
import java.io.IOException;
24+
import java.nio.ByteBuffer;
2425
import java.nio.channels.CancelledKeyException;
2526
import java.nio.channels.ClosedChannelException;
2627
import java.nio.channels.ClosedSelectorException;
@@ -51,6 +52,7 @@ public class NioSelector implements Closeable {
5152
private final ConcurrentLinkedQueue<ChannelContext<?>> channelsToRegister = new ConcurrentLinkedQueue<>();
5253
private final EventHandler eventHandler;
5354
private final Selector selector;
55+
private final ByteBuffer ioBuffer;
5456

5557
private final ReentrantLock runLock = new ReentrantLock();
5658
private final CountDownLatch exitedLoop = new CountDownLatch(1);
@@ -65,6 +67,18 @@ public NioSelector(EventHandler eventHandler) throws IOException {
6567
public NioSelector(EventHandler eventHandler, Selector selector) {
6668
this.selector = selector;
6769
this.eventHandler = eventHandler;
70+
this.ioBuffer = ByteBuffer.allocateDirect(1 << 16);
71+
}
72+
73+
/**
74+
* Returns a cached direct byte buffer for network operations. It is cleared on every get call.
75+
*
76+
* @return the byte buffer
77+
*/
78+
public ByteBuffer getIoBuffer() {
79+
assertOnSelectorThread();
80+
ioBuffer.clear();
81+
return ioBuffer;
6882
}
6983

7084
public Selector rawSelector() {

libs/nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java

Lines changed: 85 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
*/
4545
public abstract class SocketChannelContext extends ChannelContext<SocketChannel> {
4646

47-
public static final Predicate<NioSocketChannel> ALWAYS_ALLOW_CHANNEL = (c) -> true;
47+
protected static final Predicate<NioSocketChannel> ALWAYS_ALLOW_CHANNEL = (c) -> true;
4848

4949
protected final NioSocketChannel channel;
5050
protected final InboundChannelBuffer channelBuffer;
@@ -234,49 +234,113 @@ protected boolean closeNow() {
234234
return closeNow;
235235
}
236236

237+
238+
// When you read or write to a nio socket in java, the heap memory passed down must be copied to/from
239+
// direct memory. The JVM internally does some buffering of the direct memory, however we can save space
240+
// by reusing a thread-local direct buffer (provided by the NioSelector).
241+
//
242+
// Each network event loop is given a 64kb DirectByteBuffer. When we read we use this buffer and copy the
243+
// data after the read. When we go to write, we copy the data to the direct memory before calling write.
244+
// The choice of 64KB is rather arbitrary. We can explore different sizes in the future. However, any
245+
// data that is copied to the buffer for a write, but not successfully flushed immediately, must be
246+
// copied again on the next call.
247+
237248
protected int readFromChannel(ByteBuffer buffer) throws IOException {
249+
ByteBuffer ioBuffer = getSelector().getIoBuffer();
250+
ioBuffer.limit(Math.min(buffer.remaining(), ioBuffer.limit()));
251+
int bytesRead;
238252
try {
239-
int bytesRead = rawChannel.read(buffer);
240-
if (bytesRead < 0) {
241-
closeNow = true;
242-
bytesRead = 0;
243-
}
244-
return bytesRead;
253+
bytesRead = rawChannel.read(ioBuffer);
245254
} catch (IOException e) {
246255
closeNow = true;
247256
throw e;
248257
}
258+
if (bytesRead < 0) {
259+
closeNow = true;
260+
return 0;
261+
} else {
262+
ioBuffer.flip();
263+
buffer.put(ioBuffer);
264+
return bytesRead;
265+
}
249266
}
250267

251-
protected int readFromChannel(ByteBuffer[] buffers) throws IOException {
268+
protected int readFromChannel(InboundChannelBuffer channelBuffer) throws IOException {
269+
ByteBuffer ioBuffer = getSelector().getIoBuffer();
270+
int bytesRead;
252271
try {
253-
int bytesRead = (int) rawChannel.read(buffers);
254-
if (bytesRead < 0) {
255-
closeNow = true;
256-
bytesRead = 0;
257-
}
258-
return bytesRead;
272+
bytesRead = rawChannel.read(ioBuffer);
259273
} catch (IOException e) {
260274
closeNow = true;
261275
throw e;
262276
}
277+
if (bytesRead < 0) {
278+
closeNow = true;
279+
return 0;
280+
} else {
281+
ioBuffer.flip();
282+
channelBuffer.ensureCapacity(channelBuffer.getIndex() + ioBuffer.remaining());
283+
ByteBuffer[] buffers = channelBuffer.sliceBuffersFrom(channelBuffer.getIndex());
284+
int j = 0;
285+
while (j < buffers.length && ioBuffer.remaining() > 0) {
286+
ByteBuffer buffer = buffers[j++];
287+
copyBytes(ioBuffer, buffer);
288+
}
289+
channelBuffer.incrementIndex(bytesRead);
290+
return bytesRead;
291+
}
263292
}
264293

265294
protected int flushToChannel(ByteBuffer buffer) throws IOException {
295+
int initialPosition = buffer.position();
296+
ByteBuffer ioBuffer = getSelector().getIoBuffer();
297+
copyBytes(buffer, ioBuffer);
298+
ioBuffer.flip();
299+
int bytesWritten;
266300
try {
267-
return rawChannel.write(buffer);
301+
bytesWritten = rawChannel.write(ioBuffer);
268302
} catch (IOException e) {
269303
closeNow = true;
304+
buffer.position(initialPosition);
270305
throw e;
271306
}
307+
buffer.position(initialPosition + bytesWritten);
308+
return bytesWritten;
272309
}
273310

274-
protected int flushToChannel(ByteBuffer[] buffers) throws IOException {
275-
try {
276-
return (int) rawChannel.write(buffers);
277-
} catch (IOException e) {
278-
closeNow = true;
279-
throw e;
311+
protected int flushToChannel(FlushOperation flushOperation) throws IOException {
312+
ByteBuffer ioBuffer = getSelector().getIoBuffer();
313+
314+
boolean continueFlush = flushOperation.isFullyFlushed() == false;
315+
int totalBytesFlushed = 0;
316+
while (continueFlush) {
317+
ioBuffer.clear();
318+
int j = 0;
319+
ByteBuffer[] buffers = flushOperation.getBuffersToWrite();
320+
while (j < buffers.length && ioBuffer.remaining() > 0) {
321+
ByteBuffer buffer = buffers[j++];
322+
copyBytes(buffer, ioBuffer);
323+
}
324+
ioBuffer.flip();
325+
int bytesFlushed;
326+
try {
327+
bytesFlushed = rawChannel.write(ioBuffer);
328+
} catch (IOException e) {
329+
closeNow = true;
330+
throw e;
331+
}
332+
flushOperation.incrementIndex(bytesFlushed);
333+
totalBytesFlushed += bytesFlushed;
334+
continueFlush = ioBuffer.hasRemaining() == false && flushOperation.isFullyFlushed() == false;
280335
}
336+
return totalBytesFlushed;
337+
}
338+
339+
private void copyBytes(ByteBuffer from, ByteBuffer to) {
340+
int nBytesToCopy = Math.min(to.remaining(), from.remaining());
341+
int initialLimit = from.limit();
342+
from.limit(from.position() + nBytesToCopy);
343+
to.put(from);
344+
from.limit(initialLimit);
281345
}
282346
}

libs/nio/src/test/java/org/elasticsearch/nio/BytesChannelContextTests.java

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import java.util.function.Consumer;
3232

3333
import static org.mockito.Matchers.any;
34-
import static org.mockito.Matchers.anyInt;
34+
import static org.mockito.Matchers.eq;
3535
import static org.mockito.Mockito.mock;
3636
import static org.mockito.Mockito.times;
3737
import static org.mockito.Mockito.verify;
@@ -64,14 +64,19 @@ public void init() {
6464
context = new BytesChannelContext(channel, selector, mock(Consumer.class), handler, channelBuffer);
6565

6666
when(selector.isOnCurrentThread()).thenReturn(true);
67+
ByteBuffer buffer = ByteBuffer.allocate(1 << 14);
68+
when(selector.getIoBuffer()).thenAnswer(invocationOnMock -> {
69+
buffer.clear();
70+
return buffer;
71+
});
6772
}
6873

6974
public void testSuccessfulRead() throws IOException {
7075
byte[] bytes = createMessage(messageLength);
7176

72-
when(rawChannel.read(any(ByteBuffer[].class), anyInt(), anyInt())).thenAnswer(invocationOnMock -> {
73-
ByteBuffer[] buffers = (ByteBuffer[]) invocationOnMock.getArguments()[0];
74-
buffers[0].put(bytes);
77+
when(rawChannel.read(any(ByteBuffer.class))).thenAnswer(invocationOnMock -> {
78+
ByteBuffer buffer = (ByteBuffer) invocationOnMock.getArguments()[0];
79+
buffer.put(bytes);
7580
return bytes.length;
7681
});
7782

@@ -87,9 +92,9 @@ public void testSuccessfulRead() throws IOException {
8792
public void testMultipleReadsConsumed() throws IOException {
8893
byte[] bytes = createMessage(messageLength * 2);
8994

90-
when(rawChannel.read(any(ByteBuffer[].class), anyInt(), anyInt())).thenAnswer(invocationOnMock -> {
91-
ByteBuffer[] buffers = (ByteBuffer[]) invocationOnMock.getArguments()[0];
92-
buffers[0].put(bytes);
95+
when(rawChannel.read(any(ByteBuffer.class))).thenAnswer(invocationOnMock -> {
96+
ByteBuffer buffer = (ByteBuffer) invocationOnMock.getArguments()[0];
97+
buffer.put(bytes);
9398
return bytes.length;
9499
});
95100

@@ -105,9 +110,9 @@ public void testMultipleReadsConsumed() throws IOException {
105110
public void testPartialRead() throws IOException {
106111
byte[] bytes = createMessage(messageLength);
107112

108-
when(rawChannel.read(any(ByteBuffer[].class), anyInt(), anyInt())).thenAnswer(invocationOnMock -> {
109-
ByteBuffer[] buffers = (ByteBuffer[]) invocationOnMock.getArguments()[0];
110-
buffers[0].put(bytes);
113+
when(rawChannel.read(any(ByteBuffer.class))).thenAnswer(invocationOnMock -> {
114+
ByteBuffer buffer = (ByteBuffer) invocationOnMock.getArguments()[0];
115+
buffer.put(bytes);
111116
return bytes.length;
112117
});
113118

@@ -130,22 +135,22 @@ public void testPartialRead() throws IOException {
130135

131136
public void testReadThrowsIOException() throws IOException {
132137
IOException ioException = new IOException();
133-
when(rawChannel.read(any(ByteBuffer[].class), anyInt(), anyInt())).thenThrow(ioException);
138+
when(rawChannel.read(any(ByteBuffer.class))).thenThrow(ioException);
134139

135140
IOException ex = expectThrows(IOException.class, () -> context.read());
136141
assertSame(ioException, ex);
137142
}
138143

139144
public void testReadThrowsIOExceptionMeansReadyForClose() throws IOException {
140-
when(rawChannel.read(any(ByteBuffer[].class), anyInt(), anyInt())).thenThrow(new IOException());
145+
when(rawChannel.read(any(ByteBuffer.class))).thenThrow(new IOException());
141146

142147
assertFalse(context.selectorShouldClose());
143148
expectThrows(IOException.class, () -> context.read());
144149
assertTrue(context.selectorShouldClose());
145150
}
146151

147152
public void testReadLessThanZeroMeansReadyForClose() throws IOException {
148-
when(rawChannel.read(any(ByteBuffer[].class), anyInt(), anyInt())).thenReturn(-1L);
153+
when(rawChannel.read(any(ByteBuffer.class))).thenReturn(-1);
149154

150155
assertEquals(0, context.read());
151156

@@ -164,11 +169,13 @@ public void testQueuedWriteIsFlushedInFlushCall() throws Exception {
164169
assertTrue(context.readyForFlush());
165170

166171
when(flushOperation.getBuffersToWrite()).thenReturn(buffers);
167-
when(flushOperation.isFullyFlushed()).thenReturn(true);
172+
when(flushOperation.isFullyFlushed()).thenReturn(false, true);
168173
when(flushOperation.getListener()).thenReturn(listener);
169174
context.flushChannel();
170175

171-
verify(rawChannel).write(buffers, 0, buffers.length);
176+
ByteBuffer buffer = buffers[0].duplicate();
177+
buffer.flip();
178+
verify(rawChannel).write(eq(buffer));
172179
verify(selector).executeListener(listener, null);
173180
assertFalse(context.readyForFlush());
174181
}
@@ -180,7 +187,7 @@ public void testPartialFlush() throws IOException {
180187
assertTrue(context.readyForFlush());
181188

182189
when(flushOperation.isFullyFlushed()).thenReturn(false);
183-
when(flushOperation.getBuffersToWrite()).thenReturn(new ByteBuffer[0]);
190+
when(flushOperation.getBuffersToWrite()).thenReturn(new ByteBuffer[] {ByteBuffer.allocate(3)});
184191
context.flushChannel();
185192

186193
verify(listener, times(0)).accept(null, null);
@@ -194,8 +201,8 @@ public void testMultipleWritesPartialFlushes() throws IOException {
194201
BiConsumer<Void, Exception> listener2 = mock(BiConsumer.class);
195202
FlushReadyWrite flushOperation1 = mock(FlushReadyWrite.class);
196203
FlushReadyWrite flushOperation2 = mock(FlushReadyWrite.class);
197-
when(flushOperation1.getBuffersToWrite()).thenReturn(new ByteBuffer[0]);
198-
when(flushOperation2.getBuffersToWrite()).thenReturn(new ByteBuffer[0]);
204+
when(flushOperation1.getBuffersToWrite()).thenReturn(new ByteBuffer[] {ByteBuffer.allocate(3)});
205+
when(flushOperation2.getBuffersToWrite()).thenReturn(new ByteBuffer[] {ByteBuffer.allocate(3)});
199206
when(flushOperation1.getListener()).thenReturn(listener);
200207
when(flushOperation2.getListener()).thenReturn(listener2);
201208

@@ -204,15 +211,15 @@ public void testMultipleWritesPartialFlushes() throws IOException {
204211

205212
assertTrue(context.readyForFlush());
206213

207-
when(flushOperation1.isFullyFlushed()).thenReturn(true);
214+
when(flushOperation1.isFullyFlushed()).thenReturn(false, true);
208215
when(flushOperation2.isFullyFlushed()).thenReturn(false);
209216
context.flushChannel();
210217

211218
verify(selector).executeListener(listener, null);
212219
verify(listener2, times(0)).accept(null, null);
213220
assertTrue(context.readyForFlush());
214221

215-
when(flushOperation2.isFullyFlushed()).thenReturn(true);
222+
when(flushOperation2.isFullyFlushed()).thenReturn(false, true);
216223

217224
context.flushChannel();
218225

@@ -231,7 +238,7 @@ public void testWhenIOExceptionThrownListenerIsCalled() throws IOException {
231238

232239
IOException exception = new IOException();
233240
when(flushOperation.getBuffersToWrite()).thenReturn(buffers);
234-
when(rawChannel.write(buffers, 0, buffers.length)).thenThrow(exception);
241+
when(rawChannel.write(any(ByteBuffer.class))).thenThrow(exception);
235242
when(flushOperation.getListener()).thenReturn(listener);
236243
expectThrows(IOException.class, () -> context.flushChannel());
237244

@@ -246,7 +253,7 @@ public void testWriteIOExceptionMeansChannelReadyToClose() throws IOException {
246253

247254
IOException exception = new IOException();
248255
when(flushOperation.getBuffersToWrite()).thenReturn(buffers);
249-
when(rawChannel.write(buffers, 0, buffers.length)).thenThrow(exception);
256+
when(rawChannel.write(any(ByteBuffer.class))).thenThrow(exception);
250257

251258
assertFalse(context.selectorShouldClose());
252259
expectThrows(IOException.class, () -> context.flushChannel());

0 commit comments

Comments
 (0)