diff --git a/src/main/java/com/uber/rss/clients/DataBlockSyncWriteClient.java b/src/main/java/com/uber/rss/clients/DataBlockSyncWriteClient.java index 7ce44d3..627082b 100644 --- a/src/main/java/com/uber/rss/clients/DataBlockSyncWriteClient.java +++ b/src/main/java/com/uber/rss/clients/DataBlockSyncWriteClient.java @@ -118,6 +118,25 @@ public void startUpload(ShuffleMapTaskAttemptId shuffleMapTaskAttemptId, int num writeControlMessageNotWaitResponseStatus(startUploadMessage); } + // TODO do not need mapId/taskAttemptId for StartUploadMessage + public void startUpload(ShuffleMapTaskAttemptId shuffleMapTaskAttemptId, int stageId, int numMaps, int numPartitions, ShuffleWriteConfig shuffleWriteConfig) { + logger.debug(String.format("Starting upload %s, %s", shuffleMapTaskAttemptId, connectionInfo)); + + startUploadShuffleByteSnapshot = totalWriteBytes; + + StartUploadMessage startUploadMessage = new StartUploadMessage( + shuffleMapTaskAttemptId.getShuffleId(), + shuffleMapTaskAttemptId.getMapId(), + shuffleMapTaskAttemptId.getTaskAttemptId(), + numMaps, + numPartitions, + "", + shuffleWriteConfig.getNumSplits(), + stageId); + + writeControlMessageNotWaitResponseStatus(startUploadMessage); + } + public void writeData(int partitionId, long taskAttemptId, ByteBuf data) { final int headerByteCount = Integer.BYTES + Long.BYTES + Integer.BYTES; final int dataByteCount = data.readableBytes(); diff --git a/src/main/java/com/uber/rss/clients/NotifyClient.java b/src/main/java/com/uber/rss/clients/NotifyClient.java index e14dfba..ab8cdd5 100644 --- a/src/main/java/com/uber/rss/clients/NotifyClient.java +++ b/src/main/java/com/uber/rss/clients/NotifyClient.java @@ -17,6 +17,7 @@ import com.uber.rss.exceptions.RssInvalidStateException; import com.uber.rss.messages.FinishApplicationAttemptRequestMessage; import com.uber.rss.messages.FinishApplicationJobRequestMessage; +import com.uber.rss.messages.FinishApplicationStageRequestMessage; import com.uber.rss.messages.MessageConstants; import com.uber.rss.messages.ConnectNotifyRequest; import com.uber.rss.messages.ConnectNotifyResponse; @@ -80,6 +81,12 @@ public void finishApplicationAttempt(String appId, String appAttempt) { writeControlMessageAndWaitResponseStatus(request); } + public void finishApplicationStage(String appId, String appAttempt, int stageId) { + FinishApplicationStageRequestMessage request = new FinishApplicationStageRequestMessage(appId, appAttempt, stageId); + + writeControlMessageAndWaitResponseStatus(request); + } + @Override public void close() { super.close(); diff --git a/src/main/java/com/uber/rss/clients/ShuffleDataSyncWriteClientBase.java b/src/main/java/com/uber/rss/clients/ShuffleDataSyncWriteClientBase.java index f384b83..44e2731 100644 --- a/src/main/java/com/uber/rss/clients/ShuffleDataSyncWriteClientBase.java +++ b/src/main/java/com/uber/rss/clients/ShuffleDataSyncWriteClientBase.java @@ -78,7 +78,7 @@ public ConnectUploadResponse connect() { public void startUpload(AppTaskAttemptId appTaskAttemptId, int numMaps, int numPartitions) { shuffleMapTaskAttemptId = appTaskAttemptId.getShuffleMapTaskAttemptId(); - dataBlockSyncWriteClient.startUpload(shuffleMapTaskAttemptId, numMaps, numPartitions, shuffleWriteConfig); + dataBlockSyncWriteClient.startUpload(shuffleMapTaskAttemptId, appTaskAttemptId.getStageId(), numMaps, numPartitions, shuffleWriteConfig); } @Override diff --git a/src/main/java/com/uber/rss/common/AppTaskAttemptId.java b/src/main/java/com/uber/rss/common/AppTaskAttemptId.java index 178bd33..7c6bce3 100644 --- a/src/main/java/com/uber/rss/common/AppTaskAttemptId.java +++ b/src/main/java/com/uber/rss/common/AppTaskAttemptId.java @@ -26,6 +26,9 @@ public class AppTaskAttemptId { private final int mapId; private final long taskAttemptId; + // if not associated with startUpload pipeline, value will be -1 + private final int stageId; + public AppTaskAttemptId(AppShuffleId appShuffleId, int mapId, long taskAttemptId) { this(appShuffleId.getAppId(), appShuffleId.getAppAttempt(), appShuffleId.getShuffleId(), mapId, taskAttemptId); } @@ -35,11 +38,16 @@ public AppTaskAttemptId(AppMapId appMapId, long taskAttemptId) { } public AppTaskAttemptId(String appId, String appAttempt, int shuffleId, int mapId, long taskAttemptId) { + this(appId, appAttempt, shuffleId, mapId, taskAttemptId, -1); + } + + public AppTaskAttemptId(String appId, String appAttempt, int shuffleId, int mapId, long taskAttemptId, int stageId) { this.appId = appId; this.appAttempt = appAttempt; this.shuffleId = shuffleId; this.mapId = mapId; this.taskAttemptId = taskAttemptId; + this.stageId = stageId; } public String getAppId() { @@ -74,6 +82,10 @@ public ShuffleMapTaskAttemptId getShuffleMapTaskAttemptId() { return new ShuffleMapTaskAttemptId(shuffleId, mapId, taskAttemptId); } + public int getStageId() { + return stageId; + } + @Override public boolean equals(Object o) { if (this == o) return true; @@ -83,13 +95,14 @@ public boolean equals(Object o) { mapId == that.mapId && taskAttemptId == that.taskAttemptId && Objects.equals(appId, that.appId) && - Objects.equals(appAttempt, that.appAttempt); + Objects.equals(appAttempt, that.appAttempt) && + stageId == that.stageId; } @Override public int hashCode() { - return Objects.hash(appId, appAttempt, shuffleId, mapId, taskAttemptId); + return Objects.hash(appId, appAttempt, shuffleId, mapId, taskAttemptId, stageId); } @Override @@ -100,6 +113,7 @@ public String toString() { ", shuffleId=" + shuffleId + ", mapId=" + mapId + ", taskAttemptId=" + taskAttemptId + + ", stageId=" + stageId + '}'; } } diff --git a/src/main/java/com/uber/rss/decoders/StreamServerMessageDecoder.java b/src/main/java/com/uber/rss/decoders/StreamServerMessageDecoder.java index 1eee236..937f2cb 100644 --- a/src/main/java/com/uber/rss/decoders/StreamServerMessageDecoder.java +++ b/src/main/java/com/uber/rss/decoders/StreamServerMessageDecoder.java @@ -27,6 +27,7 @@ import com.uber.rss.messages.ConnectUploadResponse; import com.uber.rss.messages.FinishApplicationAttemptRequestMessage; import com.uber.rss.messages.FinishApplicationJobRequestMessage; +import com.uber.rss.messages.FinishApplicationStageRequestMessage; import com.uber.rss.messages.FinishUploadMessage; import com.uber.rss.messages.GetBusyStatusRequest; import com.uber.rss.messages.GetBusyStatusResponse; @@ -359,6 +360,8 @@ private Object getControlMessage(ChannelHandlerContext ctx, return FinishApplicationJobRequestMessage.deserialize(in); case MessageConstants.MESSAGE_FinishApplicationAttemptRequest: return FinishApplicationAttemptRequestMessage.deserialize(in); + case MessageConstants.MESSAGE_FinishApplicationStageRequest: + return FinishApplicationStageRequestMessage.deserialize(in); case MessageConstants.MESSAGE_ConnectRegistryRequest: return ConnectRegistryRequest.deserialize(in); case MessageConstants.MESSAGE_ConnectRegistryResponse: diff --git a/src/main/java/com/uber/rss/decoders/StreamServerVersionDecoder.java b/src/main/java/com/uber/rss/decoders/StreamServerVersionDecoder.java index 61c9ca5..31bef2d 100644 --- a/src/main/java/com/uber/rss/decoders/StreamServerVersionDecoder.java +++ b/src/main/java/com/uber/rss/decoders/StreamServerVersionDecoder.java @@ -84,7 +84,7 @@ private void addVersionDecoder(ChannelHandlerContext ctx, byte type, byte versio } else if (type == MessageConstants.NOTIFY_UPLINK_MAGIC_BYTE && version == MessageConstants.NOTIFY_UPLINK_VERSION_3) { newDecoder = new StreamServerMessageDecoder(null); - NotifyChannelInboundHandler channelInboundHandler = new NotifyChannelInboundHandler(serverId); + NotifyChannelInboundHandler channelInboundHandler = new NotifyChannelInboundHandler(serverId, executor); channelInboundHandler.processChannelActive(ctx); newHandler = channelInboundHandler; } else if (type == MessageConstants.REGISTRY_UPLINK_MAGIC_BYTE && diff --git a/src/main/java/com/uber/rss/exceptions/RssInvalidStageException.java b/src/main/java/com/uber/rss/exceptions/RssInvalidStageException.java new file mode 100644 index 0000000..a9f72a3 --- /dev/null +++ b/src/main/java/com/uber/rss/exceptions/RssInvalidStageException.java @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2020 Uber Technologies, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.uber.rss.exceptions; + +/*** + * This exception is thrown where there is an error with a shuffle stage in the shuffle server. + * (e.g. the stageId -> shuffleId lookup is passed a stage that doesn't have a mapping) + */ +public class RssInvalidStageException extends RssException { + public RssInvalidStageException(String message) { + super(message); + } +} diff --git a/src/main/java/com/uber/rss/execution/ShuffleExecutor.java b/src/main/java/com/uber/rss/execution/ShuffleExecutor.java index 225ef62..aafbd68 100644 --- a/src/main/java/com/uber/rss/execution/ShuffleExecutor.java +++ b/src/main/java/com/uber/rss/execution/ShuffleExecutor.java @@ -18,6 +18,7 @@ import com.uber.m3.tally.Gauge; import com.uber.rss.clients.ShuffleWriteConfig; import com.uber.rss.common.*; +import com.uber.rss.exceptions.RssInvalidStageException; import com.uber.rss.exceptions.RssShuffleCorruptedException; import com.uber.rss.exceptions.RssShuffleStageNotStartedException; import com.uber.rss.exceptions.RssTooMuchDataException; @@ -92,6 +93,11 @@ public class ShuffleExecutor { private final ConcurrentHashMap stageStates = new ConcurrentHashMap<>(); + // TODO should rename so its clear its only for writing stages? + // This field stores the shuffleId for any associated WRITING stage + private final ConcurrentHashMap stageIdToShuffleIdMap + = new ConcurrentHashMap<>(); + private final StateStore stateStore; private final ShuffleStorage storage; @@ -449,6 +455,20 @@ public void checkAppMaxWriteBytes(String appId) { checkAppMaxWriteBytes(appId, appWriteBytes); } + public void registerShuffleId(int stageId, AppShuffleId appShuffleId) { + stageIdToShuffleIdMap.putIfAbsent(stageId, appShuffleId); + } + + public AppShuffleId getShuffleId(int stageId) { + AppShuffleId shuffleId = stageIdToShuffleIdMap.get(stageId); + if (shuffleId != null) { + return shuffleId; + } + String error = String.format("unable to get a shuffleId for stage= %s could be because this stage doesn't write", stageId); + logger.warn(error); + throw new RssInvalidStageException(error); + } + private void checkAppMaxWriteBytes(AppShuffleId appShuffleId, long currentAppWriteBytes) { if (currentAppWriteBytes > appMaxWriteBytes) { numTruncatedApplications.inc(1); @@ -485,7 +505,7 @@ private ExecutorAppState getAppState(String appId) { } } - private ExecutorShuffleStageState getStageState(AppShuffleId appShuffleId) { + public ExecutorShuffleStageState getStageState(AppShuffleId appShuffleId) { ExecutorShuffleStageState state = stageStates.get(appShuffleId); if (state != null) { return state; diff --git a/src/main/java/com/uber/rss/handlers/NotifyChannelInboundHandler.java b/src/main/java/com/uber/rss/handlers/NotifyChannelInboundHandler.java index bcada1f..af290eb 100644 --- a/src/main/java/com/uber/rss/handlers/NotifyChannelInboundHandler.java +++ b/src/main/java/com/uber/rss/handlers/NotifyChannelInboundHandler.java @@ -15,9 +15,11 @@ package com.uber.rss.handlers; import com.uber.rss.exceptions.RssInvalidDataException; +import com.uber.rss.execution.ShuffleExecutor; import com.uber.rss.messages.FinishApplicationAttemptRequestMessage; import com.uber.rss.messages.FinishApplicationJobRequestMessage; import com.uber.rss.messages.ConnectNotifyRequest; +import com.uber.rss.messages.FinishApplicationStageRequestMessage; import com.uber.rss.metrics.M3Stats; import com.uber.rss.util.NettyUtils; import io.netty.channel.ChannelHandlerContext; @@ -33,8 +35,8 @@ public class NotifyChannelInboundHandler extends ChannelInboundHandlerAdapter { private final NotifyServerHandler serverHandler; - public NotifyChannelInboundHandler(String serverId) { - serverHandler = new NotifyServerHandler(serverId); + public NotifyChannelInboundHandler(String serverId, ShuffleExecutor executor) { + serverHandler = new NotifyServerHandler(serverId, executor); } @Override @@ -61,6 +63,8 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) { serverHandler.handleMessage(ctx, (FinishApplicationJobRequestMessage)msg); } else if (msg instanceof FinishApplicationAttemptRequestMessage) { serverHandler.handleMessage(ctx, (FinishApplicationAttemptRequestMessage)msg); + } else if (msg instanceof FinishApplicationStageRequestMessage) { + serverHandler.handleMessage(ctx, (FinishApplicationStageRequestMessage) msg); } else { throw new RssInvalidDataException(String.format("Unsupported message: %s, %s", msg, connectionInfo)); } diff --git a/src/main/java/com/uber/rss/handlers/NotifyServerHandler.java b/src/main/java/com/uber/rss/handlers/NotifyServerHandler.java index 83d4188..e9f6862 100644 --- a/src/main/java/com/uber/rss/handlers/NotifyServerHandler.java +++ b/src/main/java/com/uber/rss/handlers/NotifyServerHandler.java @@ -14,11 +14,17 @@ package com.uber.rss.handlers; +import com.uber.rss.common.AppShuffleId; +import com.uber.rss.exceptions.RssException; +import com.uber.rss.exceptions.RssShuffleStageNotStartedException; +import com.uber.rss.execution.ShuffleExecutor; import com.uber.rss.messages.FinishApplicationJobRequestMessage; import com.uber.rss.messages.FinishApplicationAttemptRequestMessage; +import com.uber.rss.messages.FinishApplicationStageRequestMessage; import com.uber.rss.messages.MessageConstants; import com.uber.rss.messages.ConnectNotifyRequest; import com.uber.rss.messages.ConnectNotifyResponse; +import com.uber.rss.messages.ShuffleStageStatus; import com.uber.rss.metrics.ApplicationJobStatusMetrics; import com.uber.rss.metrics.ApplicationMetrics; import com.uber.rss.metrics.NotifyServerMetricsContainer; @@ -37,11 +43,13 @@ public class NotifyServerHandler { private static final NotifyServerMetricsContainer metricsContainer = new NotifyServerMetricsContainer(); private final String serverId; + private final ShuffleExecutor executor; private String user; - public NotifyServerHandler(String serverId) { + public NotifyServerHandler(String serverId, ShuffleExecutor executor) { this.serverId = serverId; + this.executor = executor; } public void handleMessage(ChannelHandlerContext ctx, ConnectNotifyRequest msg) { @@ -73,6 +81,40 @@ public void handleMessage(ChannelHandlerContext ctx, FinishApplicationAttemptReq metrics.getNumApplications().inc(1); } + public void handleMessage(ChannelHandlerContext ctx, FinishApplicationStageRequestMessage msg) { + writeAndFlushByte(ctx, MessageConstants.RESPONSE_STATUS_OK); + + logger.info("finishApplicationStage, appId: {}, appAttempt: {}, stageId: {}", + msg.getAppId(), + msg.getAppAttempt(), + msg.getStageId()); + + // TODO investigate further whether stageId->shuffleId is 1-1. initial investigations suggest so but would + // be worth knowing 100% + AppShuffleId shuffleId; + try { + shuffleId = executor.getShuffleId(msg.getStageId()); + } catch (RssException e) { + logger.debug("Shuffle Stage {} does not do any writing", msg.getStageId(), e); + return; + } + + ShuffleStageStatus status = executor.getShuffleStageStatus(shuffleId); + if (status.getFileStatus() == ShuffleStageStatus.FILE_STATUS_SHUFFLE_STAGE_NOT_STARTED) { + // This case "should" never occur unless thread handling uploadMessage got stuck and this ran first + String error = String.format("Shuffle stage was not started for stage=%s shuffle=%s unable to close shuffle files", msg.getStageId(), shuffleId); + logger.error(error); + throw new RssShuffleStageNotStartedException(error); + } + + // TODO investigate whether its possible for next stage to start before this handler is done running + // in current rss implementation, this would be a problem as download requests start before the shuffle + // files had made it to a storage like s3 which is slower than local or hdfs + executor.getStageState(shuffleId).closeWriters(); + + logger.info("writing is complete for stage= {}, shuffleId= {} ", msg.getStageId(), shuffleId); + } + private void writeAndFlushByte(ChannelHandlerContext ctx, byte value) { ByteBuf buf = ctx.alloc().buffer(1); buf.writeByte(value); diff --git a/src/main/java/com/uber/rss/handlers/UploadChannelInboundHandler.java b/src/main/java/com/uber/rss/handlers/UploadChannelInboundHandler.java index 96ba1cd..ffdf8cf 100644 --- a/src/main/java/com/uber/rss/handlers/UploadChannelInboundHandler.java +++ b/src/main/java/com/uber/rss/handlers/UploadChannelInboundHandler.java @@ -18,6 +18,7 @@ import com.uber.m3.tally.Gauge; import com.uber.rss.RssBuildInfo; import com.uber.rss.clients.ShuffleWriteConfig; +import com.uber.rss.common.AppShuffleId; import com.uber.rss.common.AppTaskAttemptId; import com.uber.rss.exceptions.RssInvalidDataException; import com.uber.rss.exceptions.RssMaxConnectionsException; @@ -178,11 +179,14 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) { appAttempt, startUploadMessage.getShuffleId(), startUploadMessage.getMapId(), - startUploadMessage.getAttemptId()); + startUploadMessage.getAttemptId(), + startUploadMessage.getStageId()); ShuffleWriteConfig writeConfig = new ShuffleWriteConfig(startUploadMessage.getNumSplits()); uploadServerHandler.initializeAppTaskAttempt(appTaskAttemptId, startUploadMessage.getNumPartitions(), writeConfig, ctx); + uploadServerHandler.registerStageId(startUploadMessage.getStageId(), + new AppShuffleId(appId, appAttempt, startUploadMessage.getShuffleId())); } else if (msg instanceof FinishUploadMessage) { logger.debug("FinishUploadMessage, {}, {}", msg, connectionInfo); FinishUploadMessage finishUploadMessage = (FinishUploadMessage)msg; diff --git a/src/main/java/com/uber/rss/handlers/UploadServerHandler.java b/src/main/java/com/uber/rss/handlers/UploadServerHandler.java index a7e94a5..d1fb93c 100644 --- a/src/main/java/com/uber/rss/handlers/UploadServerHandler.java +++ b/src/main/java/com/uber/rss/handlers/UploadServerHandler.java @@ -16,8 +16,10 @@ import com.uber.rss.clients.ShuffleWriteConfig; import com.uber.rss.common.AppMapId; +import com.uber.rss.common.AppShuffleId; import com.uber.rss.common.AppTaskAttemptId; import com.uber.rss.exceptions.RssInvalidDataException; +import com.uber.rss.exceptions.RssInvalidStageException; import com.uber.rss.exceptions.RssInvalidStateException; import com.uber.rss.exceptions.RssMaxConnectionsException; import com.uber.rss.execution.ShuffleDataWrapper; @@ -119,6 +121,17 @@ public void finishUpload(long taskAttemptId) { finishUploadImpl(appTaskAttemptIdToFinishUpload); } + public void registerStageId(int stageId, AppShuffleId appShuffleId) { + // stageId default is -1. this would only occur in cases if method not called for StartUploadMessage or error occurred + if (stageId != -1) { + executor.registerShuffleId(stageId, appShuffleId); + } else { + String error = String.format("registerStageId called not using StartUploadMessage or stageId never set for shuffle=%s", appShuffleId); + logger.error(error); + throw new RssInvalidStageException(error); + } + } + private void finishUploadImpl(AppTaskAttemptId appTaskAttemptIdToFinishUpload) { lazyStartUpload(appTaskAttemptIdToFinishUpload); executor.finishUpload(appTaskAttemptIdToFinishUpload.getAppShuffleId(), diff --git a/src/main/java/com/uber/rss/messages/FinishApplicationStageRequestMessage.java b/src/main/java/com/uber/rss/messages/FinishApplicationStageRequestMessage.java new file mode 100644 index 0000000..bc0a11d --- /dev/null +++ b/src/main/java/com/uber/rss/messages/FinishApplicationStageRequestMessage.java @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2020 Uber Technologies, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * http://www.apache.org/licenses/LICENSE-2.0 + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.uber.rss.messages; + +import com.uber.rss.util.ByteBufUtils; +import io.netty.buffer.ByteBuf; + +/** + * Notifies RSS when a Spark Stage is completed. + */ +public class FinishApplicationStageRequestMessage extends ControlMessage { + private final String appId; + private final String appAttempt; + private final int stageId; + + public FinishApplicationStageRequestMessage(String appId, String appAttempt, int stageId) { + this.appId = appId; + this.appAttempt = appAttempt; + this.stageId = stageId; + } + + @Override + public int getMessageType() { + return MessageConstants.MESSAGE_FinishApplicationStageRequest; + } + + @Override + public void serialize(ByteBuf buf) { + ByteBufUtils.writeLengthAndString(buf, appId); + ByteBufUtils.writeLengthAndString(buf, appAttempt); + buf.writeInt(stageId); + } + + public static FinishApplicationStageRequestMessage deserialize(ByteBuf buf) { + String appId = ByteBufUtils.readLengthAndString(buf); + String appAttempt = ByteBufUtils.readLengthAndString(buf); + int stageId = buf.readInt(); + return new FinishApplicationStageRequestMessage(appId, appAttempt, stageId); + } + + public String getAppId() { + return appId; + } + + public String getAppAttempt() { + return appAttempt; + } + + public int getStageId() { + return stageId; + } + + @Override + public String toString() { + return "FinishApplicationStageRequestMessage{" + + "appId='" + appId + '\'' + + ", appAttempt='" + appAttempt + '\'' + + ", stageId=" + stageId + + '}'; + } +} diff --git a/src/main/java/com/uber/rss/messages/MessageConstants.java b/src/main/java/com/uber/rss/messages/MessageConstants.java index 5b78a76..be6e34c 100644 --- a/src/main/java/com/uber/rss/messages/MessageConstants.java +++ b/src/main/java/com/uber/rss/messages/MessageConstants.java @@ -38,6 +38,7 @@ public class MessageConstants { public final static int MESSAGE_RegisterServerRequest = -9; public final static int MESSAGE_GetServersRequest = -10; public final static int MESSAGE_FinishApplicationJobRequest = -12; + public final static int MESSAGE_FinishApplicationStageRequest = -13; public final static int MESSAGE_GetServersResponse = -16; public final static int MESSAGE_RegisterServerResponse = -19; diff --git a/src/main/java/com/uber/rss/messages/StartUploadMessage.java b/src/main/java/com/uber/rss/messages/StartUploadMessage.java index 4e3791f..c1d417e 100644 --- a/src/main/java/com/uber/rss/messages/StartUploadMessage.java +++ b/src/main/java/com/uber/rss/messages/StartUploadMessage.java @@ -19,15 +19,20 @@ public class StartUploadMessage extends BaseMessage { - private int shuffleId; - private int mapId; - private long attemptId; - private int numMaps; - private int numPartitions; - private String fileCompressionCodec; - private short numSplits; + private final int shuffleId; + private final int mapId; + private final long attemptId; + private final int numMaps; + private final int numPartitions; + private final String fileCompressionCodec; + private final short numSplits; + private final int stageId; public StartUploadMessage(int shuffleId, int mapId, long attemptId, int numMaps, int numPartitions, String fileCompressionCodec, short numSplits) { + this(shuffleId, mapId, attemptId, numMaps, numPartitions, fileCompressionCodec, numSplits, -1); + } + + public StartUploadMessage(int shuffleId, int mapId, long attemptId, int numMaps, int numPartitions, String fileCompressionCodec, short numSplits, int stageId) { this.shuffleId = shuffleId; this.mapId = mapId; this.attemptId = attemptId; @@ -35,6 +40,7 @@ public StartUploadMessage(int shuffleId, int mapId, long attemptId, int numMaps, this.numPartitions = numPartitions; this.fileCompressionCodec = fileCompressionCodec; this.numSplits = numSplits; + this.stageId = stageId; } @Override @@ -51,6 +57,7 @@ public void serialize(ByteBuf buf) { buf.writeInt(numPartitions); ByteBufUtils.writeLengthAndString(buf, fileCompressionCodec); buf.writeShort(numSplits); + buf.writeInt(stageId); } public static StartUploadMessage deserialize(ByteBuf buf) { @@ -61,7 +68,8 @@ public static StartUploadMessage deserialize(ByteBuf buf) { int numPartitions = buf.readInt(); String fileCompressionCodec = ByteBufUtils.readLengthAndString(buf); short numSplits = buf.readShort(); - return new StartUploadMessage(shuffleId, mapId, attemptId, numMaps, numPartitions, fileCompressionCodec, numSplits); + int stageId = buf.readInt(); + return new StartUploadMessage(shuffleId, mapId, attemptId, numMaps, numPartitions, fileCompressionCodec, numSplits, stageId); } public int getShuffleId() { @@ -92,6 +100,10 @@ public short getNumSplits() { return numSplits; } + public int getStageId() { + return stageId; + } + @Override public String toString() { return "StartUploadMessage{" + @@ -102,6 +114,7 @@ public String toString() { ", numPartitions=" + numPartitions + ", fileCompressionCodec='" + fileCompressionCodec + '\'' + ", numSplits=" + numSplits + + ", stageId= " + stageId + '}'; } } diff --git a/src/main/scala/org/apache/spark/shuffle/RssOpts.scala b/src/main/scala/org/apache/spark/shuffle/RssOpts.scala index 34325bb..3168fa2 100644 --- a/src/main/scala/org/apache/spark/shuffle/RssOpts.scala +++ b/src/main/scala/org/apache/spark/shuffle/RssOpts.scala @@ -177,4 +177,9 @@ object RssOpts { .doc("Create lazy connections from mappers to RSS servers just before sending the shuffle data") .booleanConf .createWithDefault(true) + val enableListenForOnStageCompleted: ConfigEntry[Boolean] = + ConfigBuilder("spark.shuffle.rss.enableListenForOnStageCompleted") + .doc("Tell RSS to process onStageCompleted events from SparkListener") + .booleanConf + .createWithDefault(false) } diff --git a/src/main/scala/org/apache/spark/shuffle/RssShuffleManager.scala b/src/main/scala/org/apache/spark/shuffle/RssShuffleManager.scala index 4b6e825..dd769bf 100644 --- a/src/main/scala/org/apache/spark/shuffle/RssShuffleManager.scala +++ b/src/main/scala/org/apache/spark/shuffle/RssShuffleManager.scala @@ -118,6 +118,7 @@ class RssShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { RssSparkListener.registerSparkListenerOnlyOnce(sparkContext, () => new RssSparkListener( + conf, user, conf.getAppId, appAttempt, @@ -158,7 +159,8 @@ class RssShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { rssShuffleHandle.appAttempt, handle.shuffleId, mapId, - context.taskAttemptId() + context.taskAttemptId(), + context.stageId() ) logDebug( s"getWriter $mapInfo" ) diff --git a/src/main/scala/org/apache/spark/shuffle/rss/RssSparkListener.scala b/src/main/scala/org/apache/spark/shuffle/rss/RssSparkListener.scala index 47b129a..e98d3a5 100644 --- a/src/main/scala/org/apache/spark/shuffle/rss/RssSparkListener.scala +++ b/src/main/scala/org/apache/spark/shuffle/rss/RssSparkListener.scala @@ -21,9 +21,10 @@ import com.uber.rss.clients.{MultiServerHeartbeatClient, NotifyClient} import com.uber.rss.metrics.M3Stats import com.uber.rss.util.ServerHostAndPort import org.apache.commons.lang3.exception.ExceptionUtils -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.scheduler.{JobFailed, _} +import org.apache.spark.shuffle.RssOpts object RssSparkListener extends Logging { @@ -64,7 +65,7 @@ object RssSparkListener extends Logging { * @param attemptId * @param notifyServers */ -class RssSparkListener(val user: String, val appId: String, val attemptId: String, val notifyServers: Array[String], val networkTimeoutMillis: Int) +class RssSparkListener(val conf: SparkConf, val user: String, val appId: String, val attemptId: String, val notifyServers: Array[String], val networkTimeoutMillis: Int) extends SparkListener with Logging { private val m3Tags: util.Map[String, String] = new util.HashMap[String, String] @@ -134,6 +135,41 @@ class RssSparkListener(val user: String, val appId: String, val attemptId: Strin }) } + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { + if (notifyServers == null || notifyServers.length == 0) { + return + } + + val processOnStageCompleted = conf.get(RssOpts.enableListenForOnStageCompleted) + + if (processOnStageCompleted) { + // ensures each server writes its partitions to s3 when stage completes vs picking a random one + invokeAllServers(client => { + client.finishApplicationStage(appId, attemptId, stageCompleted.stageInfo.stageId) + }) + } + } + + private def invokeAllServers(run: NotifyClient=>Unit): Unit = { + var client: NotifyClient = null + for (notifyServer <- notifyServers) { + try { + logInfo(s"Invoking control server $notifyServer") + val server = ServerHostAndPort.fromString(notifyServer) + client = new NotifyClient(server.getHost, server.getPort, networkTimeoutMillis, user) + client.connect() + run(client) + } catch { + case e: Throwable => { + logWarning("Failed to invoke control server", e) + M3Stats.addException(e, this.getClass().getSimpleName()) + } + } finally { + client.close() + } + } + } + private def invokeRandomNotifyServer(run: NotifyClient=>Unit) = { var client: NotifyClient = null try { diff --git a/src/test/java/com/uber/rss/clients/NotifyClientTest.java b/src/test/java/com/uber/rss/clients/NotifyClientTest.java index 2b35307..30cbdb4 100644 --- a/src/test/java/com/uber/rss/clients/NotifyClientTest.java +++ b/src/test/java/com/uber/rss/clients/NotifyClientTest.java @@ -56,4 +56,21 @@ public void finishApplicationAttempt() { } } + @Test + public void finishApplicationStage() { + TestStreamServer testServer = TestStreamServer.createRunningServer(); + + try (NotifyClient client = new NotifyClient("localhost", testServer.getShufflePort(), TestConstants.NETWORK_TIMEOUT, "user1")) { + client.connect(); + // send same request twice to make sure it is still good + client.finishApplicationStage("app1", "exec1", 1); + client.finishApplicationStage("app1", "exec1", 1); + // send different request to make sure it is still good + client.finishApplicationStage("app1", "exec2", 2); + client.finishApplicationStage("app2", "exec2", 2); + } finally { + testServer.shutdown(); + } + } + }