Skip to content
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.network.buffer;

import java.io.File;

import com.google.common.base.Objects;

import org.apache.spark.network.util.TransportConf;

/**
* A {@link ManagedBuffer} backed by a segment in a file with digest.
*/
public final class DigestFileSegmentManagedBuffer extends FileSegmentManagedBuffer {

private final long digest;

public DigestFileSegmentManagedBuffer(TransportConf conf, File file, long offset, long length,
long digest) {
super(conf, file, offset, length);
this.digest = digest;
}

public long getDigest() { return digest; }

@Override
public String toString() {
return Objects.toStringHelper(this)
.add("file", getFile())
.add("offset", getOffset())
.add("length", getLength())
.add("digest", digest)
.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
/**
* A {@link ManagedBuffer} backed by a segment in a file.
*/
public final class FileSegmentManagedBuffer extends ManagedBuffer {
public class FileSegmentManagedBuffer extends ManagedBuffer {
private final TransportConf conf;
private final File file;
private final long offset;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ public interface ChunkReceivedCallback {
*/
void onSuccess(int chunkIndex, ManagedBuffer buffer);

/** Called with a extra digest parameter upon receipt of a particular chunk.*/
default void onSuccess(int chunkIndex, ManagedBuffer buffer, long digest) {
onSuccess(chunkIndex, buffer);
}

/**
* Called upon failure to fetch a particular chunk. Note that this may actually be called due
* to failure to fetch a prior chunk in this stream.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ public interface StreamCallback {
/** Called when all data from the stream has been received. */
void onComplete(String streamId) throws IOException;

/** Called with a extra digest when all data from the stream has been received. */
default void onComplete(String streamId, long digest) throws IOException {
onComplete(streamId);
}

/** Called if there's an error reading data from the stream. */
void onFailure(String streamId, Throwable cause) throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public class StreamInterceptor<T extends Message> implements TransportFrameDecod
private final long byteCount;
private final StreamCallback callback;
private long bytesRead;
private long digest = -1L;

public StreamInterceptor(
MessageHandler<T> handler,
Expand All @@ -50,6 +51,16 @@ public StreamInterceptor(
this.bytesRead = 0;
}

public StreamInterceptor(
MessageHandler<T> handler,
String streamId,
long byteCount,
StreamCallback callback,
long digest) {
this(handler, streamId, byteCount, callback);
this.digest = digest;
}

@Override
public void exceptionCaught(Throwable cause) throws Exception {
deactivateStream();
Expand Down Expand Up @@ -86,7 +97,11 @@ public boolean handle(ByteBuf buf) throws Exception {
throw re;
} else if (bytesRead == byteCount) {
deactivateStream();
callback.onComplete(streamId);
if (digest < 0) {
callback.onComplete(streamId);
} else {
callback.onComplete(streamId, digest);
}
}

return bytesRead != byteCount;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@

import org.apache.spark.network.protocol.ChunkFetchFailure;
import org.apache.spark.network.protocol.ChunkFetchSuccess;
import org.apache.spark.network.protocol.DigestChunkFetchSuccess;
import org.apache.spark.network.protocol.DigestStreamResponse;
import org.apache.spark.network.protocol.ResponseMessage;
import org.apache.spark.network.protocol.RpcFailure;
import org.apache.spark.network.protocol.RpcResponse;
Expand Down Expand Up @@ -246,6 +248,45 @@ public void handle(ResponseMessage message) throws Exception {
} else {
logger.warn("Stream failure with unknown callback: {}", resp.error);
}
} else if (message instanceof DigestChunkFetchSuccess) {
DigestChunkFetchSuccess resp = (DigestChunkFetchSuccess) message;
ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId);
if (listener == null) {
logger.warn("Ignoring response for block {} from {} since it is not outstanding",
resp.streamChunkId, getRemoteAddress(channel));
resp.body().release();
} else {
outstandingFetches.remove(resp.streamChunkId);
listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body(), resp.digest);
resp.body().release();
}
} else if (message instanceof DigestStreamResponse) {
DigestStreamResponse resp = (DigestStreamResponse) message;
Pair<String, StreamCallback> entry = streamCallbacks.poll();
if (entry != null) {
StreamCallback callback = entry.getValue();
if (resp.byteCount > 0) {
StreamInterceptor interceptor = new StreamInterceptor(
this, resp.streamId, resp.byteCount, callback, resp.digest);
try {
TransportFrameDecoder frameDecoder = (TransportFrameDecoder)
channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME);
frameDecoder.setInterceptor(interceptor);
streamActive = true;
} catch (Exception e) {
logger.error("Error installing stream handler.", e);
deactivateStream();
}
} else {
try {
callback.onComplete(resp.streamId, resp.digest);
} catch (Exception e) {
logger.warn("Error in stream handler onComplete().", e);
}
}
} else {
logger.error("Could not find callback for StreamResponse.");
}
} else {
throw new IllegalStateException("Unknown response type: " + message.type());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.network.protocol;

import com.google.common.base.Objects;
import io.netty.buffer.ByteBuf;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.buffer.NettyManagedBuffer;

/**
* Response to {@link ChunkFetchRequest} when a chunk exists with a digest and has been
* successfully fetched.
*
* Note that the server-side encoding of this messages does NOT include the buffer itself, as this
* may be written by Netty in a more efficient manner (i.e., zero-copy write).
* Similarly, the client-side decoding will reuse the Netty ByteBuf as the buffer.
*/
public final class DigestChunkFetchSuccess extends AbstractResponseMessage {
public final StreamChunkId streamChunkId;
public final long digest;

public DigestChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer, long digest) {
super(buffer, true);
this.streamChunkId = streamChunkId;
this.digest = digest;
}

@Override
public Message.Type type() { return Type.DigestChunkFetchSuccess; }

@Override
public int encodedLength() {
return streamChunkId.encodedLength() + 8;
}

/** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */
@Override
public void encode(ByteBuf buf) {
streamChunkId.encode(buf);
buf.writeLong(digest);
}

@Override
public ResponseMessage createFailureResponse(String error) {
return new ChunkFetchFailure(streamChunkId, error);
}

/** Decoding uses the given ByteBuf as our data, and will retain() it. */
public static DigestChunkFetchSuccess decode(ByteBuf buf) {
StreamChunkId streamChunkId = StreamChunkId.decode(buf);
long digest = buf.readLong();
buf.retain();
NettyManagedBuffer managedBuf = new NettyManagedBuffer(buf.duplicate());
return new DigestChunkFetchSuccess(streamChunkId, managedBuf, digest);
}

@Override
public int hashCode() {
return Objects.hashCode(streamChunkId, body(), digest);
}

@Override
public boolean equals(Object other) {
if (other instanceof DigestChunkFetchSuccess) {
DigestChunkFetchSuccess o = (DigestChunkFetchSuccess) other;
return streamChunkId.equals(o.streamChunkId) && super.equals(o) && digest == o.digest;
}
return false;
}

@Override
public String toString() {
return Objects.toStringHelper(this)
.add("streamChunkId", streamChunkId)
.add("digest", digest)
.add("buffer", body())
.toString();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.network.protocol;

import com.google.common.base.Objects;
import io.netty.buffer.ByteBuf;
import org.apache.spark.network.buffer.ManagedBuffer;

/**
* Response to {@link StreamRequest} with digest when the stream has been successfully opened.
* <p>
* Note the message itself does not contain the stream data. That is written separately by the
* sender. The receiver is expected to set a temporary channel handler that will consume the
* number of bytes this message says the stream has.
*/
public final class DigestStreamResponse extends AbstractResponseMessage {
public final String streamId;
public final long byteCount;
public final long digest;

public DigestStreamResponse(String streamId, long byteCount, ManagedBuffer buffer, long digest) {
super(buffer, false);
this.streamId = streamId;
this.byteCount = byteCount;
this.digest = digest;
}

@Override
public Message.Type type() { return Type.DigestStreamResponse; }

@Override
public int encodedLength() {
return 8 + Encoders.Strings.encodedLength(streamId) + 8;
}

/** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */
@Override
public void encode(ByteBuf buf) {
Encoders.Strings.encode(buf, streamId);
buf.writeLong(byteCount);
buf.writeLong(digest);
}

@Override
public ResponseMessage createFailureResponse(String error) {
return new StreamFailure(streamId, error);
}

public static DigestStreamResponse decode(ByteBuf buf) {
String streamId = Encoders.Strings.decode(buf);
long byteCount = buf.readLong();
long digest = buf.readLong();
return new DigestStreamResponse(streamId, byteCount, null, digest);
}

@Override
public int hashCode() {
return Objects.hashCode(byteCount, streamId, body(), digest);
}

@Override
public boolean equals(Object other) {
if (other instanceof DigestStreamResponse) {
DigestStreamResponse o = (DigestStreamResponse) other;
return byteCount == o.byteCount && streamId.equals(o.streamId) && digest == o.digest;
}
return false;
}

@Override
public String toString() {
return Objects.toStringHelper(this)
.add("streamId", streamId)
.add("byteCount", byteCount)
.add("digest", digest)
.add("body", body())
.toString();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ enum Type implements Encodable {
ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2),
RpcRequest(3), RpcResponse(4), RpcFailure(5),
StreamRequest(6), StreamResponse(7), StreamFailure(8),
OneWayMessage(9), UploadStream(10), User(-1);
OneWayMessage(9), UploadStream(10), DigestChunkFetchSuccess(11),
DigestStreamResponse(12), User(-1);

private final byte id;

Expand Down Expand Up @@ -66,6 +67,8 @@ public static Type decode(ByteBuf buf) {
case 8: return StreamFailure;
case 9: return OneWayMessage;
case 10: return UploadStream;
case 11: return DigestChunkFetchSuccess;
case 12: return DigestStreamResponse;
case -1: throw new IllegalArgumentException("User type messages cannot be decoded.");
default: throw new IllegalArgumentException("Unknown message type: " + id);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {

private Message decode(Message.Type msgType, ByteBuf in) {
switch (msgType) {
case DigestChunkFetchSuccess:
return DigestChunkFetchSuccess.decode(in);

case DigestStreamResponse:
return DigestStreamResponse.decode(in);

case ChunkFetchRequest:
return ChunkFetchRequest.decode(in);

Expand Down
Loading