Skip to content

Commit 5f37aad

Browse files
JoshRosenzsxwing
authored andcommitted
[SPARK-13308] ManagedBuffers passed to OneToOneStreamManager need to be freed in non-error cases
ManagedBuffers that are passed to `OneToOneStreamManager.registerStream` need to be freed by the manager once it's done using them. However, the current code only frees them in certain error-cases and not during typical operation. This isn't a major problem today, but it will cause memory leaks after we implement better locking / pinning in the BlockManager (see #10705). This patch modifies the relevant network code so that the ManagedBuffers are freed as soon as the messages containing them are processed by the lower-level Netty message sending code. /cc zsxwing for review. Author: Josh Rosen <[email protected]> Closes #11193 from JoshRosen/add-missing-release-calls-in-network-layer.
1 parent c7d00a2 commit 5f37aad

File tree

7 files changed

+119
-9
lines changed

7 files changed

+119
-9
lines changed

network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,11 @@ public abstract class ManagedBuffer {
6565
public abstract ManagedBuffer release();
6666

6767
/**
68-
* Convert the buffer into an Netty object, used to write the data out.
68+
* Convert the buffer into an Netty object, used to write the data out. The return value is either
69+
* a {@link io.netty.buffer.ByteBuf} or a {@link io.netty.channel.FileRegion}.
70+
*
71+
* If this method returns a ByteBuf, then that buffer's reference count will be incremented and
72+
* the caller will be responsible for releasing this new reference.
6973
*/
7074
public abstract Object convertToNetty() throws IOException;
7175
}

network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ public ManagedBuffer release() {
6464

6565
@Override
6666
public Object convertToNetty() throws IOException {
67-
return buf.duplicate();
67+
return buf.duplicate().retain();
6868
}
6969

7070
@Override

network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ public void encode(ChannelHandlerContext ctx, Message in, List<Object> out) thro
5454
body = in.body().convertToNetty();
5555
isBodyInFrame = in.isBodyInFrame();
5656
} catch (Exception e) {
57+
in.body().release();
5758
if (in instanceof AbstractResponseMessage) {
5859
AbstractResponseMessage resp = (AbstractResponseMessage) in;
5960
// Re-encode this message as a failure response.
@@ -80,8 +81,10 @@ public void encode(ChannelHandlerContext ctx, Message in, List<Object> out) thro
8081
in.encode(header);
8182
assert header.writableBytes() == 0;
8283

83-
if (body != null && bodyLength > 0) {
84-
out.add(new MessageWithHeader(header, body, bodyLength));
84+
if (body != null) {
85+
// We transfer ownership of the reference on in.body() to MessageWithHeader.
86+
// This reference will be freed when MessageWithHeader.deallocate() is called.
87+
out.add(new MessageWithHeader(in.body(), header, body, bodyLength));
8588
} else {
8689
out.add(header);
8790
}

network/common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,29 +19,52 @@
1919

2020
import java.io.IOException;
2121
import java.nio.channels.WritableByteChannel;
22+
import javax.annotation.Nullable;
2223

2324
import com.google.common.base.Preconditions;
2425
import io.netty.buffer.ByteBuf;
2526
import io.netty.channel.FileRegion;
2627
import io.netty.util.AbstractReferenceCounted;
2728
import io.netty.util.ReferenceCountUtil;
2829

30+
import org.apache.spark.network.buffer.ManagedBuffer;
31+
2932
/**
3033
* A wrapper message that holds two separate pieces (a header and a body).
3134
*
3235
* The header must be a ByteBuf, while the body can be a ByteBuf or a FileRegion.
3336
*/
3437
class MessageWithHeader extends AbstractReferenceCounted implements FileRegion {
3538

39+
@Nullable private final ManagedBuffer managedBuffer;
3640
private final ByteBuf header;
3741
private final int headerLength;
3842
private final Object body;
3943
private final long bodyLength;
4044
private long totalBytesTransferred;
4145

42-
MessageWithHeader(ByteBuf header, Object body, long bodyLength) {
46+
/**
47+
* Construct a new MessageWithHeader.
48+
*
49+
* @param managedBuffer the {@link ManagedBuffer} that the message body came from. This needs to
50+
* be passed in so that the buffer can be freed when this message is
51+
* deallocated. Ownership of the caller's reference to this buffer is
52+
* transferred to this class, so if the caller wants to continue to use the
53+
* ManagedBuffer in other messages then they will need to call retain() on
54+
* it before passing it to this constructor. This may be null if and only if
55+
* `body` is a {@link FileRegion}.
56+
* @param header the message header.
57+
* @param body the message body. Must be either a {@link ByteBuf} or a {@link FileRegion}.
58+
* @param bodyLength the length of the message body, in bytes.
59+
*/
60+
MessageWithHeader(
61+
@Nullable ManagedBuffer managedBuffer,
62+
ByteBuf header,
63+
Object body,
64+
long bodyLength) {
4365
Preconditions.checkArgument(body instanceof ByteBuf || body instanceof FileRegion,
4466
"Body must be a ByteBuf or a FileRegion.");
67+
this.managedBuffer = managedBuffer;
4568
this.header = header;
4669
this.headerLength = header.readableBytes();
4770
this.body = body;
@@ -99,6 +122,9 @@ public long transferTo(final WritableByteChannel target, final long position) th
99122
protected void deallocate() {
100123
header.release();
101124
ReferenceCountUtil.release(body);
125+
if (managedBuffer != null) {
126+
managedBuffer.release();
127+
}
102128
}
103129

104130
private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOException {

network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import java.util.Iterator;
2121
import java.util.Map;
2222
import java.util.Random;
23-
import java.util.Set;
2423
import java.util.concurrent.ConcurrentHashMap;
2524
import java.util.concurrent.atomic.AtomicLong;
2625

network/common/src/test/java/org/apache/spark/network/protocol/MessageWithHeaderSuite.java

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,13 @@
2626
import io.netty.channel.FileRegion;
2727
import io.netty.util.AbstractReferenceCounted;
2828
import org.junit.Test;
29+
import org.mockito.Mockito;
2930

3031
import static org.junit.Assert.*;
3132

33+
import org.apache.spark.network.TestManagedBuffer;
34+
import org.apache.spark.network.buffer.ManagedBuffer;
35+
import org.apache.spark.network.buffer.NettyManagedBuffer;
3236
import org.apache.spark.network.util.ByteArrayWritableChannel;
3337

3438
public class MessageWithHeaderSuite {
@@ -46,27 +50,51 @@ public void testShortWrite() throws Exception {
4650
@Test
4751
public void testByteBufBody() throws Exception {
4852
ByteBuf header = Unpooled.copyLong(42);
49-
ByteBuf body = Unpooled.copyLong(84);
50-
MessageWithHeader msg = new MessageWithHeader(header, body, body.readableBytes());
53+
ByteBuf bodyPassedToNettyManagedBuffer = Unpooled.copyLong(84);
54+
assertEquals(1, header.refCnt());
55+
assertEquals(1, bodyPassedToNettyManagedBuffer.refCnt());
56+
ManagedBuffer managedBuf = new NettyManagedBuffer(bodyPassedToNettyManagedBuffer);
5157

58+
Object body = managedBuf.convertToNetty();
59+
assertEquals(2, bodyPassedToNettyManagedBuffer.refCnt());
60+
assertEquals(1, header.refCnt());
61+
62+
MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, managedBuf.size());
5263
ByteBuf result = doWrite(msg, 1);
5364
assertEquals(msg.count(), result.readableBytes());
5465
assertEquals(42, result.readLong());
5566
assertEquals(84, result.readLong());
67+
68+
assert(msg.release());
69+
assertEquals(0, bodyPassedToNettyManagedBuffer.refCnt());
70+
assertEquals(0, header.refCnt());
71+
}
72+
73+
@Test
74+
public void testDeallocateReleasesManagedBuffer() throws Exception {
75+
ByteBuf header = Unpooled.copyLong(42);
76+
ManagedBuffer managedBuf = Mockito.spy(new TestManagedBuffer(84));
77+
ByteBuf body = (ByteBuf) managedBuf.convertToNetty();
78+
assertEquals(2, body.refCnt());
79+
MessageWithHeader msg = new MessageWithHeader(managedBuf, header, body, body.readableBytes());
80+
assert(msg.release());
81+
Mockito.verify(managedBuf, Mockito.times(1)).release();
82+
assertEquals(0, body.refCnt());
5683
}
5784

5885
private void testFileRegionBody(int totalWrites, int writesPerCall) throws Exception {
5986
ByteBuf header = Unpooled.copyLong(42);
6087
int headerLength = header.readableBytes();
6188
TestFileRegion region = new TestFileRegion(totalWrites, writesPerCall);
62-
MessageWithHeader msg = new MessageWithHeader(header, region, region.count());
89+
MessageWithHeader msg = new MessageWithHeader(null, header, region, region.count());
6390

6491
ByteBuf result = doWrite(msg, totalWrites / writesPerCall);
6592
assertEquals(headerLength + region.count(), result.readableBytes());
6693
assertEquals(42, result.readLong());
6794
for (long i = 0; i < 8; i++) {
6895
assertEquals(i, result.readLong());
6996
}
97+
assert(msg.release());
7098
}
7199

72100
private ByteBuf doWrite(MessageWithHeader msg, int minExpectedWrites) throws Exception {
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.network.server;
19+
20+
import java.util.ArrayList;
21+
import java.util.List;
22+
23+
import io.netty.channel.Channel;
24+
import org.junit.Test;
25+
import org.mockito.Mockito;
26+
27+
import org.apache.spark.network.TestManagedBuffer;
28+
import org.apache.spark.network.buffer.ManagedBuffer;
29+
30+
public class OneForOneStreamManagerSuite {
31+
32+
@Test
33+
public void managedBuffersAreFeedWhenConnectionIsClosed() throws Exception {
34+
OneForOneStreamManager manager = new OneForOneStreamManager();
35+
List<ManagedBuffer> buffers = new ArrayList<>();
36+
TestManagedBuffer buffer1 = Mockito.spy(new TestManagedBuffer(10));
37+
TestManagedBuffer buffer2 = Mockito.spy(new TestManagedBuffer(20));
38+
buffers.add(buffer1);
39+
buffers.add(buffer2);
40+
long streamId = manager.registerStream("appId", buffers.iterator());
41+
42+
Channel dummyChannel = Mockito.mock(Channel.class, Mockito.RETURNS_SMART_NULLS);
43+
manager.registerChannel(dummyChannel, streamId);
44+
45+
manager.connectionTerminated(dummyChannel);
46+
47+
Mockito.verify(buffer1, Mockito.times(1)).release();
48+
Mockito.verify(buffer2, Mockito.times(1)).release();
49+
}
50+
}

0 commit comments

Comments
 (0)