Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -69,37 +69,22 @@ public void close() {
public void handleBytes(TcpChannel channel, ReleasableBytesReference reference) throws IOException {
pending.add(reference.retain());

final ReleasableBytesReference composite;
if (pending.size() == 1) {
composite = pending.peekFirst();
} else {
final ReleasableBytesReference[] bytesReferences = pending.toArray(new ReleasableBytesReference[0]);
final Releasable releasable = () -> Releasables.closeWhileHandlingException(bytesReferences);
composite = new ReleasableBytesReference(new CompositeBytesReference(bytesReferences), releasable);
}

final ArrayList<Object> fragments = fragmentList.get();
int bytesConsumed = 0;
boolean continueHandling = true;

while (continueHandling && isClosed == false) {
boolean continueDecoding = true;
while (continueDecoding) {
final int remaining = composite.length() - bytesConsumed;
if (remaining != 0) {
try (ReleasableBytesReference slice = composite.retainedSlice(bytesConsumed, remaining)) {
final int bytesDecoded = decoder.decode(slice, fragments::add);
if (bytesDecoded != 0) {
bytesConsumed += bytesDecoded;
if (fragments.isEmpty() == false && endOfMessage(fragments.get(fragments.size() - 1))) {
continueDecoding = false;
}
} else {
while (continueDecoding && pending.isEmpty() == false) {
try (ReleasableBytesReference toDecode = getPendingBytes()) {
final int bytesDecoded = decoder.decode(toDecode, fragments::add);
if (bytesDecoded != 0) {
releasePendingBytes(bytesDecoded);
if (fragments.isEmpty() == false && endOfMessage(fragments.get(fragments.size() - 1))) {
continueDecoding = false;
}
} else {
continueDecoding = false;
}
} else {
continueDecoding = false;
}
}

Expand All @@ -118,8 +103,6 @@ public void handleBytes(TcpChannel channel, ReleasableBytesReference reference)
}
}
}

releasePendingBytes(bytesConsumed);
}

private void forwardFragments(TcpChannel channel, ArrayList<Object> fragments) throws IOException {
Expand Down Expand Up @@ -155,11 +138,22 @@ private boolean endOfMessage(Object fragment) {
return fragment == InboundDecoder.PING || fragment == InboundDecoder.END_CONTENT || fragment instanceof Exception;
}

private void releasePendingBytes(int bytesConsumed) {
if (isClosed) {
// Are released by the close method
return;
private ReleasableBytesReference getPendingBytes() {
if (pending.size() == 1) {
return pending.peekFirst().retain();
} else {
final ReleasableBytesReference[] bytesReferences = new ReleasableBytesReference[pending.size()];
int index = 0;
for (ReleasableBytesReference pendingReference : pending) {
bytesReferences[index] = pendingReference.retain();
++index;
}
final Releasable releasable = () -> Releasables.closeWhileHandlingException(bytesReferences);
return new ReleasableBytesReference(new CompositeBytesReference(bytesReferences), releasable);
}
}

private void releasePendingBytes(int bytesConsumed) {
int bytesToRelease = bytesConsumed;
while (bytesToRelease != 0) {
try (ReleasableBytesReference reference = pending.pollFirst()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.Streams;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.PageCacheRecycler;
import org.elasticsearch.common.util.concurrent.ThreadContext;
Expand All @@ -35,6 +36,7 @@
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;

import static org.hamcrest.Matchers.instanceOf;
Expand Down Expand Up @@ -172,6 +174,53 @@ public void testPipelineHandling() throws IOException {
}
}

public void testEnsureBodyIsNotPrematurelyReleased() throws IOException {
final PageCacheRecycler recycler = PageCacheRecycler.NON_RECYCLING_INSTANCE;
BiConsumer<TcpChannel, InboundMessage> messageHandler = (c, m) -> {};
BiConsumer<TcpChannel, Tuple<Header, Exception>> errorHandler = (c, e) -> {};
final InboundPipeline pipeline = new InboundPipeline(Version.CURRENT, recycler, messageHandler, errorHandler);

try (BytesStreamOutput streamOutput = new BytesStreamOutput()) {
String actionName = "actionName";
final Version version = Version.CURRENT;
final String value = randomAlphaOfLength(1000);
final boolean isRequest = randomBoolean();
final long requestId = randomNonNegativeLong();

OutboundMessage message;
if (isRequest) {
message = new OutboundMessage.Request(threadContext, new String[0], new TestRequest(value),
version, actionName, requestId, false, false);
} else {
message = new OutboundMessage.Response(threadContext, Collections.emptySet(), new TestResponse(value),
version, requestId, false, false);
}

final BytesReference reference = message.serialize(streamOutput);
final int fixedHeaderSize = TcpHeader.headerSize(Version.CURRENT);
final int variableHeaderSize = reference.getInt(fixedHeaderSize - 4);
final int totalHeaderSize = fixedHeaderSize + variableHeaderSize;
final AtomicBoolean bodyReleased = new AtomicBoolean(false);
for (int i = 0; i < totalHeaderSize - 1; ++i) {
try (ReleasableBytesReference slice = ReleasableBytesReference.wrap(reference.slice(i, 1))) {
pipeline.handleBytes(new FakeTcpChannel(), slice);
}
}

final Releasable releasable = () -> bodyReleased.set(true);
final int from = totalHeaderSize - 1;
final BytesReference partHeaderPartBody = reference.slice(from, reference.length() - from - 1);
try (ReleasableBytesReference slice = new ReleasableBytesReference(partHeaderPartBody, releasable)) {
pipeline.handleBytes(new FakeTcpChannel(), slice);
}
assertFalse(bodyReleased.get());
try (ReleasableBytesReference slice = new ReleasableBytesReference(reference.slice(reference.length() - 1, 1), releasable)) {
pipeline.handleBytes(new FakeTcpChannel(), slice);
}
assertTrue(bodyReleased.get());
}
}

private static class MessageData {

private final Version version;
Expand Down