Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,19 @@
*/
class StreamInterceptor implements TransportFrameDecoder.Interceptor {

private final TransportResponseHandler handler;
private final String streamId;
private final long byteCount;
private final StreamCallback callback;

private volatile long bytesRead;

StreamInterceptor(String streamId, long byteCount, StreamCallback callback) {
StreamInterceptor(
TransportResponseHandler handler,
String streamId,
long byteCount,
StreamCallback callback) {
this.handler = handler;
this.streamId = streamId;
this.byteCount = byteCount;
this.callback = callback;
Expand All @@ -45,11 +51,13 @@ class StreamInterceptor implements TransportFrameDecoder.Interceptor {

@Override
public void exceptionCaught(Throwable cause) throws Exception {
handler.deactivateStream();
callback.onFailure(streamId, cause);
}

@Override
public void channelInactive() throws Exception {
handler.deactivateStream();
callback.onFailure(streamId, new ClosedChannelException());
}

Expand All @@ -65,8 +73,10 @@ public boolean handle(ByteBuf buf) throws Exception {
RuntimeException re = new IllegalStateException(String.format(
"Read too many bytes? Expected %d, but read %d.", byteCount, bytesRead));
callback.onFailure(streamId, re);
handler.deactivateStream();
throw re;
} else if (bytesRead == byteCount) {
handler.deactivateStream();
callback.onComplete(streamId);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicLong;

import com.google.common.annotations.VisibleForTesting;
import io.netty.channel.Channel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -56,6 +57,7 @@ public class TransportResponseHandler extends MessageHandler<ResponseMessage> {
private final Map<Long, RpcResponseCallback> outstandingRpcs;

private final Queue<StreamCallback> streamCallbacks;
private volatile boolean streamActive;

/** Records the time (in system nanoseconds) that the last fetch or RPC request was sent. */
private final AtomicLong timeOfLastRequestNs;
Expand Down Expand Up @@ -87,9 +89,15 @@ public void removeRpcRequest(long requestId) {
}

public void addStreamCallback(StreamCallback callback) {
timeOfLastRequestNs.set(System.nanoTime());
streamCallbacks.offer(callback);
}

@VisibleForTesting
public void deactivateStream() {
streamActive = false;
}

/**
* Fire the failure callback for all outstanding requests. This is called when we have an
* uncaught exception or pre-mature connection termination.
Expand Down Expand Up @@ -177,14 +185,16 @@ public void handle(ResponseMessage message) {
StreamResponse resp = (StreamResponse) message;
StreamCallback callback = streamCallbacks.poll();
if (callback != null) {
StreamInterceptor interceptor = new StreamInterceptor(resp.streamId, resp.byteCount,
StreamInterceptor interceptor = new StreamInterceptor(this, resp.streamId, resp.byteCount,
callback);
try {
TransportFrameDecoder frameDecoder = (TransportFrameDecoder)
channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME);
frameDecoder.setInterceptor(interceptor);
streamActive = true;
} catch (Exception e) {
logger.error("Error installing stream handler.", e);
deactivateStream();
}
} else {
logger.error("Could not find callback for StreamResponse.");
Expand All @@ -208,7 +218,8 @@ public void handle(ResponseMessage message) {

/** Returns total number of outstanding requests (fetch requests + rpcs) */
public int numOutstandingRequests() {
return outstandingFetches.size() + outstandingRpcs.size();
return outstandingFetches.size() + outstandingRpcs.size() + streamCallbacks.size() +
(streamActive ? 1 : 0);
}

/** Returns the time in nanoseconds of when the last request was sent out. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.network;

import io.netty.channel.Channel;
import io.netty.channel.local.LocalChannel;
import org.junit.Test;

Expand All @@ -28,12 +29,16 @@
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.client.ChunkReceivedCallback;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.StreamCallback;
import org.apache.spark.network.client.TransportResponseHandler;
import org.apache.spark.network.protocol.ChunkFetchFailure;
import org.apache.spark.network.protocol.ChunkFetchSuccess;
import org.apache.spark.network.protocol.RpcFailure;
import org.apache.spark.network.protocol.RpcResponse;
import org.apache.spark.network.protocol.StreamChunkId;
import org.apache.spark.network.protocol.StreamFailure;
import org.apache.spark.network.protocol.StreamResponse;
import org.apache.spark.network.util.TransportFrameDecoder;

public class TransportResponseHandlerSuite {
@Test
Expand Down Expand Up @@ -112,4 +117,26 @@ public void handleFailedRPC() {
verify(callback, times(1)).onFailure((Throwable) any());
assertEquals(0, handler.numOutstandingRequests());
}

@Test
public void testActiveStreams() {
Channel c = new LocalChannel();
c.pipeline().addLast(TransportFrameDecoder.HANDLER_NAME, new TransportFrameDecoder());
TransportResponseHandler handler = new TransportResponseHandler(c);

StreamResponse response = new StreamResponse("stream", 1234L, null);
StreamCallback cb = mock(StreamCallback.class);
handler.addStreamCallback(cb);
assertEquals(1, handler.numOutstandingRequests());
handler.handle(response);
assertEquals(1, handler.numOutstandingRequests());
handler.deactivateStream();
assertEquals(0, handler.numOutstandingRequests());

StreamFailure failure = new StreamFailure("stream", "uh-oh");
handler.addStreamCallback(cb);
assertEquals(1, handler.numOutstandingRequests());
handler.handle(failure);
assertEquals(0, handler.numOutstandingRequests());
}
}