From ec8cc24ddd2b35a074674b06503ca710a17c1af2 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 12 Feb 2016 16:09:44 -0800 Subject: [PATCH 1/7] Add missing ManagedBuffer.release() call. --- .../network/protocol/MessageEncoder.java | 12 ++++++-- .../network/protocol/MessageWithHeader.java | 28 ++++++++++++++++++- .../protocol/MessageWithHeaderSuite.java | 7 +++-- 3 files changed, 41 insertions(+), 6 deletions(-) diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java index abca22347b78..442b60bdb9b5 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java @@ -49,11 +49,15 @@ public void encode(ChannelHandlerContext ctx, Message in, List out) thro // If the message has a body, take it out to enable zero-copy transfer for the payload. if (in.body() != null) { + bodyLength = in.body().size(); + if (bodyLength > 0) { + in.body().retain(); + } try { - bodyLength = in.body().size(); body = in.body().convertToNetty(); isBodyInFrame = in.isBodyInFrame(); } catch (Exception e) { + in.body().release(); if (in instanceof AbstractResponseMessage) { AbstractResponseMessage resp = (AbstractResponseMessage) in; // Re-encode this message as a failure response. @@ -80,8 +84,10 @@ public void encode(ChannelHandlerContext ctx, Message in, List out) thro in.encode(header); assert header.writableBytes() == 0; - if (body != null && bodyLength > 0) { - out.add(new MessageWithHeader(header, body, bodyLength)); + if (body != null) { + // We transfer ownership of the reference on in.body() to MessageWithHeader. + // This reference will be freed when MessageWithHeader.deallocate() is called. + out.add(new MessageWithHeader(in.body(), header, body, bodyLength)); } else { out.add(header); } diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java index d686a951467c..66227f96a1a2 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.nio.channels.WritableByteChannel; +import javax.annotation.Nullable; import com.google.common.base.Preconditions; import io.netty.buffer.ByteBuf; @@ -26,6 +27,8 @@ import io.netty.util.AbstractReferenceCounted; import io.netty.util.ReferenceCountUtil; +import org.apache.spark.network.buffer.ManagedBuffer; + /** * A wrapper message that holds two separate pieces (a header and a body). * @@ -33,15 +36,35 @@ */ class MessageWithHeader extends AbstractReferenceCounted implements FileRegion { + @Nullable private final ManagedBuffer managedBuffer; private final ByteBuf header; private final int headerLength; private final Object body; private final long bodyLength; private long totalBytesTransferred; - MessageWithHeader(ByteBuf header, Object body, long bodyLength) { + /** + * Construct a new MessageWithHeader. + * + * @param managedBuffer the {@link ManagedBuffer} that the message body came from. This needs to + * be passed in so that the buffer can be freed when this message is + * deallocated. Ownership of the caller's reference to this buffer is + * transferred to this class, so if the caller wants to continue to use the + * ManagedBuffer in other messages then they will need to call retain() on + * it before passing it to this constructor. This may be null if and only if + * `body` is a {@link FileRegion}. + * @param header the message header. + * @param body the message body. Must be either a {@link ByteBuf} or a {@link FileRegion}. + * @param bodyLength the length of the message body, in bytes. + */ + MessageWithHeader( + @Nullable ManagedBuffer managedBuffer, + ByteBuf header, + Object body, + long bodyLength) { Preconditions.checkArgument(body instanceof ByteBuf || body instanceof FileRegion, "Body must be a ByteBuf or a FileRegion."); + this.managedBuffer = managedBuffer; this.header = header; this.headerLength = header.readableBytes(); this.body = body; @@ -99,6 +122,9 @@ public long transferTo(final WritableByteChannel target, final long position) th protected void deallocate() { header.release(); ReferenceCountUtil.release(body); + if (managedBuffer != null) { + managedBuffer.release(); + } } private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOException { diff --git a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java index 6c98e733b462..50035fba5702 100644 --- a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java @@ -29,6 +29,8 @@ import static org.junit.Assert.*; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; import org.apache.spark.network.util.ByteArrayWritableChannel; public class MessageWithHeaderSuite { @@ -47,7 +49,8 @@ public void testShortWrite() throws Exception { public void testByteBufBody() throws Exception { ByteBuf header = Unpooled.copyLong(42); ByteBuf body = Unpooled.copyLong(84); - MessageWithHeader msg = new MessageWithHeader(header, body, body.readableBytes()); + ManagedBuffer managedBuf = new NettyManagedBuffer(body); + MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, body.readableBytes()); ByteBuf result = doWrite(msg, 1); assertEquals(msg.count(), result.readableBytes()); @@ -59,7 +62,7 @@ private void testFileRegionBody(int totalWrites, int writesPerCall) throws Excep ByteBuf header = Unpooled.copyLong(42); int headerLength = header.readableBytes(); TestFileRegion region = new TestFileRegion(totalWrites, writesPerCall); - MessageWithHeader msg = new MessageWithHeader(header, region, region.count()); + MessageWithHeader msg = new MessageWithHeader(null, header, region, region.count()); ByteBuf result = doWrite(msg, totalWrites / writesPerCall); assertEquals(headerLength + region.count(), result.readableBytes()); From 613498911f3867cddb72d0a3b235260e4e7d1433 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 12 Feb 2016 16:51:29 -0800 Subject: [PATCH 2/7] Add a test for OneForOneStreamManager.connectionTerminated --- .../server/OneForOneStreamManager.java | 1 - .../server/OneForOneStreamManagerSuite.java | 51 +++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) create mode 100644 network/common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java diff --git a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java index e671854da1ca..ea9e735e0a17 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java +++ b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java @@ -20,7 +20,6 @@ import java.util.Iterator; import java.util.Map; import java.util.Random; -import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; diff --git a/network/common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java b/network/common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java new file mode 100644 index 000000000000..6356ac6c24f8 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import java.util.ArrayList; +import java.util.List; + +import io.netty.channel.Channel; +import org.junit.Test; +import org.mockito.Mockito; +import static org.mockito.Mockito.*; + +import org.apache.spark.network.TestManagedBuffer; +import org.apache.spark.network.buffer.ManagedBuffer; + +public class OneForOneStreamManagerSuite { + + @Test + public void managedBuffersAreFeedWhenConnectionIsClosed() throws Exception { + OneForOneStreamManager manager = new OneForOneStreamManager(); + List buffers = new ArrayList<>(); + TestManagedBuffer buffer1 = Mockito.spy(new TestManagedBuffer(10)); + TestManagedBuffer buffer2 = Mockito.spy(new TestManagedBuffer(20)); + buffers.add(buffer1); + buffers.add(buffer2); + long streamId = manager.registerStream("appId", buffers.iterator()); + + Channel dummyChannel = Mockito.mock(Channel.class); + manager.registerChannel(dummyChannel, streamId); + + manager.connectionTerminated(dummyChannel); + + Mockito.verify(buffer1, times(1)).release(); + Mockito.verify(buffer2, times(1)).release(); + } +} From c9726c26fe5f92d84265346c3998d7ab40e75cff Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 12 Feb 2016 17:06:35 -0800 Subject: [PATCH 3/7] Add tests covering new release() call. --- .../network/protocol/MessageWithHeaderSuite.java | 16 +++++++++++++++- .../server/OneForOneStreamManagerSuite.java | 7 +++---- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java index 50035fba5702..2f9bd7e81fab 100644 --- a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java @@ -26,9 +26,11 @@ import io.netty.channel.FileRegion; import io.netty.util.AbstractReferenceCounted; import org.junit.Test; +import org.mockito.Mockito; import static org.junit.Assert.*; +import org.apache.spark.network.TestManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NettyManagedBuffer; import org.apache.spark.network.util.ByteArrayWritableChannel; @@ -48,7 +50,7 @@ public void testShortWrite() throws Exception { @Test public void testByteBufBody() throws Exception { ByteBuf header = Unpooled.copyLong(42); - ByteBuf body = Unpooled.copyLong(84); + ByteBuf body = Unpooled.copyLong(84).retain(); ManagedBuffer managedBuf = new NettyManagedBuffer(body); MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, body.readableBytes()); @@ -56,6 +58,17 @@ public void testByteBufBody() throws Exception { assertEquals(msg.count(), result.readableBytes()); assertEquals(42, result.readLong()); assertEquals(84, result.readLong()); + msg.deallocate(); + } + + @Test + public void testDeallocateReleasesManagedBuffer() throws Exception { + ByteBuf header = Unpooled.copyLong(42); + ManagedBuffer managedBuf = Mockito.spy(new TestManagedBuffer(84)); + ByteBuf body = ((ByteBuf) managedBuf.convertToNetty()).retain(); + MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, body.readableBytes()); + msg.deallocate(); + Mockito.verify(managedBuf, Mockito.times(1)).release(); } private void testFileRegionBody(int totalWrites, int writesPerCall) throws Exception { @@ -70,6 +83,7 @@ private void testFileRegionBody(int totalWrites, int writesPerCall) throws Excep for (long i = 0; i < 8; i++) { assertEquals(i, result.readLong()); } + msg.deallocate(); } private ByteBuf doWrite(MessageWithHeader msg, int minExpectedWrites) throws Exception { diff --git a/network/common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java b/network/common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java index 6356ac6c24f8..c647525d8f1b 100644 --- a/network/common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java @@ -23,7 +23,6 @@ import io.netty.channel.Channel; import org.junit.Test; import org.mockito.Mockito; -import static org.mockito.Mockito.*; import org.apache.spark.network.TestManagedBuffer; import org.apache.spark.network.buffer.ManagedBuffer; @@ -40,12 +39,12 @@ public void managedBuffersAreFeedWhenConnectionIsClosed() throws Exception { buffers.add(buffer2); long streamId = manager.registerStream("appId", buffers.iterator()); - Channel dummyChannel = Mockito.mock(Channel.class); + Channel dummyChannel = Mockito.mock(Channel.class, Mockito.RETURNS_SMART_NULLS); manager.registerChannel(dummyChannel, streamId); manager.connectionTerminated(dummyChannel); - Mockito.verify(buffer1, times(1)).release(); - Mockito.verify(buffer2, times(1)).release(); + Mockito.verify(buffer1, Mockito.times(1)).release(); + Mockito.verify(buffer2, Mockito.times(1)).release(); } } From e5cf48d1455e4ebd85791e8909cfa9d938a591fb Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 13 Feb 2016 14:19:50 -0800 Subject: [PATCH 4/7] Remove bad retain. --- .../org/apache/spark/network/protocol/MessageEncoder.java | 5 +---- .../java/org/apache/spark/network/TestManagedBuffer.java | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java index 442b60bdb9b5..664df57feca4 100644 --- a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java @@ -49,11 +49,8 @@ public void encode(ChannelHandlerContext ctx, Message in, List out) thro // If the message has a body, take it out to enable zero-copy transfer for the payload. if (in.body() != null) { - bodyLength = in.body().size(); - if (bodyLength > 0) { - in.body().retain(); - } try { + bodyLength = in.body().size(); body = in.body().convertToNetty(); isBodyInFrame = in.isBodyInFrame(); } catch (Exception e) { diff --git a/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java b/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java index 83c90f9eff2b..e15b1309efe5 100644 --- a/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java +++ b/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java @@ -44,7 +44,7 @@ public TestManagedBuffer(int len) { for (int i = 0; i < len; i ++) { byteArray[i] = (byte) i; } - this.underlying = new NettyManagedBuffer(Unpooled.wrappedBuffer(byteArray)); + this.underlying = new NettyManagedBuffer(Unpooled.wrappedBuffer(byteArray).retain()); } From 2c00f29272051b8092b6a8a976392e32eeb5488b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 15 Feb 2016 20:12:03 -0800 Subject: [PATCH 5/7] Update to address review feedback regarding retain() call sites. --- .../spark/network/buffer/ManagedBuffer.java | 6 +++++- .../network/buffer/NettyManagedBuffer.java | 2 +- .../spark/network/TestManagedBuffer.java | 2 +- .../protocol/MessageWithHeaderSuite.java | 19 +++++++++++++++---- 4 files changed, 22 insertions(+), 7 deletions(-) diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java index a415db593a78..1861f8d7fd8f 100644 --- a/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java +++ b/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java @@ -65,7 +65,11 @@ public abstract class ManagedBuffer { public abstract ManagedBuffer release(); /** - * Convert the buffer into an Netty object, used to write the data out. + * Convert the buffer into an Netty object, used to write the data out. The return value is either + * a {@link io.netty.buffer.ByteBuf} or a {@link io.netty.channel.FileRegion}. + * + * If this method returns a ByteBuf, then that buffer's reference count will be incremented and + * the caller will be responsible for releasing this new reference. */ public abstract Object convertToNetty() throws IOException; } diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java index c806bfa45bef..4c8802af7ae6 100644 --- a/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java +++ b/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java @@ -64,7 +64,7 @@ public ManagedBuffer release() { @Override public Object convertToNetty() throws IOException { - return buf.duplicate(); + return buf.duplicate().retain(); } @Override diff --git a/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java b/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java index e15b1309efe5..83c90f9eff2b 100644 --- a/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java +++ b/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java @@ -44,7 +44,7 @@ public TestManagedBuffer(int len) { for (int i = 0; i < len; i ++) { byteArray[i] = (byte) i; } - this.underlying = new NettyManagedBuffer(Unpooled.wrappedBuffer(byteArray).retain()); + this.underlying = new NettyManagedBuffer(Unpooled.wrappedBuffer(byteArray)); } diff --git a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java index 2f9bd7e81fab..d3311195f6a4 100644 --- a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java @@ -50,25 +50,36 @@ public void testShortWrite() throws Exception { @Test public void testByteBufBody() throws Exception { ByteBuf header = Unpooled.copyLong(42); - ByteBuf body = Unpooled.copyLong(84).retain(); - ManagedBuffer managedBuf = new NettyManagedBuffer(body); - MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, body.readableBytes()); + ByteBuf bodyPassedToNettyManagedBuffer = Unpooled.copyLong(84); + assertEquals(1, header.refCnt()); + assertEquals(1, bodyPassedToNettyManagedBuffer.refCnt()); + ManagedBuffer managedBuf = new NettyManagedBuffer(bodyPassedToNettyManagedBuffer); + + Object body = managedBuf.convertToNetty(); + assertEquals(2, bodyPassedToNettyManagedBuffer.refCnt()); + assertEquals(1, header.refCnt()); + MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, managedBuf.size()); ByteBuf result = doWrite(msg, 1); assertEquals(msg.count(), result.readableBytes()); assertEquals(42, result.readLong()); assertEquals(84, result.readLong()); + msg.deallocate(); + assertEquals(0, bodyPassedToNettyManagedBuffer.refCnt()); + assertEquals(0, header.refCnt()); } @Test public void testDeallocateReleasesManagedBuffer() throws Exception { ByteBuf header = Unpooled.copyLong(42); ManagedBuffer managedBuf = Mockito.spy(new TestManagedBuffer(84)); - ByteBuf body = ((ByteBuf) managedBuf.convertToNetty()).retain(); + ByteBuf body = (ByteBuf) managedBuf.convertToNetty(); + assertEquals(2, body.refCnt()); MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, body.readableBytes()); msg.deallocate(); Mockito.verify(managedBuf, Mockito.times(1)).release(); + assertEquals(0, body.refCnt()); } private void testFileRegionBody(int totalWrites, int writesPerCall) throws Exception { From cb99750c2464d06a0ec043922699385f196f946d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 16 Feb 2016 09:50:54 -0800 Subject: [PATCH 6/7] release() -> deallocate() in tests. --- .../apache/spark/network/protocol/MessageWithHeaderSuite.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java index d3311195f6a4..e3ea7523b534 100644 --- a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java @@ -77,7 +77,7 @@ public void testDeallocateReleasesManagedBuffer() throws Exception { ByteBuf body = (ByteBuf) managedBuf.convertToNetty(); assertEquals(2, body.refCnt()); MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, body.readableBytes()); - msg.deallocate(); + assert(msg.release()); Mockito.verify(managedBuf, Mockito.times(1)).release(); assertEquals(0, body.refCnt()); } @@ -94,7 +94,7 @@ private void testFileRegionBody(int totalWrites, int writesPerCall) throws Excep for (long i = 0; i < 8; i++) { assertEquals(i, result.readLong()); } - msg.deallocate(); + assert(msg.release()); } private ByteBuf doWrite(MessageWithHeader msg, int minExpectedWrites) throws Exception { From 014ca9be204edb9428fae569921571012d3fdbde Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 16 Feb 2016 10:25:40 -0800 Subject: [PATCH 7/7] Fix last deallocate(). --- .../apache/spark/network/protocol/MessageWithHeaderSuite.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java index e3ea7523b534..fbbe4b7014ff 100644 --- a/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java @@ -65,7 +65,7 @@ public void testByteBufBody() throws Exception { assertEquals(42, result.readLong()); assertEquals(84, result.readLong()); - msg.deallocate(); + assert(msg.release()); assertEquals(0, bodyPassedToNettyManagedBuffer.refCnt()); assertEquals(0, header.refCnt()); }