diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java index e7b66a6f33a8..b81c25afc737 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java @@ -140,8 +140,24 @@ private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOExcept // SPARK-24578: cap the sub-region's size of returned nio buffer to improve the performance // for the case that the passed-in buffer has too many components. int length = Math.min(buf.readableBytes(), NIO_BUFFER_LIMIT); - ByteBuffer buffer = buf.nioBuffer(buf.readerIndex(), length); - int written = target.write(buffer); + // If the ByteBuf holds more then one ByteBuffer we should better call nioBuffers(...) + // to eliminate extra memory copies. + int written = 0; + if (buf.nioBufferCount() == 1) { + ByteBuffer buffer = buf.nioBuffer(buf.readerIndex(), length); + written = target.write(buffer); + } else { + ByteBuffer[] buffers = buf.nioBuffers(buf.readerIndex(), length); + for (ByteBuffer buffer: buffers) { + int remaining = buffer.remaining(); + int w = target.write(buffer); + written += w; + if (w < remaining) { + // Could not write all, we need to break now. + break; + } + } + } buf.skipBytes(written); return written; } diff --git a/common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java b/common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java index ecb66fcf2ff7..3bff34e210e3 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java @@ -22,6 +22,7 @@ import java.nio.channels.WritableByteChannel; import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; import io.netty.buffer.Unpooled; import org.apache.spark.network.util.AbstractFileRegion; import org.junit.Test; @@ -48,7 +49,36 @@ public void testShortWrite() throws Exception { @Test public void testByteBufBody() throws Exception { + testByteBufBody(Unpooled.copyLong(42)); + } + + @Test + public void testCompositeByteBufBodySingleBuffer() throws Exception { + ByteBuf header = Unpooled.copyLong(42); + CompositeByteBuf compositeByteBuf = Unpooled.compositeBuffer(); + compositeByteBuf.addComponent(true, header); + assertEquals(1, compositeByteBuf.nioBufferCount()); + testByteBufBody(compositeByteBuf); + } + + @Test + public void testCompositeByteBufBodyMultipleBuffers() throws Exception { ByteBuf header = Unpooled.copyLong(42); + CompositeByteBuf compositeByteBuf = Unpooled.compositeBuffer(); + compositeByteBuf.addComponent(true, header.retainedSlice(0, 4)); + compositeByteBuf.addComponent(true, header.slice(4, 4)); + assertEquals(2, compositeByteBuf.nioBufferCount()); + testByteBufBody(compositeByteBuf); + } + + /** + * Test writing a {@link MessageWithHeader} using the given {@link ByteBuf} as header. + * + * @param header the header to use. + * @throws Exception thrown on error. + */ + private void testByteBufBody(ByteBuf header) throws Exception { + long expectedHeaderValue = header.getLong(header.readerIndex()); ByteBuf bodyPassedToNettyManagedBuffer = Unpooled.copyLong(84); assertEquals(1, header.refCnt()); assertEquals(1, bodyPassedToNettyManagedBuffer.refCnt()); @@ -61,7 +91,7 @@ public void testByteBufBody() throws Exception { MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, managedBuf.size()); ByteBuf result = doWrite(msg, 1); assertEquals(msg.count(), result.readableBytes()); - assertEquals(42, result.readLong()); + assertEquals(expectedHeaderValue, result.readLong()); assertEquals(84, result.readLong()); assertTrue(msg.release());