Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ public int chunkFetchHandlerThreads() {
/**
* Whether to use the old protocol while doing the shuffle block fetching.
* It is only enabled while we need the compatibility in the scenario of new spark version
* job fetching blocks from old version external shuffle service.
* job fetching shuffle blocks from old version external shuffle service.
*/
public boolean useOldFetchProtocol() {
return conf.getBoolean("spark.shuffle.useOldFetchProtocol", false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
public abstract class BlockStoreClient implements Closeable {

/**
* Fetch a sequence of blocks from a remote node asynchronously,
* Fetch a sequence of shuffle blocks from a remote node asynchronously,
*
* Note that this API takes a sequence so the implementation can batch requests, and does not
* return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as
Expand All @@ -38,23 +38,36 @@ public abstract class BlockStoreClient implements Closeable {
* @param host the host of the remote node.
* @param port the port of the remote node.
* @param execId the executor id.
* @param shuffleGenerationId the shuffle generation id for all block ids to fetch.
* @param blockIds block ids to fetch.
* @param listener the listener to receive block fetching status.
* @param downloadFileManager DownloadFileManager to create and clean temp files.
* If it's not <code>null</code>, the remote blocks will be streamed
* into temp shuffle files to reduce the memory usage, otherwise,
* they will be kept in memory.
*/
public abstract void fetchBlocks(
public abstract void fetchShuffleBlocks(
String host,
int port,
String execId,
int shuffleGenerationId,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it mean this method can only fetch shuffle blocks? Shall we rename the method?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, after the refactor e7365e3, finally, we can rename the method and split shuffle fetch and data fetch into 2 separated function, thanks for your guidance and advice!
Done in 67f70a4.

String[] blockIds,
BlockFetchingListener listener,
DownloadFileManager downloadFileManager);

/**
* Get the shuffle MetricsSet from BlockStoreClient, this will be used in MetricsSystem to
* Fetch a sequence of non-shuffle blocks from a remote node asynchronously.
*/
public abstract void fetchDataBlocks(
String host,
int port,
String execId,
String[] blockIds,
BlockFetchingListener listener,
DownloadFileManager downloadFileManager);

/**
* Get the shuffle MetricsSet from ShuffleClient, this will be used in MetricsSystem to
* get the Shuffle related metrics.
*/
public MetricSet shuffleMetrics() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,8 @@ private int[] shuffleMapIdAndReduceIds(String[] blockIds, int shuffleId) {
assert(idx == 2 * numBlockIds);
size = mapIdAndReduceIds.length;
blockDataForIndexFn = index -> blockManager.getBlockData(msg.appId, msg.execId,
msg.shuffleId, mapIdAndReduceIds[index], mapIdAndReduceIds[index + 1]);
msg.shuffleId, msg.shuffleGenerationId, mapIdAndReduceIds[index],
mapIdAndReduceIds[index + 1]);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,22 +90,52 @@ public void init(String appId) {
}

@Override
public void fetchBlocks(
public void fetchDataBlocks(
String host,
int port,
String execId,
String[] blockIds,
BlockFetchingListener listener,
DownloadFileManager downloadFileManager) {
doFetchBlocks(host, port, execId, -1, blockIds, listener, downloadFileManager, false);
}

@Override
public void fetchShuffleBlocks(
String host,
int port,
String execId,
int shuffleGenerationId,
String[] blockIds,
BlockFetchingListener listener,
DownloadFileManager downloadFileManager) {
doFetchBlocks(host, port, execId, shuffleGenerationId, blockIds, listener,
downloadFileManager, true);
}

private void doFetchBlocks(
String host,
int port,
String execId,
int shuffleGenerationId,
String[] blockIds,
BlockFetchingListener listener,
DownloadFileManager downloadFileManager,
boolean isShuffleBlocks) {
checkInit();
logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId);
try {
RetryingBlockFetcher.BlockFetchStarter blockFetchStarter =
(blockIds1, listener1) -> {
TransportClient client = clientFactory.createClient(host, port);
new OneForOneBlockFetcher(client, appId, execId,
(blockIds1, listener1) -> {
TransportClient client = clientFactory.createClient(host, port);
if (isShuffleBlocks) {
new OneForOneShuffleBlockFetcher(client, appId, execId, shuffleGenerationId,
blockIds1, listener1, conf, downloadFileManager).start();
} else {
new OneForOneDataBlockFetcher(client, appId, execId,
blockIds1, listener1, conf, downloadFileManager).start();
};
}
};

int maxRetries = conf.maxIORetries();
if (maxRetries > 0) {
Expand All @@ -116,7 +146,7 @@ public void fetchBlocks(
blockFetchStarter.createAndStart(blockIds, listener);
}
} catch (Exception e) {
logger.error("Exception while beginning fetchBlocks", e);
logger.error("Exception while beginning fetchShuffleBlocks", e);
for (String blockId : blockIds) {
listener.onBlockFetchFailure(blockId, e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,18 @@ public void registerExecutor(
executors.put(fullId, executorInfo);
}

/**
* Overload getBlockData with setting shuffleGenerationId to an invalid value of -1.
*/
public ManagedBuffer getBlockData(
String appId,
String execId,
int shuffleId,
int mapId,
int reduceId) {
return getBlockData(appId, execId, shuffleId, -1, mapId, reduceId);
}

/**
* Obtains a FileSegmentManagedBuffer from (shuffleId, mapId, reduceId). We make assumptions
* about how the hash and sort based shuffles store their data.
Expand All @@ -172,14 +184,16 @@ public ManagedBuffer getBlockData(
String appId,
String execId,
int shuffleId,
int shuffleGenerationId,
int mapId,
int reduceId) {
ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId));
if (executor == null) {
throw new RuntimeException(
String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId));
}
return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId);
return getSortBasedShuffleBlockData(
executor, shuffleId, shuffleGenerationId, mapId, reduceId);
}

public ManagedBuffer getRddBlockData(
Expand Down Expand Up @@ -291,22 +305,29 @@ private void deleteNonShuffleServiceServedFiles(String[] dirs) {
}

/**
* Sort-based shuffle data uses an index called "shuffle_ShuffleId_MapId_0.index" into a data file
* called "shuffle_ShuffleId_MapId_0.data". This logic is from IndexShuffleBlockResolver,
* Sort-based shuffle data uses an index called "shuffle_ShuffleId_MapId_0.index" into a data
* file called "shuffle_ShuffleId_MapId_0.data". This logic is from IndexShuffleBlockResolver,
* and the block id format is from ShuffleDataBlockId and ShuffleIndexBlockId.
* While the shuffle data and index file generated from the indeterminate stage,
* the ShuffleDataBlockId and ShuffleIndexBlockId will be extended by the shuffle generation id.
*/
private ManagedBuffer getSortBasedShuffleBlockData(
ExecutorShuffleInfo executor, int shuffleId, int mapId, int reduceId) {
ExecutorShuffleInfo executor, int shuffleId, int shuffleGenerationId,
int mapId, int reduceId) {
String baseFileName = "shuffle_" + shuffleId + "_" + mapId + "_0";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not related to your change, but do you know what _0 means here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, it's the IndexShuffleBlockResolver.NOOP_REDUCE_ID, as described in the comment

// No-op reduce ID used in interactions with disk store.
// The disk store currently expects puts to relate to a (map, reduce) pair, but in the sort
// shuffle outputs for several reduces are glommed into a single file.

After all blocks consolidate in single file, we didn't use reduceId in the shuffle file name, just use the offsite reading from index file to find the block in the shuffle data file.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was from the hash shuffle algorithm from long ago -- that last id was the reducePartitionId. But now we always merge all reducePartitions into one.

if (shuffleGenerationId != -1) {
baseFileName = baseFileName + "_" + shuffleGenerationId;
}
File indexFile = ExecutorDiskUtils.getFile(executor.localDirs, executor.subDirsPerLocalDir,
"shuffle_" + shuffleId + "_" + mapId + "_0.index");
baseFileName + ".index");

try {
ShuffleIndexInformation shuffleIndexInformation = shuffleIndexCache.get(indexFile);
ShuffleIndexRecord shuffleIndexRecord = shuffleIndexInformation.getIndex(reduceId);
return new FileSegmentManagedBuffer(
conf,
ExecutorDiskUtils.getFile(executor.localDirs, executor.subDirsPerLocalDir,
"shuffle_" + shuffleId + "_" + mapId + "_0.data"),
ExecutorDiskUtils.getFile(
executor.localDirs, executor.subDirsPerLocalDir, baseFileName + ".data"),
shuffleIndexRecord.getOffset(),
shuffleIndexRecord.getLength());
} catch (ExecutionException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,8 @@

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;

import com.google.common.primitives.Ints;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -34,44 +31,33 @@
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.server.OneForOneStreamManager;
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
import org.apache.spark.network.shuffle.protocol.FetchShuffleBlocks;
import org.apache.spark.network.shuffle.protocol.OpenBlocks;
import org.apache.spark.network.shuffle.protocol.StreamHandle;
import org.apache.spark.network.util.TransportConf;

/**
* Simple wrapper on top of a TransportClient which interprets each chunk as a whole block, and
* invokes the BlockFetchingListener appropriately. This class is agnostic to the actual RPC
* handler, as long as there is a single "open blocks" message which returns a ShuffleStreamHandle,
* handler, as long as there is a single "open blocks" message which returns a StreamHandle,
* and Java serialization is used.
*
* Note that this typically corresponds to a
* {@link org.apache.spark.network.server.OneForOneStreamManager} on the server side.
*/
public class OneForOneBlockFetcher {
public abstract class OneForOneBlockFetcher {
private static final Logger logger = LoggerFactory.getLogger(OneForOneBlockFetcher.class);

private final TransportClient client;
private final BlockTransferMessage message;
private final String[] blockIds;
protected final String appId;
protected final String execId;
protected final String[] blockIds;
private final BlockFetchingListener listener;
private final ChunkReceivedCallback chunkCallback;
private final TransportConf transportConf;
protected final TransportConf transportConf;
private final DownloadFileManager downloadFileManager;

private StreamHandle streamHandle = null;

public OneForOneBlockFetcher(
TransportClient client,
String appId,
String execId,
String[] blockIds,
BlockFetchingListener listener,
TransportConf transportConf) {
this(client, appId, execId, blockIds, listener, transportConf, null);
}

public OneForOneBlockFetcher(
protected OneForOneBlockFetcher(
TransportClient client,
String appId,
String execId,
Expand All @@ -80,6 +66,8 @@ public OneForOneBlockFetcher(
TransportConf transportConf,
DownloadFileManager downloadFileManager) {
this.client = client;
this.appId = appId;
this.execId = execId;
this.blockIds = blockIds;
this.listener = listener;
this.chunkCallback = new ChunkCallback();
Expand All @@ -88,64 +76,12 @@ public OneForOneBlockFetcher(
if (blockIds.length == 0) {
throw new IllegalArgumentException("Zero-sized blockIds array");
}
if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
this.message = createFetchShuffleBlocksMsg(appId, execId, blockIds);
} else {
this.message = new OpenBlocks(appId, execId, blockIds);
}
}

private boolean isShuffleBlocks(String[] blockIds) {
for (String blockId : blockIds) {
if (!blockId.startsWith("shuffle_")) {
return false;
}
}
return true;
}

/**
* Analyze the pass in blockIds and create FetchShuffleBlocks message.
* The blockIds has been sorted by mapId and reduceId. It's produced in
* org.apache.spark.MapOutputTracker.convertMapStatuses.
* Create the corresponding message for this fetcher.
*/
private FetchShuffleBlocks createFetchShuffleBlocksMsg(
String appId, String execId, String[] blockIds) {
int shuffleId = splitBlockId(blockIds[0])[0];
HashMap<Integer, ArrayList<Integer>> mapIdToReduceIds = new HashMap<>();
for (String blockId : blockIds) {
int[] blockIdParts = splitBlockId(blockId);
if (blockIdParts[0] != shuffleId) {
throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
", got:" + blockId);
}
int mapId = blockIdParts[1];
if (!mapIdToReduceIds.containsKey(mapId)) {
mapIdToReduceIds.put(mapId, new ArrayList<>());
}
mapIdToReduceIds.get(mapId).add(blockIdParts[2]);
}
int[] mapIds = Ints.toArray(mapIdToReduceIds.keySet());
int[][] reduceIdArr = new int[mapIds.length][];
for (int i = 0; i < mapIds.length; i++) {
reduceIdArr[i] = Ints.toArray(mapIdToReduceIds.get(mapIds[i]));
}
return new FetchShuffleBlocks(appId, execId, shuffleId, mapIds, reduceIdArr);
}

/** Split the shuffleBlockId and return shuffleId, mapId and reduceId. */
private int[] splitBlockId(String blockId) {
String[] blockIdParts = blockId.split("_");
if (blockIdParts.length != 4 || !blockIdParts[0].equals("shuffle")) {
throw new IllegalArgumentException(
"Unexpected shuffle block id format: " + blockId);
}
return new int[] {
Integer.parseInt(blockIdParts[1]),
Integer.parseInt(blockIdParts[2]),
Integer.parseInt(blockIdParts[3])
};
}
public abstract BlockTransferMessage createBlockTransferMessage();

/** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */
private class ChunkCallback implements ChunkReceivedCallback {
Expand All @@ -169,7 +105,7 @@ public void onFailure(int chunkIndex, Throwable e) {
* {@link StreamHandle}. We will send all fetch requests immediately, without throttling.
*/
public void start() {
client.sendRpc(message.toByteBuffer(), new RpcResponseCallback() {
client.sendRpc(createBlockTransferMessage().toByteBuffer(), new RpcResponseCallback() {
@Override
public void onSuccess(ByteBuffer response) {
try {
Expand Down
Loading