Skip to content

Commit beb1454

Browse files
committed
[SPARK-25115] [Core] Eliminate extra memory copy done when a ByteBuf is used that is backed by > 1 ByteBuffer.
Check how many ByteBuffer are used and depending on it do either call nioBuffer(...) or nioBuffers(...) to eliminate extra memory copies. This is related to netty/netty#8176. Unit tests added.
1 parent 80784a1 commit beb1454

File tree

2 files changed

+49
-3
lines changed

2 files changed

+49
-3
lines changed

common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,24 @@ private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOExcept
140140
// SPARK-24578: cap the sub-region's size of returned nio buffer to improve the performance
141141
// for the case that the passed-in buffer has too many components.
142142
int length = Math.min(buf.readableBytes(), NIO_BUFFER_LIMIT);
143-
ByteBuffer buffer = buf.nioBuffer(buf.readerIndex(), length);
144-
int written = target.write(buffer);
143+
// If the ByteBuf holds more then one ByteBuffer we should better call nioBuffers(...)
144+
// to eliminate extra memory copies.
145+
int written = 0;
146+
if (buf.nioBufferCount() == 1) {
147+
ByteBuffer buffer = buf.nioBuffer(buf.readerIndex(), length);
148+
written = target.write(buffer);
149+
} else {
150+
ByteBuffer[] buffers = buf.nioBuffers(buf.readerIndex(), length);
151+
for (ByteBuffer buffer: buffers) {
152+
int remaining = buffer.remaining();
153+
int w = target.write(buffer);
154+
written += w;
155+
if (w < remaining) {
156+
// Could not write all, we need to break now.
157+
break;
158+
}
159+
}
160+
}
145161
buf.skipBytes(written);
146162
return written;
147163
}

common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.nio.channels.WritableByteChannel;
2323

2424
import io.netty.buffer.ByteBuf;
25+
import io.netty.buffer.CompositeByteBuf;
2526
import io.netty.buffer.Unpooled;
2627
import org.apache.spark.network.util.AbstractFileRegion;
2728
import org.junit.Test;
@@ -48,7 +49,36 @@ public void testShortWrite() throws Exception {
4849

4950
@Test
5051
public void testByteBufBody() throws Exception {
52+
testByteBufBody(Unpooled.copyLong(42));
53+
}
54+
55+
@Test
56+
public void testCompositeByteBufBodySingleBuffer() throws Exception {
57+
ByteBuf header = Unpooled.copyLong(42);
58+
CompositeByteBuf compositeByteBuf = Unpooled.compositeBuffer();
59+
compositeByteBuf.addComponent(true, header);
60+
assertEquals(1, compositeByteBuf.nioBufferCount());
61+
testByteBufBody(compositeByteBuf);
62+
}
63+
64+
@Test
65+
public void testCompositeByteBufBodyMultipleBuffers() throws Exception {
5166
ByteBuf header = Unpooled.copyLong(42);
67+
CompositeByteBuf compositeByteBuf = Unpooled.compositeBuffer();
68+
compositeByteBuf.addComponent(true, header.retainedSlice(0, 4));
69+
compositeByteBuf.addComponent(true, header.slice(4, 4));
70+
assertEquals(2, compositeByteBuf.nioBufferCount());
71+
testByteBufBody(compositeByteBuf);
72+
}
73+
74+
/**
75+
* Test writing a {@link MessageWithHeader} using the given {@link ByteBuf} as header.
76+
*
77+
* @param header the header to use.
78+
* @throws Exception thrown on error.
79+
*/
80+
private void testByteBufBody(ByteBuf header) throws Exception {
81+
long expectedHeaderValue = header.getLong(header.readerIndex());
5282
ByteBuf bodyPassedToNettyManagedBuffer = Unpooled.copyLong(84);
5383
assertEquals(1, header.refCnt());
5484
assertEquals(1, bodyPassedToNettyManagedBuffer.refCnt());
@@ -61,7 +91,7 @@ public void testByteBufBody() throws Exception {
6191
MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, managedBuf.size());
6292
ByteBuf result = doWrite(msg, 1);
6393
assertEquals(msg.count(), result.readableBytes());
64-
assertEquals(42, result.readLong());
94+
assertEquals(expectedHeaderValue, result.readLong());
6595
assertEquals(84, result.readLong());
6696

6797
assertTrue(msg.release());

0 commit comments

Comments
 (0)