From 18b1a12ab26662685a7a2ddae230f0cbd459381f Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Mon, 19 Nov 2018 11:06:12 +0800 Subject: [PATCH 1/9] Support rolling back a shuffle map stage and re-generate the shuffle files fix bug fix Address comment from Wenchen fix Address Wenchen comment Address comments from Wenchen 2 Address all shuffleGenerationId parameter reorder Address comments Address comments Address comments about tests --- .../spark/network/util/TransportConf.java | 2 +- .../network/shuffle/BlockStoreClient.java | 2 + .../network/shuffle/ExternalBlockHandler.java | 3 +- .../shuffle/ExternalBlockStoreClient.java | 3 +- .../shuffle/ExternalShuffleBlockResolver.java | 35 ++++- .../shuffle/OneForOneBlockFetcher.java | 10 +- .../shuffle/protocol/FetchShuffleBlocks.java | 12 +- .../shuffle/BlockTransferMessagesSuite.java | 2 +- ...anupNonShuffleServiceServedFilesSuite.java | 2 +- .../shuffle/ExternalBlockHandlerSuite.java | 10 +- .../ExternalShuffleBlockResolverSuite.java | 26 +++- .../shuffle/ExternalShuffleCleanupSuite.java | 2 +- .../ExternalShuffleIntegrationSuite.java | 40 ++++- .../shuffle/OneForOneBlockFetcherSuite.java | 5 +- .../shuffle/TestShuffleDataContext.java | 4 +- .../sort/BypassMergeSortShuffleWriter.java | 7 +- .../shuffle/sort/UnsafeShuffleWriter.java | 8 +- .../scala/org/apache/spark/SparkContext.scala | 1 + .../scala/org/apache/spark/TaskContext.scala | 9 ++ .../spark/internal/config/package.scala | 8 + .../spark/network/BlockDataManager.scala | 10 ++ .../spark/network/BlockTransferService.scala | 22 ++- .../network/netty/NettyBlockRpcServer.scala | 5 +- .../netty/NettyBlockTransferService.scala | 5 +- .../apache/spark/scheduler/DAGScheduler.scala | 55 +++++-- .../spark/scheduler/ShuffleMapStage.scala | 5 +- .../org/apache/spark/scheduler/Stage.scala | 6 +- .../shuffle/BlockStoreShuffleReader.scala | 3 +- .../shuffle/IndexShuffleBlockResolver.scala | 46 ++++-- .../spark/shuffle/ShuffleBlockResolver.scala | 2 +- .../shuffle/sort/SortShuffleManager.scala | 24 +-- .../shuffle/sort/SortShuffleWriter.scala | 12 +- .../org/apache/spark/storage/BlockId.scala | 2 +- .../apache/spark/storage/BlockManager.scala | 19 ++- .../storage/ShuffleBlockFetcherIterator.scala | 17 ++- .../sort/UnsafeShuffleWriterSuite.java | 10 +- .../spark/ExternalShuffleServiceSuite.scala | 17 ++- .../ExternalShuffleServiceDbSuite.scala | 2 +- .../network/BlockTransferServiceSuite.scala | 3 +- .../NettyBlockTransferSecuritySuite.scala | 5 +- .../NettyBlockTransferServiceSuite.scala | 2 +- .../spark/scheduler/DAGSchedulerSuite.scala | 139 ++++++++++++++++-- .../org/apache/spark/scheduler/FakeTask.scala | 4 +- .../scheduler/TaskSchedulerImplSuite.scala | 7 +- .../BlockStoreShuffleReaderSuite.scala | 4 +- .../BypassMergeSortShuffleWriterSuite.scala | 15 +- .../sort/IndexShuffleBlockResolverSuite.scala | 86 +++++++++-- .../shuffle/sort/SortShuffleWriterSuite.scala | 4 +- .../spark/storage/BlockManagerSuite.scala | 1 + .../ShuffleBlockFetcherIteratorSuite.scala | 45 +++--- 50 files changed, 611 insertions(+), 157 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index 589dfcbefb6ea..6e2b724fa591c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -344,7 +344,7 @@ public int chunkFetchHandlerThreads() { /** * Whether to use the old protocol while doing the shuffle block fetching. * It is only enabled while we need the compatibility in the scenario of new spark version - * job fetching blocks from old version external shuffle service. + * job fetching shuffle blocks from old version external shuffle service. */ public boolean useOldFetchProtocol() { return conf.getBoolean("spark.shuffle.useOldFetchProtocol", false); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java index fbbe8ac0f1f9b..008bb5795e44b 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java @@ -38,6 +38,7 @@ public abstract class BlockStoreClient implements Closeable { * @param host the host of the remote node. * @param port the port of the remote node. * @param execId the executor id. + * @param shuffleGenerationId the shuffle generation id for all block ids to fetch. * @param blockIds block ids to fetch. * @param listener the listener to receive block fetching status. * @param downloadFileManager DownloadFileManager to create and clean temp files. @@ -49,6 +50,7 @@ public abstract void fetchBlocks( String host, int port, String execId, + int shuffleGenerationId, String[] blockIds, BlockFetchingListener listener, DownloadFileManager downloadFileManager); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java index 037e5cf7e5222..2ecd186a42ac6 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java @@ -311,7 +311,8 @@ private int[] shuffleMapIdAndReduceIds(String[] blockIds, int shuffleId) { assert(idx == 2 * numBlockIds); size = mapIdAndReduceIds.length; blockDataForIndexFn = index -> blockManager.getBlockData(msg.appId, msg.execId, - msg.shuffleId, mapIdAndReduceIds[index], mapIdAndReduceIds[index + 1]); + msg.shuffleId, msg.shuffleGenerationId, mapIdAndReduceIds[index], + mapIdAndReduceIds[index + 1]); } @Override diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java index b8e52c8621fb6..cc86cf6e64700 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java @@ -94,6 +94,7 @@ public void fetchBlocks( String host, int port, String execId, + int shuffleGenerationId, String[] blockIds, BlockFetchingListener listener, DownloadFileManager downloadFileManager) { @@ -103,7 +104,7 @@ public void fetchBlocks( RetryingBlockFetcher.BlockFetchStarter blockFetchStarter = (blockIds1, listener1) -> { TransportClient client = clientFactory.createClient(host, port); - new OneForOneBlockFetcher(client, appId, execId, + new OneForOneBlockFetcher(client, appId, execId, shuffleGenerationId, blockIds1, listener1, conf, downloadFileManager).start(); }; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 50f16fc700f12..8ed1887ed9fbc 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -164,6 +164,18 @@ public void registerExecutor( executors.put(fullId, executorInfo); } + /** + * Overload getBlockData with setting shuffleGenerationId to an invalid value of -1. + */ + public ManagedBuffer getBlockData( + String appId, + String execId, + int shuffleId, + int mapId, + int reduceId) { + return getBlockData(appId, execId, shuffleId, -1, mapId, reduceId); + } + /** * Obtains a FileSegmentManagedBuffer from (shuffleId, mapId, reduceId). We make assumptions * about how the hash and sort based shuffles store their data. @@ -172,6 +184,7 @@ public ManagedBuffer getBlockData( String appId, String execId, int shuffleId, + int shuffleGenerationId, int mapId, int reduceId) { ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId)); @@ -179,7 +192,8 @@ public ManagedBuffer getBlockData( throw new RuntimeException( String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId)); } - return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId); + return getSortBasedShuffleBlockData( + executor, shuffleId, shuffleGenerationId, mapId, reduceId); } public ManagedBuffer getRddBlockData( @@ -291,22 +305,29 @@ private void deleteNonShuffleServiceServedFiles(String[] dirs) { } /** - * Sort-based shuffle data uses an index called "shuffle_ShuffleId_MapId_0.index" into a data file - * called "shuffle_ShuffleId_MapId_0.data". This logic is from IndexShuffleBlockResolver, + * Sort-based shuffle data uses an index called "shuffle_ShuffleId_MapId_0.index" into a data + * file called "shuffle_ShuffleId_MapId_0.data". This logic is from IndexShuffleBlockResolver, * and the block id format is from ShuffleDataBlockId and ShuffleIndexBlockId. + * While the shuffle data and index file generated from the indeterminate stage, + * the ShuffleDataBlockId and ShuffleIndexBlockId will be extended by the shuffle generation id. */ private ManagedBuffer getSortBasedShuffleBlockData( - ExecutorShuffleInfo executor, int shuffleId, int mapId, int reduceId) { + ExecutorShuffleInfo executor, int shuffleId, int shuffleGenerationId, + int mapId, int reduceId) { + String baseFileName = "shuffle_" + shuffleId + "_" + mapId + "_0"; + if (shuffleGenerationId != -1) { + baseFileName = baseFileName + "_" + shuffleGenerationId; + } File indexFile = ExecutorDiskUtils.getFile(executor.localDirs, executor.subDirsPerLocalDir, - "shuffle_" + shuffleId + "_" + mapId + "_0.index"); + baseFileName + ".index"); try { ShuffleIndexInformation shuffleIndexInformation = shuffleIndexCache.get(indexFile); ShuffleIndexRecord shuffleIndexRecord = shuffleIndexInformation.getIndex(reduceId); return new FileSegmentManagedBuffer( conf, - ExecutorDiskUtils.getFile(executor.localDirs, executor.subDirsPerLocalDir, - "shuffle_" + shuffleId + "_" + mapId + "_0.data"), + ExecutorDiskUtils.getFile( + executor.localDirs, executor.subDirsPerLocalDir, baseFileName + ".data"), shuffleIndexRecord.getOffset(), shuffleIndexRecord.getLength()); } catch (ExecutionException e) { diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index cc11e92067375..53ee4351203c5 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -68,13 +68,14 @@ public OneForOneBlockFetcher( String[] blockIds, BlockFetchingListener listener, TransportConf transportConf) { - this(client, appId, execId, blockIds, listener, transportConf, null); + this(client, appId, execId, -1, blockIds, listener, transportConf, null); } public OneForOneBlockFetcher( TransportClient client, String appId, String execId, + int shuffleGenerationId, String[] blockIds, BlockFetchingListener listener, TransportConf transportConf, @@ -89,7 +90,7 @@ public OneForOneBlockFetcher( throw new IllegalArgumentException("Zero-sized blockIds array"); } if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) { - this.message = createFetchShuffleBlocksMsg(appId, execId, blockIds); + this.message = createFetchShuffleBlocksMsg(appId, execId, shuffleGenerationId, blockIds); } else { this.message = new OpenBlocks(appId, execId, blockIds); } @@ -110,7 +111,7 @@ private boolean isShuffleBlocks(String[] blockIds) { * org.apache.spark.MapOutputTracker.convertMapStatuses. */ private FetchShuffleBlocks createFetchShuffleBlocksMsg( - String appId, String execId, String[] blockIds) { + String appId, String execId, int shuffleGenerationId, String[] blockIds) { int shuffleId = splitBlockId(blockIds[0])[0]; HashMap> mapIdToReduceIds = new HashMap<>(); for (String blockId : blockIds) { @@ -130,7 +131,8 @@ private FetchShuffleBlocks createFetchShuffleBlocksMsg( for (int i = 0; i < mapIds.length; i++) { reduceIdArr[i] = Ints.toArray(mapIdToReduceIds.get(mapIds[i])); } - return new FetchShuffleBlocks(appId, execId, shuffleId, mapIds, reduceIdArr); + return new FetchShuffleBlocks( + appId, execId, shuffleId, shuffleGenerationId, mapIds, reduceIdArr); } /** Split the shuffleBlockId and return shuffleId, mapId and reduceId. */ diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java index 466eeb3e048a8..b9e1228a5c152 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java @@ -32,6 +32,7 @@ public class FetchShuffleBlocks extends BlockTransferMessage { public final String appId; public final String execId; public final int shuffleId; + public final int shuffleGenerationId; // The length of mapIds must equal to reduceIds.size(), for the i-th mapId in mapIds, // it corresponds to the i-th int[] in reduceIds, which contains all reduce id for this map id. public final int[] mapIds; @@ -41,11 +42,13 @@ public FetchShuffleBlocks( String appId, String execId, int shuffleId, + int shuffleGenerationId, int[] mapIds, int[][] reduceIds) { this.appId = appId; this.execId = execId; this.shuffleId = shuffleId; + this.shuffleGenerationId = shuffleGenerationId; this.mapIds = mapIds; this.reduceIds = reduceIds; assert(mapIds.length == reduceIds.length); @@ -60,6 +63,7 @@ public String toString() { .add("appId", appId) .add("execId", execId) .add("shuffleId", shuffleId) + .add("shuffleGenerationId", shuffleGenerationId) .add("mapIds", Arrays.toString(mapIds)) .add("reduceIds", Arrays.deepToString(reduceIds)) .toString(); @@ -73,6 +77,7 @@ public boolean equals(Object o) { FetchShuffleBlocks that = (FetchShuffleBlocks) o; if (shuffleId != that.shuffleId) return false; + if (shuffleGenerationId != that.shuffleGenerationId) return false; if (!appId.equals(that.appId)) return false; if (!execId.equals(that.execId)) return false; if (!Arrays.equals(mapIds, that.mapIds)) return false; @@ -84,6 +89,7 @@ public int hashCode() { int result = appId.hashCode(); result = 31 * result + execId.hashCode(); result = 31 * result + shuffleId; + result = 31 * result + shuffleGenerationId; result = 31 * result + Arrays.hashCode(mapIds); result = 31 * result + Arrays.deepHashCode(reduceIds); return result; @@ -98,6 +104,7 @@ public int encodedLength() { return Encoders.Strings.encodedLength(appId) + Encoders.Strings.encodedLength(execId) + 4 /* encoded length of shuffleId */ + + 4 /* encoded length of shuffleGenerationId */ + Encoders.IntArrays.encodedLength(mapIds) + 4 /* encoded length of reduceIds.size() */ + encodedLengthOfReduceIds; @@ -108,6 +115,7 @@ public void encode(ByteBuf buf) { Encoders.Strings.encode(buf, appId); Encoders.Strings.encode(buf, execId); buf.writeInt(shuffleId); + buf.writeInt(shuffleGenerationId); Encoders.IntArrays.encode(buf, mapIds); buf.writeInt(reduceIds.length); for (int[] ids: reduceIds) { @@ -119,12 +127,14 @@ public static FetchShuffleBlocks decode(ByteBuf buf) { String appId = Encoders.Strings.decode(buf); String execId = Encoders.Strings.decode(buf); int shuffleId = buf.readInt(); + int shuffleGenerationId = buf.readInt(); int[] mapIds = Encoders.IntArrays.decode(buf); int reduceIdsSize = buf.readInt(); int[][] reduceIds = new int[reduceIdsSize][]; for (int i = 0; i < reduceIdsSize; i++) { reduceIds[i] = Encoders.IntArrays.decode(buf); } - return new FetchShuffleBlocks(appId, execId, shuffleId, mapIds, reduceIds); + return new FetchShuffleBlocks( + appId, execId, shuffleId, shuffleGenerationId, mapIds, reduceIds); } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java index 649c471dc1679..9a03965b53285 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java @@ -29,7 +29,7 @@ public class BlockTransferMessagesSuite { public void serializeOpenShuffleBlocks() { checkSerializeDeserialize(new OpenBlocks("app-1", "exec-2", new String[] { "b1", "b2" })); checkSerializeDeserialize(new FetchShuffleBlocks( - "app-1", "exec-2", 0, new int[] {0, 1}, + "app-1", "exec-2", 0, -1, new int[] {0, 1}, new int[][] {{ 0, 1 }, { 0, 1, 2 }})); checkSerializeDeserialize(new RegisterExecutor("app-1", "exec-2", new ExecutorShuffleInfo( new String[] { "/local1", "/local2" }, 32, "MyShuffleManager"))); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/CleanupNonShuffleServiceServedFilesSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/CleanupNonShuffleServiceServedFilesSuite.java index e38442327e22d..5136f2495f363 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/CleanupNonShuffleServiceServedFilesSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/CleanupNonShuffleServiceServedFilesSuite.java @@ -243,7 +243,7 @@ private static void createFilesToKeep(TestShuffleDataContext dataContext) throws Random rand = new Random(123); dataContext.insertSortShuffleData(rand.nextInt(1000), rand.nextInt(1000), new byte[][] { "ABC".getBytes(StandardCharsets.UTF_8), - "DEF".getBytes(StandardCharsets.UTF_8)}); + "DEF".getBytes(StandardCharsets.UTF_8)}, false); dataContext.insertCachedRddData(12, 34, new byte[] { 42 }); } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java index 9c623a70424b6..dadb95ef804b2 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java @@ -97,15 +97,15 @@ public void testCompatibilityWithOldVersion() { @Test public void testFetchShuffleBlocks() { - when(blockResolver.getBlockData("app0", "exec1", 0, 0, 0)).thenReturn(blockMarkers[0]); - when(blockResolver.getBlockData("app0", "exec1", 0, 0, 1)).thenReturn(blockMarkers[1]); + when(blockResolver.getBlockData("app0", "exec1", 0, 1, 0, 0)).thenReturn(blockMarkers[0]); + when(blockResolver.getBlockData("app0", "exec1", 0, 1, 0, 1)).thenReturn(blockMarkers[1]); FetchShuffleBlocks fetchShuffleBlocks = new FetchShuffleBlocks( - "app0", "exec1", 0, new int[] { 0 }, new int[][] {{ 0, 1 }}); + "app0", "exec1", 0, 1, new int[] { 0 }, new int[][] {{ 0, 1 }}); checkOpenBlocksReceive(fetchShuffleBlocks, blockMarkers); - verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 0); - verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 1); + verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 1, 0, 0); + verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 1, 0, 1); verifyOpenBlockLatencyMetrics(); } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index 09eb699be305a..2ff434debf037 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -53,7 +53,10 @@ public static void beforeAll() throws IOException { // Write some sort data. dataContext.insertSortShuffleData(0, 0, new byte[][] { sortBlock0.getBytes(StandardCharsets.UTF_8), - sortBlock1.getBytes(StandardCharsets.UTF_8)}); + sortBlock1.getBytes(StandardCharsets.UTF_8)}, false); + dataContext.insertSortShuffleData(0, 0, new byte[][] { + sortBlock0.getBytes(StandardCharsets.UTF_8), + sortBlock1.getBytes(StandardCharsets.UTF_8)}, true); } @AfterClass @@ -113,6 +116,27 @@ public void testSortShuffleBlocks() throws IOException { } } + @Test + public void testIndeterminateSortShuffleBlocks() throws IOException { + ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null); + resolver.registerExecutor("app0", "exec0", + dataContext.createExecutorInfo(SORT_MANAGER)); + + try (InputStream block0Stream = resolver.getBlockData( + "app0", "exec0", 0, 0, 0, 0).createInputStream()) { + String block0 = + CharStreams.toString(new InputStreamReader(block0Stream, StandardCharsets.UTF_8)); + assertEquals(sortBlock0, block0); + } + + try (InputStream block1Stream = resolver.getBlockData( + "app0", "exec0", 0, 0, 0, 1).createInputStream()) { + String block1 = + CharStreams.toString(new InputStreamReader(block1Stream, StandardCharsets.UTF_8)); + assertEquals(sortBlock1, block1); + } + } + @Test public void jsonSerializationOfExecutorRegistration() throws IOException { ObjectMapper mapper = new ObjectMapper(); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java index 47c087088a8a2..259cc2e04b7d6 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java @@ -142,7 +142,7 @@ private static TestShuffleDataContext createSomeData() throws IOException { dataContext.create(); dataContext.insertSortShuffleData(rand.nextInt(1000), rand.nextInt(1000), new byte[][] { "ABC".getBytes(StandardCharsets.UTF_8), - "DEF".getBytes(StandardCharsets.UTF_8)}); + "DEF".getBytes(StandardCharsets.UTF_8)}, false); return dataContext; } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 61a58e9e456fd..d44f57461a59f 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -101,7 +101,8 @@ public static void beforeAll() throws IOException { dataContext0 = new TestShuffleDataContext(2, 5); dataContext0.create(); - dataContext0.insertSortShuffleData(0, 0, exec0Blocks); + dataContext0.insertSortShuffleData(0, 0, exec0Blocks, false); + dataContext0.insertSortShuffleData(0, 0, exec0Blocks, true); dataContext0.insertCachedRddData(RDD_ID, SPLIT_INDEX_VALID_BLOCK, exec0RddBlockValid); dataContext0.insertCachedRddData(RDD_ID, SPLIT_INDEX_VALID_BLOCK_TO_RM, exec0RddBlockToRemove); @@ -159,7 +160,13 @@ public void releaseBuffers() { // Fetch a set of blocks from a pre-registered executor. private FetchResult fetchBlocks(String execId, String[] blockIds) throws Exception { - return fetchBlocks(execId, blockIds, conf, server.getPort()); + return fetchBlocks(execId, blockIds, conf, server.getPort(), -1); + } + + // Fetch a set of blocks from a pre-registered executor. + private FetchResult fetchBlocks( + String execId, String[] blockIds, int shuffleGenerationId) throws Exception { + return fetchBlocks(execId, blockIds, conf, server.getPort(), shuffleGenerationId); } // Fetch a set of blocks from a pre-registered executor. Connects to the server on the given port, @@ -168,7 +175,8 @@ private FetchResult fetchBlocks( String execId, String[] blockIds, TransportConf clientConf, - int port) throws Exception { + int port, + int shuffleGenerationId) throws Exception { final FetchResult res = new FetchResult(); res.successBlocks = Collections.synchronizedSet(new HashSet()); res.failedBlocks = Collections.synchronizedSet(new HashSet()); @@ -179,7 +187,7 @@ private FetchResult fetchBlocks( try (ExternalBlockStoreClient client = new ExternalBlockStoreClient( clientConf, null, false, 5000)) { client.init(APP_ID); - client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds, + client.fetchBlocks(TestUtils.getLocalHost(), port, execId, shuffleGenerationId, blockIds, new BlockFetchingListener() { @Override public void onBlockFetchSuccess(String blockId, ManagedBuffer data) { @@ -233,6 +241,28 @@ public void testFetchThreeSort() throws Exception { exec0Fetch.releaseBuffers(); } + @Test + public void testFetchOneSortWithShuffleGenerationId() throws Exception { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult exec0Fetch = fetchBlocks("exec-0", new String[] { "shuffle_0_0_0" }, 0); + assertEquals(Sets.newHashSet("shuffle_0_0_0"), exec0Fetch.successBlocks); + assertTrue(exec0Fetch.failedBlocks.isEmpty()); + assertBufferListsEqual(exec0Fetch.buffers, Arrays.asList(exec0Blocks[0])); + exec0Fetch.releaseBuffers(); + } + + @Test + public void testFetchThreeSortWithShuffleGenerationId() throws Exception { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult exec0Fetch = fetchBlocks("exec-0", + new String[] { "shuffle_0_0_0", "shuffle_0_0_1", "shuffle_0_0_2" }, 0); + assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_0_0_1", "shuffle_0_0_2"), + exec0Fetch.successBlocks); + assertTrue(exec0Fetch.failedBlocks.isEmpty()); + assertBufferListsEqual(exec0Fetch.buffers, Arrays.asList(exec0Blocks)); + exec0Fetch.releaseBuffers(); + } + @Test (expected = RuntimeException.class) public void testRegisterInvalidExecutor() throws Exception { registerExecutor("exec-1", dataContext0.createExecutorInfo("unknown sort manager")); @@ -325,7 +355,7 @@ public void testFetchNoServer() throws Exception { new MapConfigProvider(ImmutableMap.of("spark.shuffle.io.maxRetries", "0"))); registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); FetchResult execFetch = fetchBlocks("exec-0", - new String[]{"shuffle_1_0_0", "shuffle_1_0_1"}, clientConf, 1 /* port */); + new String[]{"shuffle_1_0_0", "shuffle_1_0_1"}, clientConf, 1 /* port */, -1); assertTrue(execFetch.successBlocks.isEmpty()); assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks); } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index 66633cc7a3595..45830f77b438b 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -64,7 +64,7 @@ public void testFetchOne() { BlockFetchingListener listener = fetchBlocks( blocks, blockIds, - new FetchShuffleBlocks("app-id", "exec-id", 0, new int[] { 0 }, new int[][] {{ 0 }}), + new FetchShuffleBlocks("app-id", "exec-id", 0, -1, new int[] { 0 }, new int[][] {{ 0 }}), conf); verify(listener).onBlockFetchSuccess("shuffle_0_0_0", blocks.get("shuffle_0_0_0")); @@ -100,7 +100,8 @@ public void testFetchThreeShuffleBlocks() { BlockFetchingListener listener = fetchBlocks( blocks, blockIds, - new FetchShuffleBlocks("app-id", "exec-id", 0, new int[] { 0 }, new int[][] {{ 0, 1, 2 }}), + new FetchShuffleBlocks( + "app-id", "exec-id", 0, -1, new int[] { 0 }, new int[][] {{ 0, 1, 2 }}), conf); for (int i = 0; i < 3; i ++) { diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java index 457805feeac45..db4d761757888 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java @@ -67,8 +67,10 @@ public void cleanup() { } /** Creates reducer blocks in a sort-based data format within our local dirs. */ - public void insertSortShuffleData(int shuffleId, int mapId, byte[][] blocks) throws IOException { + public void insertSortShuffleData( + int shuffleId, int mapId, byte[][] blocks, boolean indeterminateBlock) throws IOException { String blockId = "shuffle_" + shuffleId + "_" + mapId + "_0"; + if (indeterminateBlock) blockId += "_0"; OutputStream dataStream = null; DataOutputStream indexStream = null; diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 3ccee703619b4..d167145de00e8 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -39,6 +39,7 @@ import org.apache.spark.Partitioner; import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; +import org.apache.spark.TaskContext; import org.apache.spark.shuffle.api.ShuffleExecutorComponents; import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; import org.apache.spark.shuffle.api.ShufflePartitionWriter; @@ -85,6 +86,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { private final Partitioner partitioner; private final ShuffleWriteMetricsReporter writeMetrics; private final int shuffleId; + private final int shuffleGenerationId; private final int mapId; private final long mapTaskAttemptId; private final Serializer serializer; @@ -107,7 +109,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { BlockManager blockManager, BypassMergeSortShuffleHandle handle, int mapId, - long mapTaskAttemptId, + TaskContext taskContext, SparkConf conf, ShuffleWriteMetricsReporter writeMetrics, ShuffleExecutorComponents shuffleExecutorComponents) { @@ -117,8 +119,9 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { this.blockManager = blockManager; final ShuffleDependency dep = handle.dependency(); this.mapId = mapId; - this.mapTaskAttemptId = mapTaskAttemptId; + this.mapTaskAttemptId = taskContext.taskAttemptId(); this.shuffleId = dep.shuffleId(); + this.shuffleGenerationId = taskContext.getShuffleGenerationId(dep.shuffleId()); this.partitioner = dep.partitioner(); this.numPartitions = partitioner.numPartitions(); this.writeMetrics = writeMetrics; diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 9d05f03613ce9..66e476dbe611d 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -75,6 +75,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final Partitioner partitioner; private final ShuffleWriteMetricsReporter writeMetrics; private final int shuffleId; + private final int shuffleGenerationId; private final int mapId; private final TaskContext taskContext; private final SparkConf sparkConf; @@ -137,6 +138,7 @@ public UnsafeShuffleWriter( this.mapId = mapId; final ShuffleDependency dep = handle.dependency(); this.shuffleId = dep.shuffleId(); + this.shuffleGenerationId = taskContext.getShuffleGenerationId(dep.shuffleId()); this.serializer = dep.serializer().newInstance(); this.partitioner = dep.partitioner(); this.writeMetrics = writeMetrics; @@ -231,7 +233,8 @@ void closeAndWriteOutput() throws IOException { final SpillInfo[] spills = sorter.closeAndGetSpills(); sorter = null; final long[] partitionLengths; - final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); + final File output = shuffleBlockResolver.getDataFile( + shuffleId, shuffleGenerationId, mapId); final File tmp = Utils.tempFileWith(output); try { try { @@ -243,7 +246,8 @@ void closeAndWriteOutput() throws IOException { } } } - shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); + shuffleBlockResolver.writeIndexFileAndCommit( + shuffleId, shuffleGenerationId, mapId, partitionLengths, tmp); } finally { if (tmp.exists() && !tmp.delete()) { logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index aa71b21caa30e..f2ff7cbd49748 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2591,6 +2591,7 @@ object SparkContext extends Logging { private[spark] val SPARK_SCHEDULER_POOL = "spark.scheduler.pool" private[spark] val RDD_SCOPE_KEY = "spark.rdd.scope" private[spark] val RDD_SCOPE_NO_OVERRIDE_KEY = "spark.rdd.scope.noOverride" + private[spark] val SHUFFLE_GENERATION_ID_PREFIX = "_shuffle_generation_id_" /** * Executor id for the driver. In earlier versions of Spark, this was ``, but this was diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 2299c54e2624b..655ba91d706e6 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -185,6 +185,15 @@ abstract class TaskContext extends Serializable { @Evolving def resources(): Map[String, ResourceInformation] + /** + * The shuffle generation ID of the stage that this task belongs to, it returns the stage + * attempt number while the stage is not determinate and returns -1 on the contrary. + */ + private[spark] def getShuffleGenerationId(shuffleId: Int): Int = { + Option(getLocalProperty(SparkContext.SHUFFLE_GENERATION_ID_PREFIX + shuffleId)) + .map(_.toInt).getOrElse(-1) + } + @DeveloperApi def taskMetrics(): TaskMetrics diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index e0147218d3eb7..0542c6ed13be3 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1030,6 +1030,14 @@ package object config { .checkValue(v => v > 0, "The value should be a positive integer.") .createWithDefault(2000) + private[spark] val SHUFFLE_USE_OLD_FETCH_PROTOCOL = + ConfigBuilder("spark.shuffle.useOldFetchProtocol") + .doc("Whether to use the old protocol while doing the shuffle block fetching. " + + "It is only enabled while we need the compatibility in the scenario of new spark " + + "version job fetching shuffle blocks from old version external shuffle service.") + .booleanConf + .createWithDefault(false) + private[spark] val MEMORY_MAP_LIMIT_FOR_TESTS = ConfigBuilder("spark.storage.memoryMapLimitForTests") .internal() diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index 4993519aa3843..133ffcaaf51d1 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -33,6 +33,16 @@ trait BlockDataManager { */ def getBlockData(blockId: BlockId): ManagedBuffer + /** + * Interface to get shuffle block data. Throws an exception if the block cannot be found or + * cannot be read successfully. + */ + def getShuffleBlockData( + shuffleId: Int, + shuffleGenerationId: Int, + mapId: Int, + reduceId: Int): ManagedBuffer + /** * Put the block locally, using the given storage level. * diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index 70a159f3eeecf..259e6a56869a2 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -53,6 +53,23 @@ abstract class BlockTransferService extends BlockStoreClient with Logging { */ def hostName: String + /** + * Fetch a sequence of blocks from a remote node asynchronously, + * available only after [[init]] is invoked. + * + * Note that this API takes a sequence so the implementation can batch requests, and does not + * return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as + * the data of a block is fetched, rather than waiting for all blocks to be fetched. + */ + override def fetchBlocks( + host: String, + port: Int, + execId: String, + shuffleGenerationId: Int, + blockIds: Array[String], + listener: BlockFetchingListener, + tempFileManager: DownloadFileManager): Unit + /** * Upload a single block to a remote node, available only after [[init]] is invoked. */ @@ -76,9 +93,12 @@ abstract class BlockTransferService extends BlockStoreClient with Logging { execId: String, blockId: String, tempFileManager: DownloadFileManager): ManagedBuffer = { + // Make sure ShuffleBlockId will not enter this function, so we call fetchBlocks with the + // invalid shuffleGenerationId -1 here for this special case of fetching a non-shuffle block. + assert(!BlockId.apply(blockId).isShuffle) // A monitor for the thread to wait on. val result = Promise[ManagedBuffer]() - fetchBlocks(host, port, execId, Array(blockId), + fetchBlocks(host, port, execId, -1, Array(blockId), new BlockFetchingListener { override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = { result.failure(exception) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index b2ab31488e4c1..2019679fa5fc9 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -66,8 +66,9 @@ class NettyBlockRpcServer( case fetchShuffleBlocks: FetchShuffleBlocks => val blocks = fetchShuffleBlocks.mapIds.zipWithIndex.flatMap { case (mapId, index) => fetchShuffleBlocks.reduceIds.apply(index).map { reduceId => - blockManager.getBlockData( - ShuffleBlockId(fetchShuffleBlocks.shuffleId, mapId, reduceId)) + blockManager.getShuffleBlockData( + fetchShuffleBlocks.shuffleId, fetchShuffleBlocks.shuffleGenerationId, + mapId, reduceId) } } val numBlockIds = fetchShuffleBlocks.reduceIds.map(_.length).sum diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index b12cd4254f19e..656ec4436f46d 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -110,6 +110,7 @@ private[spark] class NettyBlockTransferService( host: String, port: Int, execId: String, + shuffleGenerationId: Int, blockIds: Array[String], listener: BlockFetchingListener, tempFileManager: DownloadFileManager): Unit = { @@ -119,8 +120,8 @@ private[spark] class NettyBlockTransferService( override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) { val client = clientFactory.createClient(host, port) try { - new OneForOneBlockFetcher(client, appId, execId, blockIds, listener, - transportConf, tempFileManager).start() + new OneForOneBlockFetcher(client, appId, execId, shuffleGenerationId, blockIds, + listener, transportConf, tempFileManager).start() } catch { case e: IOException => Try { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 482691c94f87e..94ab4087f7157 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -39,7 +39,7 @@ import org.apache.spark.internal.config import org.apache.spark.internal.config.Tests.TEST_NO_STAGE_RETRY import org.apache.spark.network.util.JavaUtils import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} -import org.apache.spark.rdd.{DeterministicLevel, RDD, RDDCheckpointData} +import org.apache.spark.rdd.{RDD, RDDCheckpointData} import org.apache.spark.rpc.RpcTimeout import org.apache.spark.storage._ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat @@ -1100,7 +1100,14 @@ private[spark] class DAGScheduler( private def submitMissingTasks(stage: Stage, jobId: Int) { logDebug("submitMissingTasks(" + stage + ")") - // First figure out the indexes of partition ids to compute. + // Before find missing partition, do the intermediate state clean work first. + stage match { + case sms: ShuffleMapStage if stage.isIndeterminate() => + mapOutputTracker.unregisterAllMapOutput(sms.shuffleDep.shuffleId) + case _ => + } + + // Figure out the indexes of partition ids to compute. val partitionsToCompute: Seq[Int] = stage.findMissingPartitions() // Use the scheduling pool, job group, description, etc. from an ActiveJob associated @@ -1139,12 +1146,28 @@ private[spark] class DAGScheduler( } stage.makeNewStageAttempt(partitionsToCompute.size, taskIdToLocations.values.toSeq) - - // If there are tasks to execute, record the submission time of the stage. Otherwise, - // post the even without the submission time, which indicates that this stage was - // skipped. if (partitionsToCompute.nonEmpty) { + // If there are tasks to execute, record the submission time of the stage. Otherwise, + // post the even without the submission time, which indicates that this stage was + // skipped. stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) + + // While an indeterminate stage retried, the stage attempt id will be used to extend the + // shuffle file in shuffle write task, and then the mapping of shuffle id to indeterminate + // stage id will be used for shuffle reader task. + val stageAttemptId = stage.latestInfo.attemptNumber() + if (stageAttemptId > 0 && stage.isIndeterminate()) { + // deal with shuffle writer side property. + stage match { + case sms: ShuffleMapStage => + properties.setProperty( + SparkContext.SHUFFLE_GENERATION_ID_PREFIX + sms.shuffleDep.shuffleId, + stageAttemptId.toString) + logInfo(s"Set SHUFFLE_GENERATION_ID for $stage(shuffleId:" + + s" ${sms.shuffleDep.shuffleId}) to $stageAttemptId") + case _ => + } + } } listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties)) @@ -1558,7 +1581,6 @@ private[spark] class DAGScheduler( } abortStage(failedStage, abortMessage, None) } else { // update failedStages and make sure a ResubmitFailedStages event is enqueued - // TODO: Cancel running tasks in the failed stage -- cf. SPARK-17064 val noResubmitEnqueued = !failedStages.contains(failedStage) failedStages += failedStage failedStages += mapStage @@ -1570,7 +1592,7 @@ private[spark] class DAGScheduler( // Note that, if map stage is UNORDERED, we are fine. The shuffle partitioner is // guaranteed to be determinate, so the input data of the reducers will not change // even if the map tasks are re-tried. - if (mapStage.rdd.outputDeterministicLevel == DeterministicLevel.INDETERMINATE) { + if (mapStage.isIndeterminate()) { // It's a little tricky to find all the succeeding stages of `failedStage`, because // each stage only know its parents not children. Here we traverse the stages from // the leaf nodes (the result stages of active jobs), and rollback all the stages @@ -1602,11 +1624,18 @@ private[spark] class DAGScheduler( case mapStage: ShuffleMapStage => val numMissingPartitions = mapStage.findMissingPartitions().length if (numMissingPartitions < mapStage.numTasks) { - // TODO: support to rollback shuffle files. - // Currently the shuffle writing is "first write wins", so we can't re-run a - // shuffle map stage and overwrite existing shuffle files. We have to finish - // SPARK-8029 first. - abortStage(mapStage, generateErrorMessage(mapStage), None) + if (sc.getConf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)) { + val reason = "A shuffle map stage with indeterminate output was failed " + + "and retried. However, Spark can only do this while using the new " + + "shuffle block fetching protocol. Please check the config " + + "'spark.shuffle.useOldFetchProtocol', see more detail in " + + "SPARK-27665 and SPARK-25341." + abortStage(mapStage, reason, None) + } else { + logInfo(s"The indeterminate stage $mapStage will be resubmitted," + + " the stage self and all indeterminate parent stage will be" + + " rollback and whole stage rerun.") + } } case resultStage: ResultStage if resultStage.activeJob.isDefined => diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala index 1b44d0aee3195..522fc99e73a85 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala @@ -87,7 +87,10 @@ private[spark] class ShuffleMapStage( */ def isAvailable: Boolean = numAvailableOutputs == numPartitions - /** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */ + /** + * Returns the sequence of partition ids that are missing (i.e. needs to be computed). + * If the current stage is indeterminate, missing partition is all partitions every time. + */ override def findMissingPartitions(): Seq[Int] = { mapOutputTrackerMaster .findMissingPartitions(shuffleDep.shuffleId) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 26cca334d3bd5..135d0036ef734 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.HashSet import org.apache.spark.executor.TaskMetrics import org.apache.spark.internal.Logging -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{DeterministicLevel, RDD} import org.apache.spark.util.CallSite /** @@ -116,4 +116,8 @@ private[scheduler] abstract class Stage( /** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */ def findMissingPartitions(): Seq[Int] + + def isIndeterminate(): Boolean = { + rdd.outputDeterministicLevel == DeterministicLevel.INDETERMINATE + } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 4329824b1b627..40e405ffa80db 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -56,7 +56,8 @@ private[spark] class BlockStoreShuffleReader[K, C]( SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY), - readMetrics).toCompletionIterator + readMetrics, + context.getShuffleGenerationId(handle.shuffleId)).toCompletionIterator val serializerInstance = dep.serializer.newInstance() diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index d3f1c7ec1bbee..3c2356a3e5f6e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -51,26 +51,47 @@ private[spark] class IndexShuffleBlockResolver( private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") - def getDataFile(shuffleId: Int, mapId: Int): File = { - blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) + def getDataFile( + shuffleId: Int, + shuffleGenerationId: Int, + mapId: Int): File = { + blockManager.diskBlockManager.getFile( + generateFileName(shuffleGenerationId, ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID))) + } + + private def getIndexFile( + shuffleId: Int, + shuffleGenerationId: Int, + mapId: Int): File = { + blockManager.diskBlockManager.getFile( + generateFileName(shuffleGenerationId, ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID))) } - private def getIndexFile(shuffleId: Int, mapId: Int): File = { - blockManager.diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) + /** + * Generate the file name from shuffle generation id and block id. + */ + private def generateFileName(shuffleGenerationId: Int, blockId: BlockId): String = { + val generationId = if (shuffleGenerationId != -1) { + "_" + shuffleGenerationId + } else { + "" + } + blockId.name.replace(".", generationId + ".") } /** * Remove data file and index file that contain the output data from one map. */ - def removeDataByMap(shuffleId: Int, mapId: Int): Unit = { - var file = getDataFile(shuffleId, mapId) + def removeDataByMap( + shuffleId: Int, shuffleGenerationId: Int, mapId: Int): Unit = { + var file = getDataFile(shuffleId, shuffleGenerationId, mapId) if (file.exists()) { if (!file.delete()) { logWarning(s"Error deleting data ${file.getPath()}") } } - file = getIndexFile(shuffleId, mapId) + file = getIndexFile(shuffleId, shuffleGenerationId, mapId) if (file.exists()) { if (!file.delete()) { logWarning(s"Error deleting index ${file.getPath()}") @@ -135,13 +156,14 @@ private[spark] class IndexShuffleBlockResolver( */ def writeIndexFileAndCommit( shuffleId: Int, + shuffleGenerationId: Int, mapId: Int, lengths: Array[Long], dataTmp: File): Unit = { - val indexFile = getIndexFile(shuffleId, mapId) + val indexFile = getIndexFile(shuffleId, shuffleGenerationId, mapId) val indexTmp = Utils.tempFileWith(indexFile) try { - val dataFile = getDataFile(shuffleId, mapId) + val dataFile = getDataFile(shuffleId, shuffleGenerationId, mapId) // There is only one IndexShuffleBlockResolver per executor, this synchronization make sure // the following check and rename are atomic. synchronized { @@ -190,10 +212,10 @@ private[spark] class IndexShuffleBlockResolver( } } - override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { + override def getBlockData(shuffleGenerationId: Int, blockId: ShuffleBlockId): ManagedBuffer = { // The block is actually going to be a range of a single map output file for this map, so // find out the consolidated file, then the offset within that from our index - val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId) + val indexFile = getIndexFile(blockId.shuffleId, shuffleGenerationId, blockId.mapId) // SPARK-22982: if this FileInputStream's position is seeked forward by another piece of code // which is incorrectly using our file descriptor then this code will fetch the wrong offsets @@ -215,7 +237,7 @@ private[spark] class IndexShuffleBlockResolver( } new FileSegmentManagedBuffer( transportConf, - getDataFile(blockId.shuffleId, blockId.mapId), + getDataFile(blockId.shuffleId, shuffleGenerationId, blockId.mapId), offset, nextOffset - offset) } finally { diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala index d1ecbc1bf0178..1486321f6fb70 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala @@ -34,7 +34,7 @@ trait ShuffleBlockResolver { * Retrieve the data for the specified block. If the data for that block is not available, * throws an unspecified exception. */ - def getBlockData(blockId: ShuffleBlockId): ManagedBuffer + def getBlockData(shuffleGenerationId: Int, blockId: ShuffleBlockId): ManagedBuffer def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 17719f516a0a1..6f2ea444fa9d3 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -19,6 +19,8 @@ package org.apache.spark.shuffle.sort import java.util.concurrent.ConcurrentHashMap +import scala.collection.JavaConverters._ + import org.apache.spark._ import org.apache.spark.internal.{config, Logging} import org.apache.spark.shuffle._ @@ -79,9 +81,10 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager } /** - * A mapping from shuffle ids to the number of mappers producing output for those shuffles. + * A mapping from shuffle ids to the tuple of number of mappers producing output and + * shuffle generation id for those shuffles. */ - private[this] val numMapsForShuffle = new ConcurrentHashMap[Int, Int]() + private[this] val infoMapsForShuffle = new ConcurrentHashMap[Int, (Int, Int)]() private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf) @@ -133,8 +136,10 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager mapId: Int, context: TaskContext, metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { - numMapsForShuffle.putIfAbsent( - handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps) + infoMapsForShuffle.putIfAbsent( + handle.shuffleId, + (handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps, + context.getShuffleGenerationId(handle.shuffleId))) val env = SparkEnv.get handle match { case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] => @@ -152,7 +157,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager env.blockManager, bypassMergeSortHandle, mapId, - context.taskAttemptId(), + context, env.conf, metrics, shuffleExecutorComponents) @@ -163,10 +168,11 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager /** Remove a shuffle's metadata from the ShuffleManager. */ override def unregisterShuffle(shuffleId: Int): Boolean = { - Option(numMapsForShuffle.remove(shuffleId)).foreach { numMaps => - (0 until numMaps).foreach { mapId => - shuffleBlockResolver.removeDataByMap(shuffleId, mapId) - } + Option(infoMapsForShuffle.remove(shuffleId)).foreach { + case (numMaps, shuffleGenerationId) => + (0 until numMaps).foreach { mapId => + shuffleBlockResolver.removeDataByMap(shuffleId, shuffleGenerationId, mapId) + } } true } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 16058de8bf3ff..833616cba77d4 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -47,6 +47,8 @@ private[spark] class SortShuffleWriter[K, V, C]( private val writeMetrics = context.taskMetrics().shuffleWriteMetrics + private val shuffleGenerationId = context.getShuffleGenerationId(handle.shuffleId) + /** Write a bunch of records to this task's output */ override def write(records: Iterator[Product2[K, V]]): Unit = { sorter = if (dep.mapSideCombine) { @@ -64,12 +66,16 @@ private[spark] class SortShuffleWriter[K, V, C]( // Don't bother including the time to open the merged output file in the shuffle write time, // because it just opens a single file, so is typically too fast to measure accurately // (see SPARK-3570). - val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId) + val output = shuffleBlockResolver.getDataFile(dep.shuffleId, shuffleGenerationId, mapId) val tmp = Utils.tempFileWith(output) try { - val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) + val blockId = ShuffleBlockId( + dep.shuffleId, + mapId, + IndexShuffleBlockResolver.NOOP_REDUCE_ID) val partitionLengths = sorter.writePartitionedFile(blockId, tmp) - shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp) + shuffleBlockResolver.writeIndexFileAndCommit( + dep.shuffleId, shuffleGenerationId, mapId, partitionLengths, tmp) mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) } finally { if (tmp.exists() && !tmp.delete()) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 7ac2c71c18eb3..c23f0a6e1d422 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -98,7 +98,7 @@ private[spark] case class TestBlockId(id: String) extends BlockId { @DeveloperApi class UnrecognizedBlockId(name: String) - extends SparkException(s"Failed to parse $name into a block ID") + extends SparkException(s"Failed to parse $name into a block ID") @DeveloperApi object BlockId { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index b6c5c98b7a457..c5bb6d9e2370d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -548,7 +548,11 @@ private[spark] class BlockManager( */ override def getBlockData(blockId: BlockId): ManagedBuffer = { if (blockId.isShuffle) { - shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) + // This branch is for the compatibility of the old shuffle protocol, which using OpenBlock + // to fetch shuffle blocks. After Spark 3.0, all shuffle block fetching will use new + // protocol FetchShuffleBlocks and use getShuffleBlockData in BlockManager. + // See more details in SPARK-27665. + shuffleManager.shuffleBlockResolver.getBlockData(-1, blockId.asInstanceOf[ShuffleBlockId]) } else { getLocalBytes(blockId) match { case Some(blockData) => @@ -563,6 +567,19 @@ private[spark] class BlockManager( } } + /** + * Interface to get shuffle block data. Throws an exception if the block cannot be found or + * cannot be read successfully. + */ + override def getShuffleBlockData( + shuffleId: Int, + shuffleGenerationId: Int, + mapId: Int, + reduceId: Int): ManagedBuffer = { + val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) + shuffleManager.shuffleBlockResolver.getBlockData(shuffleGenerationId, shuffleBlockId) + } + /** * Put the block locally, using the given storage level. * diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index a5b7ee5762c49..af95cb82829bd 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -60,6 +60,7 @@ import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} * @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory. * @param detectCorrupt whether to detect any corruption in fetched blocks. * @param shuffleMetrics used to report shuffle metrics. + * @param shuffleGenerationId used to fetch shuffle blocks for indeterminate stage. */ private[spark] final class ShuffleBlockFetcherIterator( @@ -74,7 +75,8 @@ final class ShuffleBlockFetcherIterator( maxReqSizeShuffleToMem: Long, detectCorrupt: Boolean, detectCorruptUseExtraMemory: Boolean, - shuffleMetrics: ShuffleReadMetricsReporter) + shuffleMetrics: ShuffleReadMetricsReporter, + shuffleGenerationId: Int = -1) extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging { import ShuffleBlockFetcherIterator._ @@ -257,11 +259,11 @@ final class ShuffleBlockFetcherIterator( // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch // the data and write it to file directly. if (req.size > maxReqSizeShuffleToMem) { - shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, - blockFetchingListener, this) + shuffleClient.fetchBlocks(address.host, address.port, address.executorId, + shuffleGenerationId, blockIds.toArray, blockFetchingListener, this) } else { - shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, - blockFetchingListener, null) + shuffleClient.fetchBlocks(address.host, address.port, address.executorId, + shuffleGenerationId, blockIds.toArray, blockFetchingListener, null) } } @@ -342,7 +344,10 @@ final class ShuffleBlockFetcherIterator( while (iter.hasNext) { val blockId = iter.next() try { - val buf = blockManager.getBlockData(blockId) + val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId] + val buf = blockManager.getShuffleBlockData( + shuffleBlockId.shuffleId, shuffleGenerationId, + shuffleBlockId.mapId, shuffleBlockId.reduceId) shuffleMetrics.incLocalBlocksFetched(1) shuffleMetrics.incLocalBytesRead(buf.size) buf.retain() diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 88125a6b93ade..ef7cf24e3b4b0 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -131,15 +131,17 @@ public void setUp() throws IOException { ); }); - when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile); + when(shuffleBlockResolver.getDataFile( + anyInt(), anyInt(), anyInt())).thenReturn(mergedOutputFile); doAnswer(invocationOnMock -> { - partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2]; - File tmp = (File) invocationOnMock.getArguments()[3]; + partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[3]; + File tmp = (File) invocationOnMock.getArguments()[4]; mergedOutputFile.delete(); tmp.renameTo(mergedOutputFile); return null; }).when(shuffleBlockResolver) - .writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), any(File.class)); + .writeIndexFileAndCommit( + anyInt(), anyInt(), anyInt(), any(long[].class), any(File.class)); when(diskBlockManager.createTempShuffleBlock()).thenAnswer(invocationOnMock -> { TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID()); diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index 7f7f3db65d6ca..aa77f00aad363 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.network.TransportContext import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.server.TransportServer import org.apache.spark.network.shuffle.{ExternalBlockHandler, ExternalBlockStoreClient} +import org.apache.spark.rdd.RDD import org.apache.spark.storage.{RDDBlockId, StorageLevel} import org.apache.spark.util.Utils @@ -66,7 +67,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll wi } // This test ensures that the external shuffle service is actually in use for the other tests. - test("using external shuffle service") { + private def checkResultWithShuffleService(createRDD: (SparkContext => RDD[_])): Unit = { sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) sc.env.blockManager.externalShuffleServiceEnabled should equal(true) sc.env.blockManager.blockStoreClient.getClass should equal(classOf[ExternalBlockStoreClient]) @@ -79,7 +80,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll wi // Therefore, we should wait until all slaves are up TestUtils.waitUntilExecutorsUp(sc, 2, 60000) - val rdd = sc.parallelize(0 until 1000, 10).map(i => (i, 1)).reduceByKey(_ + _) + val rdd = createRDD(sc) rdd.count() rdd.count() @@ -96,6 +97,18 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll wi e.getMessage should include ("Fetch failure will not retry stage due to testing config") } + test("using external shuffle service") { + val createRDD = (sc: SparkContext) => + sc.parallelize(0 until 1000, 10).map(i => (i, 1)).reduceByKey(_ + _) + checkResultWithShuffleService(createRDD) + } + + test("using external shuffle service for indeterminate rdd") { + val createIndeterminateRDD = (sc: SparkContext) => + sc.parallelize(0 until 1000, 10).repartition(11).repartition(12) + checkResultWithShuffleService(createIndeterminateRDD) + } + test("SPARK-25888: using external shuffle service fetching disk persisted blocks") { val confWithRddFetchEnabled = conf.clone.set(config.SHUFFLE_SERVICE_FETCH_RDD_ENABLED, true) sc = new SparkContext("local-cluster[1,1,1024]", "test", confWithRddFetchEnabled) diff --git a/core/src/test/scala/org/apache/spark/deploy/ExternalShuffleServiceDbSuite.scala b/core/src/test/scala/org/apache/spark/deploy/ExternalShuffleServiceDbSuite.scala index 9cfb8a647ad89..a133965ed7ade 100644 --- a/core/src/test/scala/org/apache/spark/deploy/ExternalShuffleServiceDbSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/ExternalShuffleServiceDbSuite.scala @@ -59,7 +59,7 @@ class ExternalShuffleServiceDbSuite extends SparkFunSuite { // Write some sort data. dataContext.insertSortShuffleData(0, 0, Array[Array[Byte]](sortBlock0.getBytes(StandardCharsets.UTF_8), - sortBlock1.getBytes(StandardCharsets.UTF_8))) + sortBlock1.getBytes(StandardCharsets.UTF_8)), false) registerExecutor() } diff --git a/core/src/test/scala/org/apache/spark/network/BlockTransferServiceSuite.scala b/core/src/test/scala/org/apache/spark/network/BlockTransferServiceSuite.scala index d7e4b9166fa04..cceb9717b78d8 100644 --- a/core/src/test/scala/org/apache/spark/network/BlockTransferServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/BlockTransferServiceSuite.scala @@ -51,6 +51,7 @@ class BlockTransferServiceSuite extends SparkFunSuite with TimeLimits { host: String, port: Int, execId: String, + shuffleGenerationId: Int, blockIds: Array[String], listener: BlockFetchingListener, tempFileManager: DownloadFileManager): Unit = { @@ -96,7 +97,7 @@ class BlockTransferServiceSuite extends SparkFunSuite with TimeLimits { val e = intercept[SparkException] { failAfter(10.seconds) { blockTransferService.fetchBlockSync( - "localhost-unused", 0, "exec-id-unused", "block-id-unused", null) + "localhost-unused", 0, "exec-id-unused", "test_block-id-unused", null) } } assert(e.getCause.isInstanceOf[IllegalArgumentException]) diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 544d52d48b385..f014857342220 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -122,7 +122,8 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi val blockString = "Hello, world!" val blockBuffer = new NioManagedBuffer(ByteBuffer.wrap( blockString.getBytes(StandardCharsets.UTF_8))) - when(blockManager.getBlockData(blockId)).thenReturn(blockBuffer) + when(blockManager.getShuffleBlockData(blockId.shuffleId, -1, blockId.mapId, blockId.reduceId)) + .thenReturn(blockBuffer) val securityManager0 = new SecurityManager(conf0) val exec0 = new NettyBlockTransferService(conf0, securityManager0, "localhost", "localhost", 0, @@ -158,7 +159,7 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi val promise = Promise[ManagedBuffer]() - self.fetchBlocks(from.hostName, from.port, execId, Array(blockId.toString), + self.fetchBlocks(from.hostName, from.port, execId, -1, Array(blockId.toString), new BlockFetchingListener { override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = { promise.failure(exception) diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala index 5d67d3358a9ca..46815f066d426 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala @@ -121,7 +121,7 @@ class NettyBlockTransferServiceSuite clientFactoryField.setAccessible(true) clientFactoryField.set(service0, clientFactory) - service0.fetchBlocks("localhost", port, "exec1", + service0.fetchBlocks("localhost", port, "exec1", -1, Array("block1"), listener, mock(classOf[DownloadFileManager])) assert(createClientCount === 1) assert(hitExecutorDeadException) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index cff3ebf2fb7e0..f8accd94a5a2c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -2710,37 +2710,32 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assert(countSubmittedMapStageAttempts() === 2) } - test("SPARK-23207: retry all the succeeding stages when the map stage is indeterminate") { + private def constructIndeterminateStageRetryScenario(): (Int, Int) = { val shuffleMapRdd1 = new MyRDD(sc, 2, Nil, indeterminate = true) - val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(2)) val shuffleId1 = shuffleDep1.shuffleId val shuffleMapRdd2 = new MyRDD(sc, 2, List(shuffleDep1), tracker = mapOutputTracker) - val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new HashPartitioner(2)) val shuffleId2 = shuffleDep2.shuffleId val finalRdd = new MyRDD(sc, 2, List(shuffleDep2), tracker = mapOutputTracker) - submit(finalRdd, Array(0, 1)) + submit(finalRdd, Array(0, 1), properties = new Properties()) // Finish the first shuffle map stage. complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 2)), (Success, makeMapStatus("hostB", 2)))) assert(mapOutputTracker.findMissingPartitions(shuffleId1) === Some(Seq.empty)) - // Finish the second shuffle map stage. complete(taskSets(1), Seq( (Success, makeMapStatus("hostC", 2)), (Success, makeMapStatus("hostD", 2)))) assert(mapOutputTracker.findMissingPartitions(shuffleId2) === Some(Seq.empty)) - // The first task of the final stage failed with fetch failure runEvent(makeCompletionEvent( taskSets(2).tasks(0), FetchFailed(makeBlockManagerId("hostC"), shuffleId2, 0, 0, "ignored"), null)) - val failedStages = scheduler.failedStages.toSeq assert(failedStages.length == 2) // Shuffle blocks of "hostC" is lost, so first task of the `shuffleMapRdd2` needs to retry. @@ -2751,8 +2746,36 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assert(failedStages.collect { case stage: ResultStage => stage }.head.findMissingPartitions() == Seq(0, 1)) - scheduler.resubmitFailedStages() + (shuffleId1, shuffleId2) + } + + test("SPARK-25341: abort stage while using old fetch protocol") { + // reset the test context with using old fetch protocol + afterEach() + val conf = new SparkConf() + conf.set(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL.key, "true") + init(conf) + + val (shuffleId1, _) = constructIndeterminateStageRetryScenario() + // The second task of the `shuffleMapRdd2` failed with fetch failure + runEvent(makeCompletionEvent( + taskSets(3).tasks(0), + Success, + makeMapStatus("hostC", 2))) + runEvent(makeCompletionEvent( + taskSets(3).tasks(1), + FetchFailed(makeBlockManagerId("hostA"), shuffleId1, 0, 0, "ignored"), + null)) + + // The job should fail because Spark can't rollback the shuffle map stage while + // using old protocol. + assert(failure != null && failure.getMessage.contains( + "Spark can only do this while using the new shuffle block fetching protocol")) + } + + test("SPARK-25341: retry all the succeeding stages when the map stage is indeterminate") { + val (shuffleId1, shuffleId2) = constructIndeterminateStageRetryScenario() // The first task of the `shuffleMapRdd2` failed with fetch failure runEvent(makeCompletionEvent( @@ -2760,8 +2783,104 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi FetchFailed(makeBlockManagerId("hostA"), shuffleId1, 0, 0, "ignored"), null)) - // The job should fail because Spark can't rollback the shuffle map stage. - assert(failure != null && failure.getMessage.contains("Spark cannot rollback")) + val newFailedStages = scheduler.failedStages.toSeq + assert(newFailedStages.map(_.id) == Seq(0, 1)) + + scheduler.resubmitFailedStages() + + // First shuffle map stage resubmitted and reran all tasks. + assert(taskSets(4).stageId == 0) + assert(taskSets(4).stageAttemptId == 1) + assert(taskSets(4).tasks.length == 2) + + // Finish all stage. + complete(taskSets(4), Seq( + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)))) + assert(mapOutputTracker.findMissingPartitions(shuffleId1) === Some(Seq.empty)) + assert(taskSets(4).tasks.head.localProperties.getProperty( + SparkContext.SHUFFLE_GENERATION_ID_PREFIX + shuffleId1.toString) == "1") + + complete(taskSets(5), Seq( + (Success, makeMapStatus("hostC", 2)), + (Success, makeMapStatus("hostD", 2)))) + assert(mapOutputTracker.findMissingPartitions(shuffleId2) === Some(Seq.empty)) + assert(taskSets(5).tasks.head.localProperties.getProperty( + SparkContext.SHUFFLE_GENERATION_ID_PREFIX + shuffleId2.toString) == "2") + + complete(taskSets(6), Seq((Success, 11), (Success, 12))) + + // Job successful ended. + assert(results === Map(0 -> 11, 1 -> 12)) + results.clear() + assertDataStructuresEmpty() + } + + test("SPARK-25341: continuous indeterminate stage roll back") { + // shuffleMapRdd1/2/3 are all indeterminate. + val shuffleMapRdd1 = new MyRDD(sc, 2, Nil, indeterminate = true) + val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(2)) + val shuffleId1 = shuffleDep1.shuffleId + + val shuffleMapRdd2 = new MyRDD( + sc, 2, List(shuffleDep1), tracker = mapOutputTracker, indeterminate = true) + val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new HashPartitioner(2)) + val shuffleId2 = shuffleDep2.shuffleId + + val shuffleMapRdd3 = new MyRDD( + sc, 2, List(shuffleDep2), tracker = mapOutputTracker, indeterminate = true) + val shuffleDep3 = new ShuffleDependency(shuffleMapRdd3, new HashPartitioner(2)) + val shuffleId3 = shuffleDep3.shuffleId + val finalRdd = new MyRDD(sc, 2, List(shuffleDep3), tracker = mapOutputTracker) + + submit(finalRdd, Array(0, 1), properties = new Properties()) + + // Finish the first 3 shuffle map stages. + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)))) + assert(mapOutputTracker.findMissingPartitions(shuffleId1) === Some(Seq.empty)) + + complete(taskSets(1), Seq( + (Success, makeMapStatus("hostB", 2)), + (Success, makeMapStatus("hostD", 2)))) + assert(mapOutputTracker.findMissingPartitions(shuffleId2) === Some(Seq.empty)) + + // Executor lost on hostB, both of stage 0 and 1 should be reran. + runEvent(makeCompletionEvent( + taskSets(2).tasks(0), + FetchFailed(makeBlockManagerId("hostB"), shuffleId2, 0, 0, "ignored"), + null)) + mapOutputTracker.removeOutputsOnHost("hostB") + + assert(scheduler.failedStages.toSeq.map(_.id) == Seq(1, 2)) + scheduler.resubmitFailedStages() + + def checkAndCompleteRetryStage( + taskSetIndex: Int, + stageId: Int, + shuffleId: Int): Unit = { + assert(taskSets(taskSetIndex).stageId == stageId) + assert(taskSets(taskSetIndex).stageAttemptId == 1) + assert(taskSets(taskSetIndex).tasks.length == 2) + complete(taskSets(taskSetIndex), Seq( + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)))) + assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty)) + assert(taskSets(taskSetIndex).tasks.head.localProperties.getProperty( + SparkContext.SHUFFLE_GENERATION_ID_PREFIX + shuffleId.toString) == "1") + } + + // Check all indeterminate stage roll back. + checkAndCompleteRetryStage(3, 0, shuffleId1) + checkAndCompleteRetryStage(4, 1, shuffleId2) + checkAndCompleteRetryStage(5, 2, shuffleId3) + + // Result stage success, all job ended. + complete(taskSets(6), Seq((Success, 11), (Success, 12))) + assert(results === Map(0 -> 11, 1 -> 12)) + results.clear() + assertDataStructuresEmpty() } private def assertResultStageFailToRollback(mapRdd: MyRDD): Unit = { diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index b29d32f7b35c5..d3da1bf75ecf1 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -57,7 +57,7 @@ object FakeTask { val tasks = Array.tabulate[Task[_]](numTasks) { i => new FakeTask(stageId, i, if (prefLocs.size != 0) prefLocs(i) else Nil) } - new TaskSet(tasks, stageId, stageAttemptId, priority = 0, null) + new TaskSet(tasks, stageId, stageAttemptId, priority = 0, new Properties()) } def createShuffleMapTaskSet( @@ -92,6 +92,6 @@ object FakeTask { val tasks = Array.tabulate[Task[_]](numTasks) { i => new FakeTask(stageId, i, if (prefLocs.size != 0) prefLocs(i) else Nil, isBarrier = true) } - new TaskSet(tasks, stageId, stageAttempId, priority = 0, null) + new TaskSet(tasks, stageId, stageAttempId, priority = 0, new Properties()) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index cac6285e58417..fbb2c2f8c490d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer +import java.util.Properties import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.concurrent.duration._ @@ -202,7 +203,8 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B config.CPUS_PER_TASK.key -> taskCpus.toString) val numFreeCores = 1 val taskSet = new TaskSet( - Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), 0, 0, 0, null) + Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), + 0, 0, 0, new Properties()) val multiCoreWorkerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", taskCpus), new WorkerOffer("executor1", "host1", numFreeCores)) taskScheduler.submitTasks(taskSet) @@ -216,7 +218,8 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B // still be processed without error taskScheduler.submitTasks(FakeTask.createTaskSet(1)) val taskSet2 = new TaskSet( - Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), 1, 0, 0, null) + Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), + 1, 0, 0, new Properties()) taskScheduler.submitTasks(taskSet2) taskDescriptions = taskScheduler.resourceOffers(multiCoreWorkerOffers).flatten assert(taskDescriptions.map(_.executorId) === Seq("executor0")) diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 6d2ef17a7a790..4b2d4dfd55f35 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -94,8 +94,8 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // Setup the blockManager mock so the buffer gets returned when the shuffle code tries to // fetch shuffle data. - val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) - when(blockManager.getBlockData(shuffleBlockId)).thenReturn(managedBuffer) + when(blockManager.getShuffleBlockData(shuffleId, -1, mapId, reduceId)) + .thenReturn(managedBuffer) managedBuffer } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index b9f81fa0d0a06..2cebb4ff8b34b 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -73,14 +73,15 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte when(dependency.partitioner).thenReturn(new HashPartitioner(7)) when(dependency.serializer).thenReturn(new JavaSerializer(conf)) when(taskContext.taskMetrics()).thenReturn(taskMetrics) - when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile) + when(taskContext.getShuffleGenerationId(any[Int])).thenReturn(-1) + when(blockResolver.getDataFile(0, -1, 0)).thenReturn(outputFile) when(blockManager.diskBlockManager).thenReturn(diskBlockManager) when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager) when(blockResolver.writeIndexFileAndCommit( - anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File]))) + anyInt, anyInt, anyInt(), any(classOf[Array[Long]]), any(classOf[File]))) .thenAnswer { invocationOnMock => - val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File] + val tmp = invocationOnMock.getArguments()(4).asInstanceOf[File] if (tmp != null) { outputFile.delete tmp.renameTo(outputFile) @@ -140,7 +141,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte blockManager, shuffleHandle, 0, // MapId - 0L, // MapTaskAttemptId + taskContext, conf, taskContext.taskMetrics().shuffleWriteMetrics, shuffleExecutorComponents) @@ -167,7 +168,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte blockManager, shuffleHandle, 0, // MapId - 0L, + taskContext, transferConf, taskContext.taskMetrics().shuffleWriteMetrics, shuffleExecutorComponents) @@ -203,7 +204,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte blockManager, shuffleHandle, 0, // MapId - 0L, + taskContext, conf, taskContext.taskMetrics().shuffleWriteMetrics, shuffleExecutorComponents) @@ -225,7 +226,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte blockManager, shuffleHandle, 0, // MapId - 0L, + taskContext, conf, taskContext.taskMetrics().shuffleWriteMetrics, shuffleExecutorComponents) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala index 27bb06b4e0636..c38753e0c7fd8 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala @@ -46,8 +46,11 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa MockitoAnnotations.initMocks(this) when(blockManager.diskBlockManager).thenReturn(diskBlockManager) - when(diskBlockManager.getFile(any[BlockId])).thenAnswer( - (invocation: InvocationOnMock) => new File(tempDir, invocation.getArguments.head.toString)) + when(diskBlockManager.getFile(any[String])).thenAnswer( + (invocation: InvocationOnMock) => { + new File(tempDir, invocation.getArguments.head.asInstanceOf[String]) + } + ) } override def afterEach(): Unit = { @@ -58,10 +61,8 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } } - test("commit shuffle files multiple times") { - val shuffleId = 1 - val mapId = 2 - val idxName = s"shuffle_${shuffleId}_${mapId}_0.index" + private def testWithIndexShuffleBlockResolver( + shuffleId: Int, mapId: Int, idxName: String, generationId: Int): Unit = { val resolver = new IndexShuffleBlockResolver(conf, blockManager) val lengths = Array[Long](10, 0, 20) val dataTmp = File.createTempFile("shuffle", null, tempDir) @@ -71,10 +72,10 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } { out.close() } - resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp) + resolver.writeIndexFileAndCommit(shuffleId, generationId, mapId, lengths, dataTmp) val indexFile = new File(tempDir.getAbsolutePath, idxName) - val dataFile = resolver.getDataFile(shuffleId, mapId) + val dataFile = resolver.getDataFile(shuffleId, generationId, mapId) assert(indexFile.exists()) assert(indexFile.length() === (lengths.length + 1) * 8) @@ -91,7 +92,7 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } { out2.close() } - resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths2, dataTmp2) + resolver.writeIndexFileAndCommit(shuffleId, generationId, mapId, lengths2, dataTmp2) assert(indexFile.length() === (lengths.length + 1) * 8) assert(lengths2.toSeq === lengths.toSeq) @@ -130,7 +131,7 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } { out3.close() } - resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths3, dataTmp3) + resolver.writeIndexFileAndCommit(shuffleId, generationId, mapId, lengths3, dataTmp3) assert(indexFile.length() === (lengths3.length + 1) * 8) assert(lengths3.toSeq != lengths.toSeq) assert(dataFile.exists()) @@ -155,4 +156,69 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa indexIn2.close() } } + + test("commit shuffle files multiple times") { + val shuffleId = 1 + val mapId = 2 + val idxName = s"shuffle_${shuffleId}_${mapId}_0.index" + testWithIndexShuffleBlockResolver(shuffleId, mapId, idxName, -1) + } + + test("commit shuffle files with shuffle generation id multiple times") { + val shuffleId = 1 + val mapId = 2 + val generationId = 1 + val idxName = s"shuffle_${shuffleId}_${mapId}_0_1.index" + testWithIndexShuffleBlockResolver(shuffleId, mapId, idxName, generationId) + } + + test("commit shuffle files with different shuffle generation id multiple times") { + val shuffleId = 1 + val mapId = 2 + val resolver = new IndexShuffleBlockResolver(conf, blockManager) + // write index file and commit with different generation ids + (1 to 3).foreach { i => + val generationId = i + // Use generation id * 10 as the lengths in index file + val lengths = Array[Long](i * 10) + val dataTmp = File.createTempFile("shuffle", null, tempDir) + val out = new FileOutputStream(dataTmp) + Utils.tryWithSafeFinally { + // Use generation id as the first byte in data file + out.write(Array[Byte](i.toByte)) + out.write(new Array[Byte](29)) + } { + out.close() + } + resolver.writeIndexFileAndCommit(shuffleId, generationId, mapId, lengths, dataTmp) + } + // Check all data files and index files + val firstByte = new Array[Byte](1) + (1 to 3).foreach { i => + val generationId = i + val indexFile = new File( + tempDir.getAbsolutePath, s"shuffle_${shuffleId}_${mapId}_0_${generationId}.index") + val dataFile = resolver.getDataFile(shuffleId, generationId, mapId) + + // Check the dataFile, it should be the new one for each generation id + assert(dataFile.exists()) + assert(dataFile.length() === 30) + val dataIn = new FileInputStream(dataFile) + Utils.tryWithSafeFinally { + dataIn.read(firstByte) + } { + dataIn.close() + } + assert(firstByte(0) === i) + + // The index file should be the new one for each generation id + val indexIn = new DataInputStream(new FileInputStream(indexFile)) + Utils.tryWithSafeFinally { + indexIn.readLong() // the first offset is always 0 + assert(indexIn.readLong() === i * 10, "The index file should be the new one") + } { + indexIn.close() + } + } + } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala index 690bcd9905257..40df8d6f0c36b 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala @@ -70,7 +70,7 @@ class SortShuffleWriterSuite extends SparkFunSuite with SharedSparkContext with context) writer.write(Iterator.empty) writer.stop(success = true) - val dataFile = shuffleBlockResolver.getDataFile(shuffleId, 1) + val dataFile = shuffleBlockResolver.getDataFile(shuffleId, -1, 1) val writeMetrics = context.taskMetrics().shuffleWriteMetrics assert(!dataFile.exists()) assert(writeMetrics.bytesWritten === 0) @@ -87,7 +87,7 @@ class SortShuffleWriterSuite extends SparkFunSuite with SharedSparkContext with context) writer.write(records.toIterator) writer.stop(success = true) - val dataFile = shuffleBlockResolver.getDataFile(shuffleId, 2) + val dataFile = shuffleBlockResolver.getDataFile(shuffleId, -1, 2) val writeMetrics = context.taskMetrics().shuffleWriteMetrics assert(dataFile.exists()) assert(dataFile.length() === writeMetrics.bytesWritten) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 509d4efcab67a..36f1f839406bc 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -1650,6 +1650,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE host: String, port: Int, execId: String, + shuffleGenerationId: Int, blockIds: Array[String], listener: BlockFetchingListener, tempFileManager: DownloadFileManager): Unit = { diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index ed402440e74f1..68082c59678fa 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -49,10 +49,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT /** Creates a mock [[BlockTransferService]] that returns data from the given map. */ private def createMockTransfer(data: Map[BlockId, ManagedBuffer]): BlockTransferService = { val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())).thenAnswer( + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())).thenAnswer( (invocation: InvocationOnMock) => { - val blocks = invocation.getArguments()(3).asInstanceOf[Array[String]] - val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] + val blocks = invocation.getArguments()(4).asInstanceOf[Array[String]] + val listener = invocation.getArguments()(5).asInstanceOf[BlockFetchingListener] for (blockId <- blocks) { if (data.contains(BlockId(blockId))) { @@ -87,7 +87,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()) localBlocks.foreach { case (blockId, buf) => - doReturn(buf).when(blockManager).getBlockData(meq(blockId)) + val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId] + doReturn(buf).when(blockManager).getShuffleBlockData( + meq(shuffleBlockId.shuffleId), meq(-1), + meq(shuffleBlockId.mapId), meq(shuffleBlockId.reduceId)) } // Make sure remote blocks would return @@ -120,7 +123,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT metrics) // 3 local blocks fetched in initialization - verify(blockManager, times(3)).getBlockData(any()) + verify(blockManager, times(3)).getShuffleBlockData(any(), any(), any(), any()) for (i <- 0 until 5) { assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements") @@ -144,8 +147,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // 3 local blocks, and 2 remote blocks // (but from the same block manager so one call to fetchBlocks) - verify(blockManager, times(3)).getBlockData(any()) - verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any()) + verify(blockManager, times(3)).getShuffleBlockData(any(), any(), any(), any()) + verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any(), any()) } test("release current unexhausted buffer in case the task completes early") { @@ -164,9 +167,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer((invocation: InvocationOnMock) => { - val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] + val listener = invocation.getArguments()(5).asInstanceOf[BlockFetchingListener] Future { // Return the first two blocks, and wait till task completion before returning the 3rd one listener.onBlockFetchSuccess( @@ -232,9 +235,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer((invocation: InvocationOnMock) => { - val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] + val listener = invocation.getArguments()(5).asInstanceOf[BlockFetchingListener] Future { // Return the first block, and then fail. listener.onBlockFetchSuccess( @@ -321,9 +324,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer((invocation: InvocationOnMock) => { - val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] + val listener = invocation.getArguments()(5).asInstanceOf[BlockFetchingListener] Future { // Return the first block, and then fail. listener.onBlockFetchSuccess( @@ -361,9 +364,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val (id1, _) = iterator.next() assert(id1 === ShuffleBlockId(0, 0, 0)) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer((invocation: InvocationOnMock) => { - val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] + val listener = invocation.getArguments()(5).asInstanceOf[BlockFetchingListener] Future { // Return the first block, and then fail. listener.onBlockFetchSuccess( @@ -464,7 +467,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) doReturn(localBmId).when(blockManager).blockManagerId - doReturn(managedBuffer).when(blockManager).getBlockData(ShuffleBlockId(0, 0, 0)) + doReturn(managedBuffer).when(blockManager).getShuffleBlockData(0, -1, 0, 0) val localBlockLengths = Seq[Tuple2[BlockId, Long]]( ShuffleBlockId(0, 0, 0) -> 10000 ) @@ -516,9 +519,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer((invocation: InvocationOnMock) => { - val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] + val listener = invocation.getArguments()(5).asInstanceOf[BlockFetchingListener] Future { // Return the first block, and then fail. listener.onBlockFetchSuccess( @@ -580,10 +583,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()) val transfer = mock(classOf[BlockTransferService]) var tempFileManager: DownloadFileManager = null - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer((invocation: InvocationOnMock) => { - val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - tempFileManager = invocation.getArguments()(5).asInstanceOf[DownloadFileManager] + val listener = invocation.getArguments()(5).asInstanceOf[BlockFetchingListener] + tempFileManager = invocation.getArguments()(6).asInstanceOf[DownloadFileManager] Future { listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0))) From e7365e33e44e7f90dcb4a2707745293d648790f0 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Mon, 1 Jul 2019 23:14:58 -0700 Subject: [PATCH 2/9] refactor OneForOneShuffleBlockFetcher, seperate shuffle block fetcher and non-shuffle block fetcher into 2 classes Fix and address comments fix style UT fix and comment address --- .../network/shuffle/BlockStoreClient.java | 13 ++- .../shuffle/ExternalBlockStoreClient.java | 37 ++++++- .../shuffle/OneForOneBlockFetcher.java | 90 +++-------------- .../shuffle/OneForOneDataBlockFetcher.java | 45 +++++++++ .../shuffle/OneForOneShuffleBlockFetcher.java | 96 +++++++++++++++++++ .../network/sasl/SaslIntegrationSuite.java | 4 +- .../ExternalShuffleIntegrationSuite.java | 36 ++++--- .../shuffle/OneForOneBlockFetcherSuite.java | 46 ++++++--- .../spark/network/BlockTransferService.scala | 15 ++- .../netty/NettyBlockTransferService.scala | 35 ++++++- .../shuffle/IndexShuffleBlockResolver.scala | 14 ++- .../spark/shuffle/ShuffleBlockResolver.scala | 6 +- .../apache/spark/storage/BlockManager.scala | 8 +- .../network/BlockTransferServiceSuite.scala | 15 ++- .../NettyBlockTransferServiceSuite.scala | 4 +- .../spark/scheduler/DAGSchedulerSuite.scala | 16 +++- .../spark/storage/BlockManagerSuite.scala | 10 ++ 17 files changed, 359 insertions(+), 131 deletions(-) create mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneDataBlockFetcher.java create mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneShuffleBlockFetcher.java diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java index 008bb5795e44b..294cc2caa28ec 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java @@ -56,7 +56,18 @@ public abstract void fetchBlocks( DownloadFileManager downloadFileManager); /** - * Get the shuffle MetricsSet from BlockStoreClient, this will be used in MetricsSystem to + * Fetch a sequence of non-shuffle blocks from a remote node asynchronously. + */ + public abstract void fetchDataBlocks( + String host, + int port, + String execId, + String[] blockIds, + BlockFetchingListener listener, + DownloadFileManager downloadFileManager); + + /** + * Get the shuffle MetricsSet from ShuffleClient, this will be used in MetricsSystem to * get the Shuffle related metrics. */ public MetricSet shuffleMetrics() { diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java index cc86cf6e64700..f2c58b54b8692 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java @@ -89,6 +89,17 @@ public void init(String appId) { clientFactory = context.createClientFactory(bootstraps); } + @Override + public void fetchDataBlocks( + String host, + int port, + String execId, + String[] blockIds, + BlockFetchingListener listener, + DownloadFileManager downloadFileManager) { + doFetchBlocks(host, port, execId, -1, blockIds, listener, downloadFileManager, false); + } + @Override public void fetchBlocks( String host, @@ -98,15 +109,33 @@ public void fetchBlocks( String[] blockIds, BlockFetchingListener listener, DownloadFileManager downloadFileManager) { + doFetchBlocks(host, port, execId, shuffleGenerationId, blockIds, listener, + downloadFileManager, true); + } + + private void doFetchBlocks( + String host, + int port, + String execId, + int shuffleGenerationId, + String[] blockIds, + BlockFetchingListener listener, + DownloadFileManager downloadFileManager, + boolean isShuffleBlocks) { checkInit(); logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); try { RetryingBlockFetcher.BlockFetchStarter blockFetchStarter = - (blockIds1, listener1) -> { - TransportClient client = clientFactory.createClient(host, port); - new OneForOneBlockFetcher(client, appId, execId, shuffleGenerationId, + (blockIds1, listener1) -> { + TransportClient client = clientFactory.createClient(host, port); + if (isShuffleBlocks) { + new OneForOneShuffleBlockFetcher(client, appId, execId, shuffleGenerationId, + blockIds1, listener1, conf, downloadFileManager).start(); + } else { + new OneForOneDataBlockFetcher(client, appId, execId, blockIds1, listener1, conf, downloadFileManager).start(); - }; + } + }; int maxRetries = conf.maxIORetries(); if (maxRetries > 0) { diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 53ee4351203c5..ceddda1826362 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -19,11 +19,8 @@ import java.io.IOException; import java.nio.ByteBuffer; -import java.util.ArrayList; import java.util.Arrays; -import java.util.HashMap; -import com.google.common.primitives.Ints; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -34,53 +31,43 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; -import org.apache.spark.network.shuffle.protocol.FetchShuffleBlocks; -import org.apache.spark.network.shuffle.protocol.OpenBlocks; import org.apache.spark.network.shuffle.protocol.StreamHandle; import org.apache.spark.network.util.TransportConf; /** * Simple wrapper on top of a TransportClient which interprets each chunk as a whole block, and * invokes the BlockFetchingListener appropriately. This class is agnostic to the actual RPC - * handler, as long as there is a single "open blocks" message which returns a ShuffleStreamHandle, + * handler, as long as there is a single "open blocks" message which returns a StreamHandle, * and Java serialization is used. * * Note that this typically corresponds to a * {@link org.apache.spark.network.server.OneForOneStreamManager} on the server side. */ -public class OneForOneBlockFetcher { +public abstract class OneForOneBlockFetcher { private static final Logger logger = LoggerFactory.getLogger(OneForOneBlockFetcher.class); private final TransportClient client; - private final BlockTransferMessage message; - private final String[] blockIds; + protected final String appId; + protected final String execId; + protected final String[] blockIds; private final BlockFetchingListener listener; private final ChunkReceivedCallback chunkCallback; - private final TransportConf transportConf; + protected final TransportConf transportConf; private final DownloadFileManager downloadFileManager; private StreamHandle streamHandle = null; - public OneForOneBlockFetcher( - TransportClient client, - String appId, - String execId, - String[] blockIds, - BlockFetchingListener listener, - TransportConf transportConf) { - this(client, appId, execId, -1, blockIds, listener, transportConf, null); - } - - public OneForOneBlockFetcher( + protected OneForOneBlockFetcher( TransportClient client, String appId, String execId, - int shuffleGenerationId, String[] blockIds, BlockFetchingListener listener, TransportConf transportConf, DownloadFileManager downloadFileManager) { this.client = client; + this.appId = appId; + this.execId = execId; this.blockIds = blockIds; this.listener = listener; this.chunkCallback = new ChunkCallback(); @@ -89,65 +76,12 @@ public OneForOneBlockFetcher( if (blockIds.length == 0) { throw new IllegalArgumentException("Zero-sized blockIds array"); } - if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) { - this.message = createFetchShuffleBlocksMsg(appId, execId, shuffleGenerationId, blockIds); - } else { - this.message = new OpenBlocks(appId, execId, blockIds); - } - } - - private boolean isShuffleBlocks(String[] blockIds) { - for (String blockId : blockIds) { - if (!blockId.startsWith("shuffle_")) { - return false; - } - } - return true; } /** - * Analyze the pass in blockIds and create FetchShuffleBlocks message. - * The blockIds has been sorted by mapId and reduceId. It's produced in - * org.apache.spark.MapOutputTracker.convertMapStatuses. + * Create the corresponding message for this fetcher. */ - private FetchShuffleBlocks createFetchShuffleBlocksMsg( - String appId, String execId, int shuffleGenerationId, String[] blockIds) { - int shuffleId = splitBlockId(blockIds[0])[0]; - HashMap> mapIdToReduceIds = new HashMap<>(); - for (String blockId : blockIds) { - int[] blockIdParts = splitBlockId(blockId); - if (blockIdParts[0] != shuffleId) { - throw new IllegalArgumentException("Expected shuffleId=" + shuffleId + - ", got:" + blockId); - } - int mapId = blockIdParts[1]; - if (!mapIdToReduceIds.containsKey(mapId)) { - mapIdToReduceIds.put(mapId, new ArrayList<>()); - } - mapIdToReduceIds.get(mapId).add(blockIdParts[2]); - } - int[] mapIds = Ints.toArray(mapIdToReduceIds.keySet()); - int[][] reduceIdArr = new int[mapIds.length][]; - for (int i = 0; i < mapIds.length; i++) { - reduceIdArr[i] = Ints.toArray(mapIdToReduceIds.get(mapIds[i])); - } - return new FetchShuffleBlocks( - appId, execId, shuffleId, shuffleGenerationId, mapIds, reduceIdArr); - } - - /** Split the shuffleBlockId and return shuffleId, mapId and reduceId. */ - private int[] splitBlockId(String blockId) { - String[] blockIdParts = blockId.split("_"); - if (blockIdParts.length != 4 || !blockIdParts[0].equals("shuffle")) { - throw new IllegalArgumentException( - "Unexpected shuffle block id format: " + blockId); - } - return new int[] { - Integer.parseInt(blockIdParts[1]), - Integer.parseInt(blockIdParts[2]), - Integer.parseInt(blockIdParts[3]) - }; - } + public abstract BlockTransferMessage createBlockTransferMessage(); /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */ private class ChunkCallback implements ChunkReceivedCallback { @@ -171,7 +105,7 @@ public void onFailure(int chunkIndex, Throwable e) { * {@link StreamHandle}. We will send all fetch requests immediately, without throttling. */ public void start() { - client.sendRpc(message.toByteBuffer(), new RpcResponseCallback() { + client.sendRpc(createBlockTransferMessage().toByteBuffer(), new RpcResponseCallback() { @Override public void onSuccess(ByteBuffer response) { try { diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneDataBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneDataBlockFetcher.java new file mode 100644 index 0000000000000..2dab5a9ca2e08 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneDataBlockFetcher.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.OpenBlocks; +import org.apache.spark.network.util.TransportConf; + +public class OneForOneDataBlockFetcher extends OneForOneBlockFetcher { + public OneForOneDataBlockFetcher( + TransportClient client, + String appId, + String execId, + String[] blockIds, + BlockFetchingListener listener, + TransportConf transportConf, + DownloadFileManager downloadFileManager) { + super(client, appId, execId, blockIds, listener, transportConf, downloadFileManager); + } + + /** + * Create the corresponding message for this fetcher. + * For non shuffle blocks, just use OpenBlocks message for all cases. + */ + @Override + public BlockTransferMessage createBlockTransferMessage() { + return new OpenBlocks(appId, execId, blockIds); + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneShuffleBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneShuffleBlockFetcher.java new file mode 100644 index 0000000000000..558aef6076a21 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneShuffleBlockFetcher.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import com.google.common.primitives.Ints; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.FetchShuffleBlocks; +import org.apache.spark.network.shuffle.protocol.OpenBlocks; +import org.apache.spark.network.util.TransportConf; + +import java.util.ArrayList; +import java.util.HashMap; + +public class OneForOneShuffleBlockFetcher extends OneForOneBlockFetcher { + private final int shuffleGenerationId; + + public OneForOneShuffleBlockFetcher( + TransportClient client, + String appId, + String execId, + int shuffleGenerationId, + String[] blockIds, + BlockFetchingListener listener, + TransportConf transportConf, + DownloadFileManager downloadFileManager) { + super(client, appId, execId, blockIds, listener, transportConf, downloadFileManager); + this.shuffleGenerationId = shuffleGenerationId; + } + + /** + * Create the corresponding message for this fetcher. + * For shuffle blocks, we choose different message base on whether to use the old protocol. + * If `spark.shuffle.useOldFetchProtocol` set to true, we use OpenBlocks for shuffle blocks. + * Otherwise, we analyze the pass in blockIds and create FetchShuffleBlocks message. + * The blockIds has been sorted by mapId and reduceId. It's produced in + * org.apache.spark.MapOutputTracker.convertMapStatuses. + */ + @Override + public BlockTransferMessage createBlockTransferMessage() { + if (transportConf.useOldFetchProtocol()) { + return new OpenBlocks(appId, execId, blockIds); + } else { + int shuffleId = splitBlockId(blockIds[0])[0]; + HashMap> mapIdToReduceIds = new HashMap<>(); + for (String blockId : blockIds) { + int[] blockIdParts = splitBlockId(blockId); + if (blockIdParts[0] != shuffleId) { + throw new IllegalArgumentException("Expected shuffleId=" + shuffleId + + ", got:" + blockId); + } + int mapId = blockIdParts[1]; + if (!mapIdToReduceIds.containsKey(mapId)) { + mapIdToReduceIds.put(mapId, new ArrayList<>()); + } + mapIdToReduceIds.get(mapId).add(blockIdParts[2]); + } + int[] mapIds = Ints.toArray(mapIdToReduceIds.keySet()); + int[][] reduceIdArr = new int[mapIds.length][]; + for (int i = 0; i < mapIds.length; i++) { + reduceIdArr[i] = Ints.toArray(mapIdToReduceIds.get(mapIds[i])); + } + return new FetchShuffleBlocks( + appId, execId, shuffleId, shuffleGenerationId, mapIds, reduceIdArr); + } + } + + /** Split the shuffleBlockId and return shuffleId, mapId and reduceId. */ + private int[] splitBlockId(String blockId) { + String[] blockIdParts = blockId.split("_"); + if (blockIdParts.length != 4 || !blockIdParts[0].equals("shuffle")) { + throw new IllegalArgumentException( + "Unexpected shuffle block id format: " + blockId); + } + return new int[] { + Integer.parseInt(blockIdParts[1]), + Integer.parseInt(blockIdParts[2]), + Integer.parseInt(blockIdParts[3]) + }; + } +} diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index e8e766d3fb3ab..187c0d2e1d6e9 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -48,6 +48,7 @@ import org.apache.spark.network.shuffle.ExternalBlockHandler; import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver; import org.apache.spark.network.shuffle.OneForOneBlockFetcher; +import org.apache.spark.network.shuffle.OneForOneShuffleBlockFetcher; import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; import org.apache.spark.network.shuffle.protocol.OpenBlocks; @@ -203,7 +204,8 @@ public void onBlockFetchFailure(String blockId, Throwable t) { String[] blockIds = { "shuffle_0_1_2", "shuffle_0_3_4" }; OneForOneBlockFetcher fetcher = - new OneForOneBlockFetcher(client1, "app-2", "0", blockIds, listener, conf); + new OneForOneShuffleBlockFetcher( + client1, "app-2", "0", -1, blockIds, listener, conf, null); fetcher.start(); blockFetchLatch.await(); checkSecurityException(exception.get()); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index d44f57461a59f..cabf982d044eb 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -158,15 +158,20 @@ public void releaseBuffers() { } } - // Fetch a set of blocks from a pre-registered executor. + // Fetch a set of shuffle blocks with default generation id -1 from a pre-registered executor. private FetchResult fetchBlocks(String execId, String[] blockIds) throws Exception { - return fetchBlocks(execId, blockIds, conf, server.getPort(), -1); + return fetchBlocks(execId, blockIds, conf, server.getPort(), -1, true); } - // Fetch a set of blocks from a pre-registered executor. + // Fetch a set of shuffle blocks from a pre-registered executor. private FetchResult fetchBlocks( String execId, String[] blockIds, int shuffleGenerationId) throws Exception { - return fetchBlocks(execId, blockIds, conf, server.getPort(), shuffleGenerationId); + return fetchBlocks(execId, blockIds, conf, server.getPort(), shuffleGenerationId, true); + } + + // Fetch a set of shuffle blocks from a pre-registered executor. + private FetchResult fetchDataBlocks(String execId, String[] blockIds) throws Exception { + return fetchBlocks(execId, blockIds, conf, server.getPort(), -1, false); } // Fetch a set of blocks from a pre-registered executor. Connects to the server on the given port, @@ -176,7 +181,8 @@ private FetchResult fetchBlocks( String[] blockIds, TransportConf clientConf, int port, - int shuffleGenerationId) throws Exception { + int shuffleGenerationId, + boolean isShuffleBlocks) throws Exception { final FetchResult res = new FetchResult(); res.successBlocks = Collections.synchronizedSet(new HashSet()); res.failedBlocks = Collections.synchronizedSet(new HashSet()); @@ -187,7 +193,7 @@ private FetchResult fetchBlocks( try (ExternalBlockStoreClient client = new ExternalBlockStoreClient( clientConf, null, false, 5000)) { client.init(APP_ID); - client.fetchBlocks(TestUtils.getLocalHost(), port, execId, shuffleGenerationId, blockIds, + BlockFetchingListener listener = new BlockFetchingListener() { @Override public void onBlockFetchSuccess(String blockId, ManagedBuffer data) { @@ -210,7 +216,13 @@ public void onBlockFetchFailure(String blockId, Throwable exception) { } } } - }, null); + }; + if (isShuffleBlocks) { + client.fetchBlocks(TestUtils.getLocalHost(), port, execId, shuffleGenerationId, blockIds, + listener, null); + } else { + client.fetchDataBlocks(TestUtils.getLocalHost(), port, execId, blockIds, listener, null); + } if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) { fail("Timeout getting response from the server"); @@ -271,7 +283,7 @@ public void testRegisterInvalidExecutor() throws Exception { @Test public void testFetchWrongBlockId() throws Exception { registerExecutor("exec-1", dataContext0.createExecutorInfo(SORT_MANAGER)); - FetchResult execFetch = fetchBlocks("exec-1", new String[] { "broadcast_1" }); + FetchResult execFetch = fetchDataBlocks("exec-1", new String[] { "broadcast_1" }); assertTrue(execFetch.successBlocks.isEmpty()); assertEquals(Sets.newHashSet("broadcast_1"), execFetch.failedBlocks); } @@ -280,7 +292,7 @@ public void testFetchWrongBlockId() throws Exception { public void testFetchValidRddBlock() throws Exception { registerExecutor("exec-1", dataContext0.createExecutorInfo(SORT_MANAGER)); String validBlockId = "rdd_" + RDD_ID +"_" + SPLIT_INDEX_VALID_BLOCK; - FetchResult execFetch = fetchBlocks("exec-1", new String[] { validBlockId }); + FetchResult execFetch = fetchDataBlocks("exec-1", new String[] { validBlockId }); assertTrue(execFetch.failedBlocks.isEmpty()); assertEquals(Sets.newHashSet(validBlockId), execFetch.successBlocks); assertBuffersEqual(new NioManagedBuffer(ByteBuffer.wrap(exec0RddBlockValid)), @@ -291,7 +303,7 @@ public void testFetchValidRddBlock() throws Exception { public void testFetchDeletedRddBlock() throws Exception { registerExecutor("exec-1", dataContext0.createExecutorInfo(SORT_MANAGER)); String missingBlockId = "rdd_" + RDD_ID +"_" + SPLIT_INDEX_MISSING_FILE; - FetchResult execFetch = fetchBlocks("exec-1", new String[] { missingBlockId }); + FetchResult execFetch = fetchDataBlocks("exec-1", new String[] { missingBlockId }); assertTrue(execFetch.successBlocks.isEmpty()); assertEquals(Sets.newHashSet(missingBlockId), execFetch.failedBlocks); } @@ -317,7 +329,7 @@ public void testRemoveRddBlocks() throws Exception { public void testFetchCorruptRddBlock() throws Exception { registerExecutor("exec-1", dataContext0.createExecutorInfo(SORT_MANAGER)); String corruptBlockId = "rdd_" + RDD_ID +"_" + SPLIT_INDEX_CORRUPT_LENGTH; - FetchResult execFetch = fetchBlocks("exec-1", new String[] { corruptBlockId }); + FetchResult execFetch = fetchDataBlocks("exec-1", new String[] { corruptBlockId }); assertTrue(execFetch.successBlocks.isEmpty()); assertEquals(Sets.newHashSet(corruptBlockId), execFetch.failedBlocks); } @@ -355,7 +367,7 @@ public void testFetchNoServer() throws Exception { new MapConfigProvider(ImmutableMap.of("spark.shuffle.io.maxRetries", "0"))); registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); FetchResult execFetch = fetchBlocks("exec-0", - new String[]{"shuffle_1_0_0", "shuffle_1_0_1"}, clientConf, 1 /* port */, -1); + new String[]{"shuffle_1_0_0", "shuffle_1_0_1"}, clientConf, 1 /* port */, -1, true); assertTrue(execFetch.successBlocks.isEmpty()); assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks); } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index 45830f77b438b..d2933909dd901 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -60,12 +60,16 @@ public void testFetchOne() { LinkedHashMap blocks = Maps.newLinkedHashMap(); blocks.put("shuffle_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0]))); String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); + int shuffleGenerationId = 1; BlockFetchingListener listener = fetchBlocks( blocks, blockIds, - new FetchShuffleBlocks("app-id", "exec-id", 0, -1, new int[] { 0 }, new int[][] {{ 0 }}), - conf); + new FetchShuffleBlocks( + "app-id", "exec-id", 0, shuffleGenerationId, new int[] { 0 }, new int[][] {{ 0 }}), + conf, + true, + shuffleGenerationId); verify(listener).onBlockFetchSuccess("shuffle_0_0_0", blocks.get("shuffle_0_0_0")); } @@ -84,7 +88,9 @@ public void testUseOldProtocol() { new HashMap() {{ put("spark.shuffle.useOldFetchProtocol", "true"); }} - ))); + )), + false, + -1); verify(listener).onBlockFetchSuccess("shuffle_0_0_0", blocks.get("shuffle_0_0_0")); } @@ -102,7 +108,9 @@ public void testFetchThreeShuffleBlocks() { blockIds, new FetchShuffleBlocks( "app-id", "exec-id", 0, -1, new int[] { 0 }, new int[][] {{ 0, 1, 2 }}), - conf); + conf, + true, + -1); for (int i = 0; i < 3; i ++) { verify(listener, times(1)).onBlockFetchSuccess( @@ -122,7 +130,9 @@ public void testFetchThree() { blocks, blockIds, new OpenBlocks("app-id", "exec-id", blockIds), - conf); + conf, + false, + -1); for (int i = 0; i < 3; i ++) { verify(listener, times(1)).onBlockFetchSuccess("b" + i, blocks.get("b" + i)); @@ -141,7 +151,9 @@ public void testFailure() { blocks, blockIds, new OpenBlocks("app-id", "exec-id", blockIds), - conf); + conf, + false, + -1); // Each failure will cause a failure to be invoked in all remaining block fetches. verify(listener, times(1)).onBlockFetchSuccess("b0", blocks.get("b0")); @@ -161,7 +173,9 @@ public void testFailureAndSuccess() { blocks, blockIds, new OpenBlocks("app-id", "exec-id", blockIds), - conf); + conf, + false, + -1); // We may call both success and failure for the same block. verify(listener, times(1)).onBlockFetchSuccess("b0", blocks.get("b0")); @@ -177,7 +191,9 @@ public void testEmptyBlockFetch() { Maps.newLinkedHashMap(), new String[] {}, new OpenBlocks("app-id", "exec-id", new String[] {}), - conf); + conf, + false, + -1); fail(); } catch (IllegalArgumentException e) { assertEquals("Zero-sized blockIds array", e.getMessage()); @@ -196,11 +212,19 @@ private static BlockFetchingListener fetchBlocks( LinkedHashMap blocks, String[] blockIds, BlockTransferMessage expectMessage, - TransportConf transportConf) { + TransportConf transportConf, + boolean useShuffleBlockFetcher, + int shuffleGenerationId) { TransportClient client = mock(TransportClient.class); BlockFetchingListener listener = mock(BlockFetchingListener.class); - OneForOneBlockFetcher fetcher = - new OneForOneBlockFetcher(client, "app-id", "exec-id", blockIds, listener, transportConf); + OneForOneBlockFetcher fetcher = null; + if (useShuffleBlockFetcher) { + fetcher = new OneForOneShuffleBlockFetcher( + client, "app-id", "exec-id", shuffleGenerationId, blockIds, listener, transportConf, null); + } else { + fetcher = new OneForOneDataBlockFetcher( + client, "app-id", "exec-id", blockIds, listener, transportConf, null); + } // Respond to the "OpenBlocks" message with an appropriate ShuffleStreamHandle with streamId 123 doAnswer(invocationOnMock -> { diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index 259e6a56869a2..4edb6c9823082 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -54,7 +54,7 @@ abstract class BlockTransferService extends BlockStoreClient with Logging { def hostName: String /** - * Fetch a sequence of blocks from a remote node asynchronously, + * Fetch a sequence of shuffle blocks from a remote node asynchronously, * available only after [[init]] is invoked. * * Note that this API takes a sequence so the implementation can batch requests, and does not @@ -70,6 +70,17 @@ abstract class BlockTransferService extends BlockStoreClient with Logging { listener: BlockFetchingListener, tempFileManager: DownloadFileManager): Unit + /** + * Fetch a sequence of non-shuffle blocks from a remote node asynchronously. + */ + override def fetchDataBlocks( + host: String, + port: Int, + execId: String, + blockIds: Array[String], + listener: BlockFetchingListener, + tempFileManager: DownloadFileManager): Unit + /** * Upload a single block to a remote node, available only after [[init]] is invoked. */ @@ -98,7 +109,7 @@ abstract class BlockTransferService extends BlockStoreClient with Logging { assert(!BlockId.apply(blockId).isShuffle) // A monitor for the thread to wait on. val result = Promise[ManagedBuffer]() - fetchBlocks(host, port, execId, -1, Array(blockId), + fetchDataBlocks(host, port, execId, Array(blockId), new BlockFetchingListener { override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = { result.failure(exception) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 656ec4436f46d..e44b20caac02e 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -36,7 +36,7 @@ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap, TransportClientFactory} import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap} import org.apache.spark.network.server._ -import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, OneForOneBlockFetcher, RetryingBlockFetcher} +import org.apache.spark.network.shuffle._ import org.apache.spark.network.shuffle.protocol.{UploadBlock, UploadBlockStream} import org.apache.spark.network.util.JavaUtils import org.apache.spark.rpc.RpcEndpointRef @@ -115,13 +115,40 @@ private[spark] class NettyBlockTransferService( listener: BlockFetchingListener, tempFileManager: DownloadFileManager): Unit = { logTrace(s"Fetch blocks from $host:$port (executor id $execId)") + doFetchBlocks(execId, blockIds, listener, () => { + val client = clientFactory.createClient(host, port) + new OneForOneShuffleBlockFetcher( + client, appId, execId, shuffleGenerationId, blockIds, + listener, transportConf, tempFileManager) + }) + } + + override def fetchDataBlocks( + host: String, + port: Int, + execId: String, + blockIds: Array[String], + listener: BlockFetchingListener, + tempFileManager: DownloadFileManager): Unit = { + logTrace(s"Fetch blocks from $host:$port (executor id $execId)") + doFetchBlocks(execId, blockIds, listener, () => { + val client = clientFactory.createClient(host, port) + new OneForOneDataBlockFetcher( + client, appId, execId, blockIds, + listener, transportConf, tempFileManager) + }) + } + + private def doFetchBlocks( + execId: String, + blockIds: Array[String], + listener: BlockFetchingListener, + createBlockFetcher: () => OneForOneBlockFetcher): Unit = { try { val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter { override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) { - val client = clientFactory.createClient(host, port) try { - new OneForOneBlockFetcher(client, appId, execId, shuffleGenerationId, blockIds, - listener, transportConf, tempFileManager).start() + createBlockFetcher().start() } catch { case e: IOException => Try { diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 3c2356a3e5f6e..f4bd6eefefc8f 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -212,10 +212,14 @@ private[spark] class IndexShuffleBlockResolver( } } - override def getBlockData(shuffleGenerationId: Int, blockId: ShuffleBlockId): ManagedBuffer = { + override def getBlockData( + shuffleId: Int, + shuffleGenerationId: Int, + mapId: Int, + reduceId: Int): ManagedBuffer = { // The block is actually going to be a range of a single map output file for this map, so // find out the consolidated file, then the offset within that from our index - val indexFile = getIndexFile(blockId.shuffleId, shuffleGenerationId, blockId.mapId) + val indexFile = getIndexFile(shuffleId, shuffleGenerationId, mapId) // SPARK-22982: if this FileInputStream's position is seeked forward by another piece of code // which is incorrectly using our file descriptor then this code will fetch the wrong offsets @@ -224,20 +228,20 @@ private[spark] class IndexShuffleBlockResolver( // class of issue from re-occurring in the future which is why they are left here even though // SPARK-22982 is fixed. val channel = Files.newByteChannel(indexFile.toPath) - channel.position(blockId.reduceId * 8L) + channel.position(reduceId * 8L) val in = new DataInputStream(Channels.newInputStream(channel)) try { val offset = in.readLong() val nextOffset = in.readLong() val actualPosition = channel.position() - val expectedPosition = blockId.reduceId * 8L + 16 + val expectedPosition = reduceId * 8L + 16 if (actualPosition != expectedPosition) { throw new Exception(s"SPARK-22982: Incorrect channel position after index file reads: " + s"expected $expectedPosition but actual position was $actualPosition.") } new FileSegmentManagedBuffer( transportConf, - getDataFile(blockId.shuffleId, shuffleGenerationId, blockId.mapId), + getDataFile(shuffleId, shuffleGenerationId, mapId), offset, nextOffset - offset) } finally { diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala index 1486321f6fb70..b70687e975441 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala @@ -34,7 +34,11 @@ trait ShuffleBlockResolver { * Retrieve the data for the specified block. If the data for that block is not available, * throws an unspecified exception. */ - def getBlockData(shuffleGenerationId: Int, blockId: ShuffleBlockId): ManagedBuffer + def getBlockData( + shuffleId: ShuffleId, + shuffleGenerationId: Int, + mapId: Int, + reduceId: Int): ManagedBuffer def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index c5bb6d9e2370d..b8d266d95fc8c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -552,7 +552,9 @@ private[spark] class BlockManager( // to fetch shuffle blocks. After Spark 3.0, all shuffle block fetching will use new // protocol FetchShuffleBlocks and use getShuffleBlockData in BlockManager. // See more details in SPARK-27665. - shuffleManager.shuffleBlockResolver.getBlockData(-1, blockId.asInstanceOf[ShuffleBlockId]) + val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId] + shuffleManager.shuffleBlockResolver.getBlockData( + shuffleBlockId.shuffleId, -1, shuffleBlockId.mapId, shuffleBlockId.reduceId) } else { getLocalBytes(blockId) match { case Some(blockData) => @@ -576,8 +578,8 @@ private[spark] class BlockManager( shuffleGenerationId: Int, mapId: Int, reduceId: Int): ManagedBuffer = { - val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) - shuffleManager.shuffleBlockResolver.getBlockData(shuffleGenerationId, shuffleBlockId) + shuffleManager.shuffleBlockResolver.getBlockData( + shuffleId, shuffleGenerationId, mapId, reduceId) } /** diff --git a/core/src/test/scala/org/apache/spark/network/BlockTransferServiceSuite.scala b/core/src/test/scala/org/apache/spark/network/BlockTransferServiceSuite.scala index cceb9717b78d8..8eafaaa840f5f 100644 --- a/core/src/test/scala/org/apache/spark/network/BlockTransferServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/BlockTransferServiceSuite.scala @@ -47,11 +47,10 @@ class BlockTransferServiceSuite extends SparkFunSuite with TimeLimits { override def hostName: String = "localhost-unused" - override def fetchBlocks( + override def fetchDataBlocks( host: String, port: Int, execId: String, - shuffleGenerationId: Int, blockIds: Array[String], listener: BlockFetchingListener, tempFileManager: DownloadFileManager): Unit = { @@ -92,6 +91,18 @@ class BlockTransferServiceSuite extends SparkFunSuite with TimeLimits { // This method is unused in this test throw new UnsupportedOperationException("uploadBlock") } + + override def fetchBlocks( + host: String, + port: Int, + execId: String, + shuffleGenerationId: Int, + blockIds: Array[String], + listener: BlockFetchingListener, + tempFileManager: DownloadFileManager): Unit = { + // This method is unused in this test + throw new UnsupportedOperationException("fetchBlocks") + } } val e = intercept[SparkException] { diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala index 46815f066d426..f510c82995e63 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala @@ -121,8 +121,8 @@ class NettyBlockTransferServiceSuite clientFactoryField.setAccessible(true) clientFactoryField.set(service0, clientFactory) - service0.fetchBlocks("localhost", port, "exec1", -1, - Array("block1"), listener, mock(classOf[DownloadFileManager])) + service0.fetchDataBlocks("localhost", port, "exec1", Array("block1"), + listener, mock(classOf[DownloadFileManager])) assert(createClientCount === 1) assert(hitExecutorDeadException) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index f8accd94a5a2c..1845cb09e9fa6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -2712,9 +2712,11 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi private def constructIndeterminateStageRetryScenario(): (Int, Int) = { val shuffleMapRdd1 = new MyRDD(sc, 2, Nil, indeterminate = true) + val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(2)) val shuffleId1 = shuffleDep1.shuffleId val shuffleMapRdd2 = new MyRDD(sc, 2, List(shuffleDep1), tracker = mapOutputTracker) + val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new HashPartitioner(2)) val shuffleId2 = shuffleDep2.shuffleId val finalRdd = new MyRDD(sc, 2, List(shuffleDep2), tracker = mapOutputTracker) @@ -2726,18 +2728,21 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi (Success, makeMapStatus("hostA", 2)), (Success, makeMapStatus("hostB", 2)))) assert(mapOutputTracker.findMissingPartitions(shuffleId1) === Some(Seq.empty)) + // Finish the second shuffle map stage. complete(taskSets(1), Seq( (Success, makeMapStatus("hostC", 2)), (Success, makeMapStatus("hostD", 2)))) assert(mapOutputTracker.findMissingPartitions(shuffleId2) === Some(Seq.empty)) + // The first task of the final stage failed with fetch failure runEvent(makeCompletionEvent( taskSets(2).tasks(0), FetchFailed(makeBlockManagerId("hostC"), shuffleId2, 0, 0, "ignored"), null)) + val failedStages = scheduler.failedStages.toSeq - assert(failedStages.length == 2) + assert(failedStages.map(_.id) == Seq(1, 2)) // Shuffle blocks of "hostC" is lost, so first task of the `shuffleMapRdd2` needs to retry. assert(failedStages.collect { case stage: ShuffleMapStage if stage.shuffleDep.shuffleId == shuffleId2 => stage @@ -2746,6 +2751,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assert(failedStages.collect { case stage: ResultStage => stage }.head.findMissingPartitions() == Seq(0, 1)) + scheduler.resubmitFailedStages() (shuffleId1, shuffleId2) } @@ -2758,13 +2764,13 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi init(conf) val (shuffleId1, _) = constructIndeterminateStageRetryScenario() - // The second task of the `shuffleMapRdd2` failed with fetch failure + // The first task of the `shuffleMapRdd2` failed with fetch failure runEvent(makeCompletionEvent( - taskSets(3).tasks(0), + taskSets(3).tasks(1), Success, makeMapStatus("hostC", 2))) runEvent(makeCompletionEvent( - taskSets(3).tasks(1), + taskSets(3).tasks(0), FetchFailed(makeBlockManagerId("hostA"), shuffleId1, 0, 0, "ignored"), null)) @@ -2835,7 +2841,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi submit(finalRdd, Array(0, 1), properties = new Properties()) - // Finish the first 3 shuffle map stages. + // Finish the first 2 shuffle map stages. complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 2)), (Success, makeMapStatus("hostB", 2)))) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 36f1f839406bc..6dd75a435117c 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -1657,6 +1657,16 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1))) } + override def fetchDataBlocks( + host: String, + port: Int, + execId: String, + blockIds: Array[String], + listener: BlockFetchingListener, + tempFileManager: DownloadFileManager): Unit = { + listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1))) + } + override def close(): Unit = {} override def hostName: String = { "MockBlockTransferServiceHost" } From 99c2b4ab0d3d2b7b0b16ab1348a190e191b43c92 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Mon, 5 Aug 2019 18:55:36 +0800 Subject: [PATCH 3/9] Resolve conflict with SPARK-28209 --- .../spark/shuffle/api/ShuffleExecutorComponents.java | 12 ++++++++---- .../shuffle/sort/BypassMergeSortShuffleWriter.java | 4 ++-- .../sort/io/LocalDiskShuffleExecutorComponents.java | 3 ++- .../sort/io/LocalDiskShuffleMapOutputWriter.java | 8 ++++++-- .../io/LocalDiskShuffleMapOutputWriterSuite.scala | 5 +++-- 5 files changed, 21 insertions(+), 11 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java index 70c112b78911d..09ab3ae527f7d 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java @@ -39,16 +39,20 @@ public interface ShuffleExecutorComponents { /** * Called once per map task to create a writer that will be responsible for persisting all the * partitioned bytes written by that map task. - * @param shuffleId Unique identifier for the shuffle the map task is a part of + * @param shuffleId Unique identifier for the shuffle the map task is a part of + * @param shuffleGenerationId The shuffle generation ID of the stage that this task belongs to, + * it equals the stage attempt number while the stage is indeterminate + * and -1 on the contrary. * @param mapId Within the shuffle, the identifier of the map task * @param mapTaskAttemptId Identifier of the task attempt. Multiple attempts of the same map task - * with the same (shuffleId, mapId) pair can be distinguished by the - * different values of mapTaskAttemptId. + * with the same (shuffleId, mapId) pair can be distinguished by the + * different values of mapTaskAttemptId. * @param numPartitions The number of partitions that will be written by the map task. Some of -* these partitions may be empty. + * these partitions may be empty. */ ShuffleMapOutputWriter createMapOutputWriter( int shuffleId, + int shuffleGenerationId, int mapId, long mapTaskAttemptId, int numPartitions) throws IOException; diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index d167145de00e8..43a7162a24030 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -132,8 +132,8 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { @Override public void write(Iterator> records) throws IOException { assert (partitionWriters == null); - ShuffleMapOutputWriter mapOutputWriter = shuffleExecutorComponents - .createMapOutputWriter(shuffleId, mapId, mapTaskAttemptId, numPartitions); + ShuffleMapOutputWriter mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter( + shuffleId, shuffleGenerationId, mapId, mapTaskAttemptId, numPartitions); try { if (!records.hasNext()) { partitionLengths = new long[numPartitions]; diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java index 02eb710737285..c50bc3cb738c5 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java @@ -58,6 +58,7 @@ public void initializeExecutor(String appId, String execId) { @Override public ShuffleMapOutputWriter createMapOutputWriter( int shuffleId, + int shuffleGenerationId, int mapId, long mapTaskAttemptId, int numPartitions) { @@ -66,6 +67,6 @@ public ShuffleMapOutputWriter createMapOutputWriter( "Executor components must be initialized before getting writers."); } return new LocalDiskShuffleMapOutputWriter( - shuffleId, mapId, numPartitions, blockResolver, sparkConf); + shuffleId, shuffleGenerationId, mapId, numPartitions, blockResolver, sparkConf); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java index add4634a61fb5..87c4b2ec351b3 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java @@ -48,6 +48,7 @@ public class LocalDiskShuffleMapOutputWriter implements ShuffleMapOutputWriter { LoggerFactory.getLogger(LocalDiskShuffleMapOutputWriter.class); private final int shuffleId; + private final int shuffleGenerationId; private final int mapId; private final IndexShuffleBlockResolver blockResolver; private final long[] partitionLengths; @@ -63,18 +64,20 @@ public class LocalDiskShuffleMapOutputWriter implements ShuffleMapOutputWriter { public LocalDiskShuffleMapOutputWriter( int shuffleId, + int shuffleGenerationId, int mapId, int numPartitions, IndexShuffleBlockResolver blockResolver, SparkConf sparkConf) { this.shuffleId = shuffleId; + this.shuffleGenerationId = shuffleGenerationId; this.mapId = mapId; this.blockResolver = blockResolver; this.bufferSize = (int) (long) sparkConf.get( package$.MODULE$.SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE()) * 1024; this.partitionLengths = new long[numPartitions]; - this.outputFile = blockResolver.getDataFile(shuffleId, mapId); + this.outputFile = blockResolver.getDataFile(shuffleId, shuffleGenerationId, mapId); this.outputTempFile = null; } @@ -99,7 +102,8 @@ public ShufflePartitionWriter getPartitionWriter(int reducePartitionId) throws I public void commitAllPartitions() throws IOException { cleanUp(); File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null; - blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp); + blockResolver.writeIndexFileAndCommit( + shuffleId, shuffleGenerationId, mapId, partitionLengths, resolvedTmp); } @Override diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala index 5693b9824523a..bcf1a6774e697 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala @@ -73,9 +73,9 @@ class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndA conf = new SparkConf() .set("spark.app.id", "example.spark.app") .set("spark.shuffle.unsafe.file.output.buffer", "16k") - when(blockResolver.getDataFile(anyInt, anyInt)).thenReturn(mergedOutputFile) + when(blockResolver.getDataFile(anyInt, anyInt, anyInt)).thenReturn(mergedOutputFile) when(blockResolver.writeIndexFileAndCommit( - anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File]))) + anyInt, anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File]))) .thenAnswer { invocationOnMock => partitionSizesInMergedFile = invocationOnMock.getArguments()(2).asInstanceOf[Array[Long]] val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File] @@ -87,6 +87,7 @@ class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndA } mapOutputWriter = new LocalDiskShuffleMapOutputWriter( 0, + -1, 0, NUM_PARTITIONS, blockResolver, From 122a8fed056b804407a3b35184454a52c0fa79d1 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Mon, 5 Aug 2019 22:30:55 +0800 Subject: [PATCH 4/9] fix --- .../org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java index 45a593c25a93c..e3d60880765d2 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java @@ -39,7 +39,7 @@ public interface ShuffleMapOutputWriter { * for the same partition within any given map task. The partition identifier will be in the * range of precisely 0 (inclusive) to numPartitions (exclusive), where numPartitions was * provided upon the creation of this map output writer via - * {@link ShuffleExecutorComponents#createMapOutputWriter(int, int, long, int)}. + * {@link ShuffleExecutorComponents#createMapOutputWriter(int, int, int, long, int)}. *

* Calls to this method will be invoked with monotonically increasing reducePartitionIds; each * call to this method will be called with a reducePartitionId that is strictly greater than From eb5c58b35427497accb43c47423c23c48810aa78 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Tue, 6 Aug 2019 11:53:02 +0800 Subject: [PATCH 5/9] UT fix --- .../sort/io/LocalDiskShuffleMapOutputWriterSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala index bcf1a6774e697..a862c931bddb2 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala @@ -77,8 +77,8 @@ class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndA when(blockResolver.writeIndexFileAndCommit( anyInt, anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File]))) .thenAnswer { invocationOnMock => - partitionSizesInMergedFile = invocationOnMock.getArguments()(2).asInstanceOf[Array[Long]] - val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File] + partitionSizesInMergedFile = invocationOnMock.getArguments()(3).asInstanceOf[Array[Long]] + val tmp: File = invocationOnMock.getArguments()(4).asInstanceOf[File] if (tmp != null) { mergedOutputFile.delete() tmp.renameTo(mergedOutputFile) From 67f70a4e253b9d4808219c4e80fc2478ed6d9e59 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Wed, 7 Aug 2019 21:40:10 +0800 Subject: [PATCH 6/9] Address comment, fetchBlocks finally change to fetchShuffleBlocks --- .../network/shuffle/BlockStoreClient.java | 4 +-- .../shuffle/ExternalBlockStoreClient.java | 4 +-- .../ExternalShuffleIntegrationSuite.java | 2 +- .../spark/network/BlockTransferService.scala | 34 ++----------------- .../netty/NettyBlockTransferService.scala | 4 +-- .../apache/spark/scheduler/DAGScheduler.scala | 2 ++ .../spark/scheduler/ShuffleMapStage.scala | 5 +-- .../storage/ShuffleBlockFetcherIterator.scala | 4 +-- .../network/BlockTransferServiceSuite.scala | 4 +-- .../NettyBlockTransferSecuritySuite.scala | 2 +- .../spark/storage/BlockManagerSuite.scala | 2 +- .../ShuffleBlockFetcherIteratorSuite.scala | 18 +++++----- 12 files changed, 28 insertions(+), 57 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java index 294cc2caa28ec..11377ca0421c7 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java @@ -29,7 +29,7 @@ public abstract class BlockStoreClient implements Closeable { /** - * Fetch a sequence of blocks from a remote node asynchronously, + * Fetch a sequence of shuffle blocks from a remote node asynchronously, * * Note that this API takes a sequence so the implementation can batch requests, and does not * return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as @@ -46,7 +46,7 @@ public abstract class BlockStoreClient implements Closeable { * into temp shuffle files to reduce the memory usage, otherwise, * they will be kept in memory. */ - public abstract void fetchBlocks( + public abstract void fetchShuffleBlocks( String host, int port, String execId, diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java index f2c58b54b8692..43fc83127cc1d 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java @@ -101,7 +101,7 @@ public void fetchDataBlocks( } @Override - public void fetchBlocks( + public void fetchShuffleBlocks( String host, int port, String execId, @@ -146,7 +146,7 @@ private void doFetchBlocks( blockFetchStarter.createAndStart(blockIds, listener); } } catch (Exception e) { - logger.error("Exception while beginning fetchBlocks", e); + logger.error("Exception while beginning fetchShuffleBlocks", e); for (String blockId : blockIds) { listener.onBlockFetchFailure(blockId, e); } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index cabf982d044eb..97cab117a99d3 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -218,7 +218,7 @@ public void onBlockFetchFailure(String blockId, Throwable exception) { } }; if (isShuffleBlocks) { - client.fetchBlocks(TestUtils.getLocalHost(), port, execId, shuffleGenerationId, blockIds, + client.fetchShuffleBlocks(TestUtils.getLocalHost(), port, execId, shuffleGenerationId, blockIds, listener, null); } else { client.fetchDataBlocks(TestUtils.getLocalHost(), port, execId, blockIds, listener, null); diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index 4edb6c9823082..1a2c581be67a8 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -38,7 +38,7 @@ abstract class BlockTransferService extends BlockStoreClient with Logging { /** * Initialize the transfer service by giving it the BlockDataManager that can be used to fetch - * local blocks or put local blocks. The fetchBlocks method in [[BlockStoreClient]] also + * local blocks or put local blocks. The fetchShuffleBlocks method in [[BlockStoreClient]] also * available only after this is invoked. */ def init(blockDataManager: BlockDataManager): Unit @@ -53,34 +53,6 @@ abstract class BlockTransferService extends BlockStoreClient with Logging { */ def hostName: String - /** - * Fetch a sequence of shuffle blocks from a remote node asynchronously, - * available only after [[init]] is invoked. - * - * Note that this API takes a sequence so the implementation can batch requests, and does not - * return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as - * the data of a block is fetched, rather than waiting for all blocks to be fetched. - */ - override def fetchBlocks( - host: String, - port: Int, - execId: String, - shuffleGenerationId: Int, - blockIds: Array[String], - listener: BlockFetchingListener, - tempFileManager: DownloadFileManager): Unit - - /** - * Fetch a sequence of non-shuffle blocks from a remote node asynchronously. - */ - override def fetchDataBlocks( - host: String, - port: Int, - execId: String, - blockIds: Array[String], - listener: BlockFetchingListener, - tempFileManager: DownloadFileManager): Unit - /** * Upload a single block to a remote node, available only after [[init]] is invoked. */ @@ -94,7 +66,7 @@ abstract class BlockTransferService extends BlockStoreClient with Logging { classTag: ClassTag[_]): Future[Unit] /** - * A special case of [[fetchBlocks]], as it fetches only one block and is blocking. + * A special case of [[fetchDataBlocks]], as it fetches only one block and is blocking. * * It is also only available after [[init]] is invoked. */ @@ -104,7 +76,7 @@ abstract class BlockTransferService extends BlockStoreClient with Logging { execId: String, blockId: String, tempFileManager: DownloadFileManager): ManagedBuffer = { - // Make sure ShuffleBlockId will not enter this function, so we call fetchBlocks with the + // Make sure ShuffleBlockId will not enter this function, so we call fetchShuffleBlocks with the // invalid shuffleGenerationId -1 here for this special case of fetching a non-shuffle block. assert(!BlockId.apply(blockId).isShuffle) // A monitor for the thread to wait on. diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index e44b20caac02e..d7c9e54f8a098 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -106,7 +106,7 @@ private[spark] class NettyBlockTransferService( } } - override def fetchBlocks( + override def fetchShuffleBlocks( host: String, port: Int, execId: String, @@ -173,7 +173,7 @@ private[spark] class NettyBlockTransferService( } } catch { case e: Exception => - logError("Exception while beginning fetchBlocks", e) + logError("Exception while beginning fetchShuffleBlocks", e) blockIds.foreach(listener.onBlockFetchFailure(_, e)) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 94ab4087f7157..c34a38f2bd4ab 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1101,6 +1101,8 @@ private[spark] class DAGScheduler( logDebug("submitMissingTasks(" + stage + ")") // Before find missing partition, do the intermediate state clean work first. + // The operation here can make sure for the intermediate stage, `findMissingPartitions()` + // returns all partitions every time. stage match { case sms: ShuffleMapStage if stage.isIndeterminate() => mapOutputTracker.unregisterAllMapOutput(sms.shuffleDep.shuffleId) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala index 522fc99e73a85..1b44d0aee3195 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala @@ -87,10 +87,7 @@ private[spark] class ShuffleMapStage( */ def isAvailable: Boolean = numAvailableOutputs == numPartitions - /** - * Returns the sequence of partition ids that are missing (i.e. needs to be computed). - * If the current stage is indeterminate, missing partition is all partitions every time. - */ + /** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */ override def findMissingPartitions(): Seq[Int] = { mapOutputTrackerMaster .findMissingPartitions(shuffleDep.shuffleId) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index af95cb82829bd..1859a2c718cd3 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -259,10 +259,10 @@ final class ShuffleBlockFetcherIterator( // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch // the data and write it to file directly. if (req.size > maxReqSizeShuffleToMem) { - shuffleClient.fetchBlocks(address.host, address.port, address.executorId, + shuffleClient.fetchShuffleBlocks(address.host, address.port, address.executorId, shuffleGenerationId, blockIds.toArray, blockFetchingListener, this) } else { - shuffleClient.fetchBlocks(address.host, address.port, address.executorId, + shuffleClient.fetchShuffleBlocks(address.host, address.port, address.executorId, shuffleGenerationId, blockIds.toArray, blockFetchingListener, null) } } diff --git a/core/src/test/scala/org/apache/spark/network/BlockTransferServiceSuite.scala b/core/src/test/scala/org/apache/spark/network/BlockTransferServiceSuite.scala index 8eafaaa840f5f..a5528c49c6628 100644 --- a/core/src/test/scala/org/apache/spark/network/BlockTransferServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/BlockTransferServiceSuite.scala @@ -92,7 +92,7 @@ class BlockTransferServiceSuite extends SparkFunSuite with TimeLimits { throw new UnsupportedOperationException("uploadBlock") } - override def fetchBlocks( + override def fetchShuffleBlocks( host: String, port: Int, execId: String, @@ -101,7 +101,7 @@ class BlockTransferServiceSuite extends SparkFunSuite with TimeLimits { listener: BlockFetchingListener, tempFileManager: DownloadFileManager): Unit = { // This method is unused in this test - throw new UnsupportedOperationException("fetchBlocks") + throw new UnsupportedOperationException("fetchShuffleBlocks") } } diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index f014857342220..5180d3c864ee5 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -159,7 +159,7 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi val promise = Promise[ManagedBuffer]() - self.fetchBlocks(from.hostName, from.port, execId, -1, Array(blockId.toString), + self.fetchShuffleBlocks(from.hostName, from.port, execId, -1, Array(blockId.toString), new BlockFetchingListener { override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = { promise.failure(exception) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 6dd75a435117c..644620f9f1e00 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -1646,7 +1646,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE override def init(blockDataManager: BlockDataManager): Unit = {} - override def fetchBlocks( + override def fetchShuffleBlocks( host: String, port: Int, execId: String, diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 68082c59678fa..16f013d2ec321 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -49,7 +49,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT /** Creates a mock [[BlockTransferService]] that returns data from the given map. */ private def createMockTransfer(data: Map[BlockId, ManagedBuffer]): BlockTransferService = { val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())).thenAnswer( + when(transfer.fetchShuffleBlocks(any(), any(), any(), any(), any(), any(), any())).thenAnswer( (invocation: InvocationOnMock) => { val blocks = invocation.getArguments()(4).asInstanceOf[Array[String]] val listener = invocation.getArguments()(5).asInstanceOf[BlockFetchingListener] @@ -146,9 +146,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } // 3 local blocks, and 2 remote blocks - // (but from the same block manager so one call to fetchBlocks) + // (but from the same block manager so one call to fetchShuffleBlocks) verify(blockManager, times(3)).getShuffleBlockData(any(), any(), any(), any()) - verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any(), any()) + verify(transfer, times(1)).fetchShuffleBlocks(any(), any(), any(), any(), any(), any(), any()) } test("release current unexhausted buffer in case the task completes early") { @@ -167,7 +167,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) + when(transfer.fetchShuffleBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer((invocation: InvocationOnMock) => { val listener = invocation.getArguments()(5).asInstanceOf[BlockFetchingListener] Future { @@ -235,7 +235,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) + when(transfer.fetchShuffleBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer((invocation: InvocationOnMock) => { val listener = invocation.getArguments()(5).asInstanceOf[BlockFetchingListener] Future { @@ -324,7 +324,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) + when(transfer.fetchShuffleBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer((invocation: InvocationOnMock) => { val listener = invocation.getArguments()(5).asInstanceOf[BlockFetchingListener] Future { @@ -364,7 +364,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val (id1, _) = iterator.next() assert(id1 === ShuffleBlockId(0, 0, 0)) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) + when(transfer.fetchShuffleBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer((invocation: InvocationOnMock) => { val listener = invocation.getArguments()(5).asInstanceOf[BlockFetchingListener] Future { @@ -519,7 +519,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) + when(transfer.fetchShuffleBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer((invocation: InvocationOnMock) => { val listener = invocation.getArguments()(5).asInstanceOf[BlockFetchingListener] Future { @@ -583,7 +583,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()) val transfer = mock(classOf[BlockTransferService]) var tempFileManager: DownloadFileManager = null - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) + when(transfer.fetchShuffleBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer((invocation: InvocationOnMock) => { val listener = invocation.getArguments()(5).asInstanceOf[BlockFetchingListener] tempFileManager = invocation.getArguments()(6).asInstanceOf[DownloadFileManager] From 0e8ff21cfe0ca32dbd98d1e86de4ae6db2913dc1 Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Wed, 7 Aug 2019 22:05:47 +0800 Subject: [PATCH 7/9] fix style --- .../network/shuffle/ExternalShuffleIntegrationSuite.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 97cab117a99d3..906d44d8027fc 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -218,8 +218,8 @@ public void onBlockFetchFailure(String blockId, Throwable exception) { } }; if (isShuffleBlocks) { - client.fetchShuffleBlocks(TestUtils.getLocalHost(), port, execId, shuffleGenerationId, blockIds, - listener, null); + client.fetchShuffleBlocks( + TestUtils.getLocalHost(), port, execId, shuffleGenerationId, blockIds, listener, null); } else { client.fetchDataBlocks(TestUtils.getLocalHost(), port, execId, blockIds, listener, null); } From 4f9606cb3cfd8c46e67f6c7e56c7233bb5f1c8dc Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Fri, 9 Aug 2019 01:20:51 +0800 Subject: [PATCH 8/9] Forbid speculation for the indeterminate stage rerunning --- .../scala/org/apache/spark/scheduler/DAGScheduler.scala | 3 ++- .../src/main/scala/org/apache/spark/scheduler/TaskSet.scala | 3 ++- .../scala/org/apache/spark/scheduler/TaskSetManager.scala | 6 ++++-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index c34a38f2bd4ab..fea559e0e0946 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1257,7 +1257,8 @@ private[spark] class DAGScheduler( logInfo(s"Submitting ${tasks.size} missing tasks from $stage (${stage.rdd}) (first 15 " + s"tasks are for partitions ${tasks.take(15).map(_.partitionId)})") taskScheduler.submitTasks(new TaskSet( - tasks.toArray, stage.id, stage.latestInfo.attemptNumber, jobId, properties)) + tasks.toArray, stage.id, stage.latestInfo.attemptNumber, + jobId, properties, stage.isIndeterminate)) } else { // Because we posted SparkListenerStageSubmitted earlier, we should mark // the stage as completed here in case there are no tasks to run diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala index 517c8991aed78..0b0bc84aa789c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala @@ -28,7 +28,8 @@ private[spark] class TaskSet( val stageId: Int, val stageAttemptId: Int, val priority: Int, - val properties: Properties) { + val properties: Properties, + val isIndeterminate: Boolean = false) { val id: String = stageId + "." + stageAttemptId override def toString: String = "TaskSet " + id diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 49bd55e553482..40491bda5e67d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -955,8 +955,10 @@ private[spark] class TaskSetManager( */ override def checkSpeculatableTasks(minTimeToSpeculation: Int): Boolean = { // Can't speculate if we only have one task, and no need to speculate if the task set is a - // zombie or is from a barrier stage. - if (isZombie || isBarrier || numTasks == 1) { + // zombie or is from a barrier stage. Also, for the scenario of indeterminate stage rerunning, + // speculation will cause correctness bugs, see more details in SPARK-23243. + if (isZombie || isBarrier || numTasks == 1 || + (taskSet.stageAttemptId > 0 && taskSet.isIndeterminate)) { return false } var foundTasks = false From 8873f111f5e1dcd6137396eeda06d7595ac2a72e Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Fri, 9 Aug 2019 14:16:38 +0800 Subject: [PATCH 9/9] Do not initialize isIndeterminate in normal stage submit --- .../main/scala/org/apache/spark/scheduler/DAGScheduler.scala | 4 ++-- core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala | 2 +- .../scala/org/apache/spark/scheduler/TaskSetManager.scala | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index fea559e0e0946..cb55f6bebb52a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1257,8 +1257,8 @@ private[spark] class DAGScheduler( logInfo(s"Submitting ${tasks.size} missing tasks from $stage (${stage.rdd}) (first 15 " + s"tasks are for partitions ${tasks.take(15).map(_.partitionId)})") taskScheduler.submitTasks(new TaskSet( - tasks.toArray, stage.id, stage.latestInfo.attemptNumber, - jobId, properties, stage.isIndeterminate)) + tasks.toArray, stage.id, stage.latestInfo.attemptNumber, jobId, properties, + stage.latestInfo.attemptNumber > 0 && stage.isIndeterminate)) } else { // Because we posted SparkListenerStageSubmitted earlier, we should mark // the stage as completed here in case there are no tasks to run diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala index 0b0bc84aa789c..080cbde7471aa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala @@ -29,7 +29,7 @@ private[spark] class TaskSet( val stageAttemptId: Int, val priority: Int, val properties: Properties, - val isIndeterminate: Boolean = false) { + val isIndeterminateRerun: Boolean = false) { val id: String = stageId + "." + stageAttemptId override def toString: String = "TaskSet " + id diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 40491bda5e67d..d26c133f36cf2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -958,7 +958,7 @@ private[spark] class TaskSetManager( // zombie or is from a barrier stage. Also, for the scenario of indeterminate stage rerunning, // speculation will cause correctness bugs, see more details in SPARK-23243. if (isZombie || isBarrier || numTasks == 1 || - (taskSet.stageAttemptId > 0 && taskSet.isIndeterminate)) { + (taskSet.stageAttemptId > 0 && taskSet.isIndeterminateRerun)) { return false } var foundTasks = false