diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java index 736059fdd1f57..490915f6de4b3 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java @@ -112,4 +112,27 @@ public static int[] decode(ByteBuf buf) { return ints; } } + + /** Long integer arrays are encoded with their length followed by long integers. */ + public static class LongArrays { + public static int encodedLength(long[] longs) { + return 4 + 8 * longs.length; + } + + public static void encode(ByteBuf buf, long[] longs) { + buf.writeInt(longs.length); + for (long i : longs) { + buf.writeLong(i); + } + } + + public static long[] decode(ByteBuf buf) { + int numLongs = buf.readInt(); + long[] longs = new long[numLongs]; + for (int i = 0; i < longs.length; i ++) { + longs[i] = buf.readLong(); + } + return longs; + } + } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java index 037e5cf7e5222..2d7a72315cf23 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java @@ -106,7 +106,7 @@ protected void handleMessage( numBlockIds += ids.length; } streamId = streamManager.registerStream(client.getClientId(), - new ManagedBufferIterator(msg, numBlockIds), client.getChannel()); + new ShuffleManagedBufferIterator(msg), client.getChannel()); } else { // For the compatibility with the old version, still keep the support for OpenBlocks. OpenBlocks msg = (OpenBlocks) msgObj; @@ -299,21 +299,6 @@ private int[] shuffleMapIdAndReduceIds(String[] blockIds, int shuffleId) { return mapIdAndReduceIds; } - ManagedBufferIterator(FetchShuffleBlocks msg, int numBlockIds) { - final int[] mapIdAndReduceIds = new int[2 * numBlockIds]; - int idx = 0; - for (int i = 0; i < msg.mapIds.length; i++) { - for (int reduceId : msg.reduceIds[i]) { - mapIdAndReduceIds[idx++] = msg.mapIds[i]; - mapIdAndReduceIds[idx++] = reduceId; - } - } - assert(idx == 2 * numBlockIds); - size = mapIdAndReduceIds.length; - blockDataForIndexFn = index -> blockManager.getBlockData(msg.appId, msg.execId, - msg.shuffleId, mapIdAndReduceIds[index], mapIdAndReduceIds[index + 1]); - } - @Override public boolean hasNext() { return index < size; @@ -328,6 +313,49 @@ public ManagedBuffer next() { } } + private class ShuffleManagedBufferIterator implements Iterator { + + private int mapIdx = 0; + private int reduceIdx = 0; + + private final String appId; + private final String execId; + private final int shuffleId; + private final long[] mapIds; + private final int[][] reduceIds; + + ShuffleManagedBufferIterator(FetchShuffleBlocks msg) { + appId = msg.appId; + execId = msg.execId; + shuffleId = msg.shuffleId; + mapIds = msg.mapIds; + reduceIds = msg.reduceIds; + } + + @Override + public boolean hasNext() { + // mapIds.length must equal to reduceIds.length, and the passed in FetchShuffleBlocks + // must have non-empty mapIds and reduceIds, see the checking logic in + // OneForOneBlockFetcher. + assert(mapIds.length != 0 && mapIds.length == reduceIds.length); + return mapIdx < mapIds.length && reduceIdx < reduceIds[mapIdx].length; + } + + @Override + public ManagedBuffer next() { + final ManagedBuffer block = blockManager.getBlockData( + appId, execId, shuffleId, mapIds[mapIdx], reduceIds[mapIdx][reduceIdx]); + if (reduceIdx < reduceIds[mapIdx].length - 1) { + reduceIdx += 1; + } else { + reduceIdx = 0; + mapIdx += 1; + } + metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0); + return block; + } + } + @Override public void channelActive(TransportClient client) { metrics.activeConnections.inc(); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 50f16fc700f12..8b0d1e145a813 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -172,7 +172,7 @@ public ManagedBuffer getBlockData( String appId, String execId, int shuffleId, - int mapId, + long mapId, int reduceId) { ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId)); if (executor == null) { @@ -296,7 +296,7 @@ private void deleteNonShuffleServiceServedFiles(String[] dirs) { * and the block id format is from ShuffleDataBlockId and ShuffleIndexBlockId. */ private ManagedBuffer getSortBasedShuffleBlockData( - ExecutorShuffleInfo executor, int shuffleId, int mapId, int reduceId) { + ExecutorShuffleInfo executor, int shuffleId, long mapId, int reduceId) { File indexFile = ExecutorDiskUtils.getFile(executor.localDirs, executor.subDirsPerLocalDir, "shuffle_" + shuffleId + "_" + mapId + "_0.index"); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index cc11e92067375..52854c86be3e6 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -24,6 +24,8 @@ import java.util.HashMap; import com.google.common.primitives.Ints; +import com.google.common.primitives.Longs; +import org.apache.commons.lang3.tuple.ImmutableTriple; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -111,21 +113,21 @@ private boolean isShuffleBlocks(String[] blockIds) { */ private FetchShuffleBlocks createFetchShuffleBlocksMsg( String appId, String execId, String[] blockIds) { - int shuffleId = splitBlockId(blockIds[0])[0]; - HashMap> mapIdToReduceIds = new HashMap<>(); + int shuffleId = splitBlockId(blockIds[0]).left; + HashMap> mapIdToReduceIds = new HashMap<>(); for (String blockId : blockIds) { - int[] blockIdParts = splitBlockId(blockId); - if (blockIdParts[0] != shuffleId) { + ImmutableTriple blockIdParts = splitBlockId(blockId); + if (blockIdParts.left != shuffleId) { throw new IllegalArgumentException("Expected shuffleId=" + shuffleId + ", got:" + blockId); } - int mapId = blockIdParts[1]; + long mapId = blockIdParts.middle; if (!mapIdToReduceIds.containsKey(mapId)) { mapIdToReduceIds.put(mapId, new ArrayList<>()); } - mapIdToReduceIds.get(mapId).add(blockIdParts[2]); + mapIdToReduceIds.get(mapId).add(blockIdParts.right); } - int[] mapIds = Ints.toArray(mapIdToReduceIds.keySet()); + long[] mapIds = Longs.toArray(mapIdToReduceIds.keySet()); int[][] reduceIdArr = new int[mapIds.length][]; for (int i = 0; i < mapIds.length; i++) { reduceIdArr[i] = Ints.toArray(mapIdToReduceIds.get(mapIds[i])); @@ -134,17 +136,16 @@ private FetchShuffleBlocks createFetchShuffleBlocksMsg( } /** Split the shuffleBlockId and return shuffleId, mapId and reduceId. */ - private int[] splitBlockId(String blockId) { + private ImmutableTriple splitBlockId(String blockId) { String[] blockIdParts = blockId.split("_"); if (blockIdParts.length != 4 || !blockIdParts[0].equals("shuffle")) { throw new IllegalArgumentException( "Unexpected shuffle block id format: " + blockId); } - return new int[] { - Integer.parseInt(blockIdParts[1]), - Integer.parseInt(blockIdParts[2]), - Integer.parseInt(blockIdParts[3]) - }; + return new ImmutableTriple<>( + Integer.parseInt(blockIdParts[1]), + Long.parseLong(blockIdParts[2]), + Integer.parseInt(blockIdParts[3])); } /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */ diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java index 466eeb3e048a8..faa960d414bcc 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java @@ -34,14 +34,14 @@ public class FetchShuffleBlocks extends BlockTransferMessage { public final int shuffleId; // The length of mapIds must equal to reduceIds.size(), for the i-th mapId in mapIds, // it corresponds to the i-th int[] in reduceIds, which contains all reduce id for this map id. - public final int[] mapIds; + public final long[] mapIds; public final int[][] reduceIds; public FetchShuffleBlocks( String appId, String execId, int shuffleId, - int[] mapIds, + long[] mapIds, int[][] reduceIds) { this.appId = appId; this.execId = execId; @@ -98,7 +98,7 @@ public int encodedLength() { return Encoders.Strings.encodedLength(appId) + Encoders.Strings.encodedLength(execId) + 4 /* encoded length of shuffleId */ - + Encoders.IntArrays.encodedLength(mapIds) + + Encoders.LongArrays.encodedLength(mapIds) + 4 /* encoded length of reduceIds.size() */ + encodedLengthOfReduceIds; } @@ -108,7 +108,7 @@ public void encode(ByteBuf buf) { Encoders.Strings.encode(buf, appId); Encoders.Strings.encode(buf, execId); buf.writeInt(shuffleId); - Encoders.IntArrays.encode(buf, mapIds); + Encoders.LongArrays.encode(buf, mapIds); buf.writeInt(reduceIds.length); for (int[] ids: reduceIds) { Encoders.IntArrays.encode(buf, ids); @@ -119,7 +119,7 @@ public static FetchShuffleBlocks decode(ByteBuf buf) { String appId = Encoders.Strings.decode(buf); String execId = Encoders.Strings.decode(buf); int shuffleId = buf.readInt(); - int[] mapIds = Encoders.IntArrays.decode(buf); + long[] mapIds = Encoders.LongArrays.decode(buf); int reduceIdsSize = buf.readInt(); int[][] reduceIds = new int[reduceIdsSize][]; for (int i = 0; i < reduceIdsSize; i++) { diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java index 649c471dc1679..ba40f4a45ac8f 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java @@ -29,7 +29,7 @@ public class BlockTransferMessagesSuite { public void serializeOpenShuffleBlocks() { checkSerializeDeserialize(new OpenBlocks("app-1", "exec-2", new String[] { "b1", "b2" })); checkSerializeDeserialize(new FetchShuffleBlocks( - "app-1", "exec-2", 0, new int[] {0, 1}, + "app-1", "exec-2", 0, new long[] {0, 1}, new int[][] {{ 0, 1 }, { 0, 1, 2 }})); checkSerializeDeserialize(new RegisterExecutor("app-1", "exec-2", new ExecutorShuffleInfo( new String[] { "/local1", "/local2" }, 32, "MyShuffleManager"))); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java index 9c623a70424b6..6a5d04b6f417b 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java @@ -101,7 +101,7 @@ public void testFetchShuffleBlocks() { when(blockResolver.getBlockData("app0", "exec1", 0, 0, 1)).thenReturn(blockMarkers[1]); FetchShuffleBlocks fetchShuffleBlocks = new FetchShuffleBlocks( - "app0", "exec1", 0, new int[] { 0 }, new int[][] {{ 0, 1 }}); + "app0", "exec1", 0, new long[] { 0 }, new int[][] {{ 0, 1 }}); checkOpenBlocksReceive(fetchShuffleBlocks, blockMarkers); verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 0); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index 66633cc7a3595..26a11672b8068 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -64,7 +64,7 @@ public void testFetchOne() { BlockFetchingListener listener = fetchBlocks( blocks, blockIds, - new FetchShuffleBlocks("app-id", "exec-id", 0, new int[] { 0 }, new int[][] {{ 0 }}), + new FetchShuffleBlocks("app-id", "exec-id", 0, new long[] { 0 }, new int[][] {{ 0 }}), conf); verify(listener).onBlockFetchSuccess("shuffle_0_0_0", blocks.get("shuffle_0_0_0")); @@ -100,7 +100,7 @@ public void testFetchThreeShuffleBlocks() { BlockFetchingListener listener = fetchBlocks( blocks, blockIds, - new FetchShuffleBlocks("app-id", "exec-id", 0, new int[] { 0 }, new int[][] {{ 0, 1, 2 }}), + new FetchShuffleBlocks("app-id", "exec-id", 0, new long[] { 0 }, new int[][] {{ 0, 1, 2 }}), conf); for (int i = 0; i < 3; i ++) { diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java index 804119cd06fa6..d30f3dad3c940 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java @@ -42,17 +42,13 @@ public interface ShuffleExecutorComponents { * partitioned bytes written by that map task. * * @param shuffleId Unique identifier for the shuffle the map task is a part of - * @param mapId Within the shuffle, the identifier of the map task - * @param mapTaskAttemptId Identifier of the task attempt. Multiple attempts of the same map task - * with the same (shuffleId, mapId) pair can be distinguished by the - * different values of mapTaskAttemptId. + * @param mapId An ID of the map task. The ID is unique within this Spark application. * @param numPartitions The number of partitions that will be written by the map task. Some of * these partitions may be empty. */ ShuffleMapOutputWriter createMapOutputWriter( int shuffleId, - int mapId, - long mapTaskAttemptId, + long mapId, int numPartitions) throws IOException; /** @@ -64,15 +60,11 @@ ShuffleMapOutputWriter createMapOutputWriter( * preserving an optimization in the local disk shuffle storage implementation. * * @param shuffleId Unique identifier for the shuffle the map task is a part of - * @param mapId Within the shuffle, the identifier of the map task - * @param mapTaskAttemptId Identifier of the task attempt. Multiple attempts of the same map task - * with the same (shuffleId, mapId) pair can be distinguished by the - * different values of mapTaskAttemptId. + * @param mapId An ID of the map task. The ID is unique within this Spark application. */ default Optional createSingleFileMapOutputWriter( int shuffleId, - int mapId, - long mapTaskAttemptId) throws IOException { + long mapId) throws IOException { return Optional.empty(); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java index 7fac00b7fbc3f..21abe9a57cd25 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java @@ -39,7 +39,7 @@ public interface ShuffleMapOutputWriter { * for the same partition within any given map task. The partition identifier will be in the * range of precisely 0 (inclusive) to numPartitions (exclusive), where numPartitions was * provided upon the creation of this map output writer via - * {@link ShuffleExecutorComponents#createMapOutputWriter(int, int, long, int)}. + * {@link ShuffleExecutorComponents#createMapOutputWriter(int, long, int)}. *

* Calls to this method will be invoked with monotonically increasing reducePartitionIds; each * call to this method will be called with a reducePartitionId that is strictly greater than diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index f75e932860f90..dc157eaa3b253 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -85,8 +85,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { private final Partitioner partitioner; private final ShuffleWriteMetricsReporter writeMetrics; private final int shuffleId; - private final int mapId; - private final long mapTaskAttemptId; + private final long mapId; private final Serializer serializer; private final ShuffleExecutorComponents shuffleExecutorComponents; @@ -106,8 +105,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { BypassMergeSortShuffleWriter( BlockManager blockManager, BypassMergeSortShuffleHandle handle, - int mapId, - long mapTaskAttemptId, + long mapId, SparkConf conf, ShuffleWriteMetricsReporter writeMetrics, ShuffleExecutorComponents shuffleExecutorComponents) { @@ -117,7 +115,6 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { this.blockManager = blockManager; final ShuffleDependency dep = handle.dependency(); this.mapId = mapId; - this.mapTaskAttemptId = mapTaskAttemptId; this.shuffleId = dep.shuffleId(); this.partitioner = dep.partitioner(); this.numPartitions = partitioner.numPartitions(); @@ -130,11 +127,12 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { public void write(Iterator> records) throws IOException { assert (partitionWriters == null); ShuffleMapOutputWriter mapOutputWriter = shuffleExecutorComponents - .createMapOutputWriter(shuffleId, mapId, mapTaskAttemptId, numPartitions); + .createMapOutputWriter(shuffleId, mapId, numPartitions); try { if (!records.hasNext()) { partitionLengths = mapOutputWriter.commitAllPartitions(); - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply( + blockManager.shuffleServerId(), partitionLengths, mapId); return; } final SerializerInstance serInstance = serializer.newInstance(); @@ -167,7 +165,8 @@ public void write(Iterator> records) throws IOException { } partitionLengths = writePartitionedData(mapOutputWriter); - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply( + blockManager.shuffleServerId(), partitionLengths, mapId); } catch (Exception e) { try { mapOutputWriter.abort(e); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 4d11abd36985e..d09282e61a9c7 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -78,7 +78,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final ShuffleWriteMetricsReporter writeMetrics; private final ShuffleExecutorComponents shuffleExecutorComponents; private final int shuffleId; - private final int mapId; + private final long mapId; private final TaskContext taskContext; private final SparkConf sparkConf; private final boolean transferToEnabled; @@ -109,7 +109,7 @@ public UnsafeShuffleWriter( BlockManager blockManager, TaskMemoryManager memoryManager, SerializedShuffleHandle handle, - int mapId, + long mapId, TaskContext taskContext, SparkConf sparkConf, ShuffleWriteMetricsReporter writeMetrics, @@ -228,7 +228,8 @@ void closeAndWriteOutput() throws IOException { } } } - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply( + blockManager.shuffleServerId(), partitionLengths, mapId); } @VisibleForTesting @@ -264,16 +265,11 @@ private long[] mergeSpills(SpillInfo[] spills) throws IOException { long[] partitionLengths; if (spills.length == 0) { final ShuffleMapOutputWriter mapWriter = shuffleExecutorComponents - .createMapOutputWriter( - shuffleId, - mapId, - taskContext.taskAttemptId(), - partitioner.numPartitions()); + .createMapOutputWriter(shuffleId, mapId, partitioner.numPartitions()); return mapWriter.commitAllPartitions(); } else if (spills.length == 1) { Optional maybeSingleFileWriter = - shuffleExecutorComponents.createSingleFileMapOutputWriter( - shuffleId, mapId, taskContext.taskAttemptId()); + shuffleExecutorComponents.createSingleFileMapOutputWriter(shuffleId, mapId); if (maybeSingleFileWriter.isPresent()) { // Here, we don't need to perform any metrics updates because the bytes written to this // output file would have already been counted as shuffle bytes written. @@ -298,11 +294,7 @@ private long[] mergeSpillsUsingStandardWriter(SpillInfo[] spills) throws IOExcep CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec); final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled(); final ShuffleMapOutputWriter mapWriter = shuffleExecutorComponents - .createMapOutputWriter( - shuffleId, - mapId, - taskContext.taskAttemptId(), - partitioner.numPartitions()); + .createMapOutputWriter(shuffleId, mapId, partitioner.numPartitions()); try { // There are multiple spills to merge, so none of these spill files' lengths were counted // towards our shuffle write count or shuffle write time. If we use the slow merge path, diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java index 47aa2e39fe29b..a0c7d3c248d48 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java @@ -61,8 +61,7 @@ public void initializeExecutor(String appId, String execId) { @Override public ShuffleMapOutputWriter createMapOutputWriter( int shuffleId, - int mapId, - long mapTaskAttemptId, + long mapId, int numPartitions) { if (blockResolver == null) { throw new IllegalStateException( @@ -75,8 +74,7 @@ public ShuffleMapOutputWriter createMapOutputWriter( @Override public Optional createSingleFileMapOutputWriter( int shuffleId, - int mapId, - long mapTaskAttemptId) { + long mapId) { if (blockResolver == null) { throw new IllegalStateException( "Executor components must be initialized before getting writers."); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java index 444cdc4270ecd..a6529fd76188a 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java @@ -48,7 +48,7 @@ public class LocalDiskShuffleMapOutputWriter implements ShuffleMapOutputWriter { LoggerFactory.getLogger(LocalDiskShuffleMapOutputWriter.class); private final int shuffleId; - private final int mapId; + private final long mapId; private final IndexShuffleBlockResolver blockResolver; private final long[] partitionLengths; private final int bufferSize; @@ -64,7 +64,7 @@ public class LocalDiskShuffleMapOutputWriter implements ShuffleMapOutputWriter { public LocalDiskShuffleMapOutputWriter( int shuffleId, - int mapId, + long mapId, int numPartitions, IndexShuffleBlockResolver blockResolver, SparkConf sparkConf) { diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java index 6b0a797a61b52..c8b41992a8919 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskSingleSpillMapOutputWriter.java @@ -29,12 +29,12 @@ public class LocalDiskSingleSpillMapOutputWriter implements SingleSpillShuffleMapOutputWriter { private final int shuffleId; - private final int mapId; + private final long mapId; private final IndexShuffleBlockResolver blockResolver; public LocalDiskSingleSpillMapOutputWriter( int shuffleId, - int mapId, + long mapId, IndexShuffleBlockResolver blockResolver) { this.shuffleId = shuffleId; this.mapId = mapId; diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index fb051a8c0db8e..f0ac9acd90156 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -93,7 +93,7 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( val shuffleId: Int = _rdd.context.newShuffleId() val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle( - shuffleId, _rdd.partitions.length, this) + shuffleId, this) _rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this)) } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index d878fc527791a..53329f0a937bd 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -41,7 +41,7 @@ import org.apache.spark.util._ * Helper class used by the [[MapOutputTrackerMaster]] to perform bookkeeping for a single * ShuffleMapStage. * - * This class maintains a mapping from mapIds to `MapStatus`. It also maintains a cache of + * This class maintains a mapping from map index to `MapStatus`. It also maintains a cache of * serialized map statuses in order to speed up tasks' requests for map output statuses. * * All public methods of this class are thread-safe. @@ -88,12 +88,12 @@ private class ShuffleStatus(numPartitions: Int) { * Register a map output. If there is already a registered location for the map output then it * will be replaced by the new location. */ - def addMapOutput(mapId: Int, status: MapStatus): Unit = synchronized { - if (mapStatuses(mapId) == null) { + def addMapOutput(mapIndex: Int, status: MapStatus): Unit = synchronized { + if (mapStatuses(mapIndex) == null) { _numAvailableOutputs += 1 invalidateSerializedMapOutputStatusCache() } - mapStatuses(mapId) = status + mapStatuses(mapIndex) = status } /** @@ -101,10 +101,10 @@ private class ShuffleStatus(numPartitions: Int) { * This is a no-op if there is no registered map output or if the registered output is from a * different block manager. */ - def removeMapOutput(mapId: Int, bmAddress: BlockManagerId): Unit = synchronized { - if (mapStatuses(mapId) != null && mapStatuses(mapId).location == bmAddress) { + def removeMapOutput(mapIndex: Int, bmAddress: BlockManagerId): Unit = synchronized { + if (mapStatuses(mapIndex) != null && mapStatuses(mapIndex).location == bmAddress) { _numAvailableOutputs -= 1 - mapStatuses(mapId) = null + mapStatuses(mapIndex) = null invalidateSerializedMapOutputStatusCache() } } @@ -131,10 +131,10 @@ private class ShuffleStatus(numPartitions: Int) { * remove outputs which are served by an external shuffle server (if one exists). */ def removeOutputsByFilter(f: (BlockManagerId) => Boolean): Unit = synchronized { - for (mapId <- 0 until mapStatuses.length) { - if (mapStatuses(mapId) != null && f(mapStatuses(mapId).location)) { + for (mapIndex <- 0 until mapStatuses.length) { + if (mapStatuses(mapIndex) != null && f(mapStatuses(mapIndex).location)) { _numAvailableOutputs -= 1 - mapStatuses(mapId) = null + mapStatuses(mapIndex) = null invalidateSerializedMapOutputStatusCache() } } @@ -282,8 +282,8 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging // For testing def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int) - : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { - getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1) + : Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1, useOldFetchProtocol = false) } /** @@ -292,11 +292,15 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * endPartition is excluded from the range). * * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, - * and the second item is a sequence of (shuffle block id, shuffle block size) tuples - * describing the shuffle blocks that are stored at that block manager. + * and the second item is a sequence of (shuffle block id, shuffle block size, map index) + * tuples describing the shuffle blocks that are stored at that block manager. */ - def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] + def getMapSizesByExecutorId( + shuffleId: Int, + startPartition: Int, + endPartition: Int, + useOldFetchProtocol: Boolean) + : Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] /** * Deletes map output status information for the specified shuffle stage. @@ -418,15 +422,15 @@ private[spark] class MapOutputTrackerMaster( } } - def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) { - shuffleStatuses(shuffleId).addMapOutput(mapId, status) + def registerMapOutput(shuffleId: Int, mapIndex: Int, status: MapStatus) { + shuffleStatuses(shuffleId).addMapOutput(mapIndex, status) } /** Unregister map output information of the given shuffle, mapper and block manager */ - def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { + def unregisterMapOutput(shuffleId: Int, mapIndex: Int, bmAddress: BlockManagerId) { shuffleStatuses.get(shuffleId) match { case Some(shuffleStatus) => - shuffleStatus.removeMapOutput(mapId, bmAddress) + shuffleStatus.removeMapOutput(mapIndex, bmAddress) incrementEpoch() case None => throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID") @@ -645,13 +649,18 @@ private[spark] class MapOutputTrackerMaster( // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. // This method is only called in local-mode. - def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { + def getMapSizesByExecutorId( + shuffleId: Int, + startPartition: Int, + endPartition: Int, + useOldFetchProtocol: Boolean) + : Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") shuffleStatuses.get(shuffleId) match { case Some (shuffleStatus) => shuffleStatus.withMapStatuses { statuses => - MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) + MapOutputTracker.convertMapStatuses( + shuffleId, startPartition, endPartition, statuses, useOldFetchProtocol) } case None => Iterator.empty @@ -685,12 +694,17 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr private val fetchingLock = new KeyLock[Int] // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. - override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) - : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { + override def getMapSizesByExecutorId( + shuffleId: Int, + startPartition: Int, + endPartition: Int, + useOldFetchProtocol: Boolean) + : Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition") val statuses = getStatuses(shuffleId) try { - MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses) + MapOutputTracker.convertMapStatuses( + shuffleId, startPartition, endPartition, statuses, useOldFetchProtocol) } catch { case e: MetadataFetchFailedException => // We experienced a fetch failure so our mapStatuses cache is outdated; clear it: @@ -832,19 +846,21 @@ private[spark] object MapOutputTracker extends Logging { * @param shuffleId Identifier for the shuffle * @param startPartition Start of map output partition ID range (included in range) * @param endPartition End of map output partition ID range (excluded from range) - * @param statuses List of map statuses, indexed by map ID. + * @param statuses List of map statuses, indexed by map partition index. + * @param useOldFetchProtocol Whether to use the old shuffle fetch protocol. * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, - * and the second item is a sequence of (shuffle block ID, shuffle block size) tuples - * describing the shuffle blocks that are stored at that block manager. + * and the second item is a sequence of (shuffle block id, shuffle block size, map index) + * tuples describing the shuffle blocks that are stored at that block manager. */ def convertMapStatuses( shuffleId: Int, startPartition: Int, endPartition: Int, - statuses: Array[MapStatus]): Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = { + statuses: Array[MapStatus], + useOldFetchProtocol: Boolean): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { assert (statuses != null) - val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long)]] - for ((status, mapId) <- statuses.iterator.zipWithIndex) { + val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long, Int)]] + for ((status, mapIndex) <- statuses.iterator.zipWithIndex) { if (status == null) { val errorMessage = s"Missing an output location for shuffle $shuffleId" logError(errorMessage) @@ -853,8 +869,15 @@ private[spark] object MapOutputTracker extends Logging { for (part <- startPartition until endPartition) { val size = status.getSizeForBlock(part) if (size != 0) { - splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) += - ((ShuffleBlockId(shuffleId, mapId, part), size)) + if (useOldFetchProtocol) { + // While we use the old shuffle fetch protocol, we use mapIndex as mapId in the + // ShuffleBlockId. + splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) += + ((ShuffleBlockId(shuffleId, mapIndex, part), size, mapIndex)) + } else { + splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) += + ((ShuffleBlockId(shuffleId, status.mapTaskId, part), size, mapIndex)) + } } } } diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 19f71a1dec296..b13028f868072 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -83,14 +83,15 @@ case object Resubmitted extends TaskFailedReason { case class FetchFailed( bmAddress: BlockManagerId, // Note that bmAddress can be null shuffleId: Int, - mapId: Int, + mapId: Long, + mapIndex: Int, reduceId: Int, message: String) extends TaskFailedReason { override def toErrorString: String = { val bmAddressString = if (bmAddress == null) "null" else bmAddress.toString - s"FetchFailed($bmAddressString, shuffleId=$shuffleId, mapId=$mapId, reduceId=$reduceId, " + - s"message=\n$message\n)" + s"FetchFailed($bmAddressString, shuffleId=$shuffleId, mapIndex=$mapIndex, " + + s"mapId=$mapId, reduceId=$reduceId, message=\n$message\n)" } /** diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 158a4b7cfa55a..2155dc6d3aa60 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1047,6 +1047,14 @@ package object config { .checkValue(v => v > 0, "The value should be a positive integer.") .createWithDefault(2000) + private[spark] val SHUFFLE_USE_OLD_FETCH_PROTOCOL = + ConfigBuilder("spark.shuffle.useOldFetchProtocol") + .doc("Whether to use the old protocol while doing the shuffle block fetching. " + + "It is only enabled while we need the compatibility in the scenario of new Spark " + + "version job fetching shuffle blocks from old version external shuffle service.") + .booleanConf + .createWithDefault(false) + private[spark] val MEMORY_MAP_LIMIT_FOR_TESTS = ConfigBuilder("spark.storage.memoryMapLimitForTests") .internal() diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 894234f70e05a..c9101e983bef7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1103,7 +1103,16 @@ private[spark] class DAGScheduler( private def submitMissingTasks(stage: Stage, jobId: Int) { logDebug("submitMissingTasks(" + stage + ")") - // First figure out the indexes of partition ids to compute. + // Before find missing partition, do the intermediate state clean work first. + // The operation here can make sure for the partially completed intermediate stage, + // `findMissingPartitions()` returns all partitions every time. + stage match { + case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable => + mapOutputTracker.unregisterAllMapOutput(sms.shuffleDep.shuffleId) + case _ => + } + + // Figure out the indexes of partition ids to compute. val partitionsToCompute: Seq[Int] = stage.findMissingPartitions() // Use the scheduling pool, job group, description, etc. from an ActiveJob associated @@ -1498,7 +1507,7 @@ private[spark] class DAGScheduler( } } - case FetchFailed(bmAddress, shuffleId, mapId, _, failureMessage) => + case FetchFailed(bmAddress, shuffleId, _, mapIndex, _, failureMessage) => val failedStage = stageIdToStage(task.stageId) val mapStage = shuffleIdToMapStage(shuffleId) @@ -1529,9 +1538,9 @@ private[spark] class DAGScheduler( // Mark all the map as broken in the map stage, to ensure retry all the tasks on // resubmitted stage attempt. mapOutputTracker.unregisterAllMapOutput(shuffleId) - } else if (mapId != -1) { + } else if (mapIndex != -1) { // Mark the map whose fetch failed as broken in the map stage - mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress) + mapOutputTracker.unregisterMapOutput(shuffleId, mapIndex, bmAddress) } if (failedStage.rdd.isBarrier()) { @@ -1573,7 +1582,7 @@ private[spark] class DAGScheduler( // Note that, if map stage is UNORDERED, we are fine. The shuffle partitioner is // guaranteed to be determinate, so the input data of the reducers will not change // even if the map tasks are re-tried. - if (mapStage.rdd.outputDeterministicLevel == DeterministicLevel.INDETERMINATE) { + if (mapStage.isIndeterminate) { // It's a little tricky to find all the succeeding stages of `mapStage`, because // each stage only know its parents not children. Here we traverse the stages from // the leaf nodes (the result stages of active jobs), and rollback all the stages @@ -1601,15 +1610,22 @@ private[spark] class DAGScheduler( activeJobs.foreach(job => collectStagesToRollback(job.finalStage :: Nil)) + // The stages will be rolled back after checking + val rollingBackStages = HashSet[Stage](mapStage) stagesToRollback.foreach { case mapStage: ShuffleMapStage => val numMissingPartitions = mapStage.findMissingPartitions().length if (numMissingPartitions < mapStage.numTasks) { - // TODO: support to rollback shuffle files. - // Currently the shuffle writing is "first write wins", so we can't re-run a - // shuffle map stage and overwrite existing shuffle files. We have to finish - // SPARK-8029 first. - abortStage(mapStage, generateErrorMessage(mapStage), None) + if (sc.getConf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)) { + val reason = "A shuffle map stage with indeterminate output was failed " + + "and retried. However, Spark can only do this while using the new " + + "shuffle block fetching protocol. Please check the config " + + "'spark.shuffle.useOldFetchProtocol', see more detail in " + + "SPARK-27665 and SPARK-25341." + abortStage(mapStage, reason, None) + } else { + rollingBackStages += mapStage + } } case resultStage: ResultStage if resultStage.activeJob.isDefined => @@ -1621,6 +1637,9 @@ private[spark] class DAGScheduler( case _ => } + logInfo(s"The shuffle map stage $mapStage with indeterminate output was failed, " + + s"we will roll back and rerun below stages which include itself and all its " + + s"indeterminate child stages: $rollingBackStages") } // We expect one executor failure to trigger many FetchFailures in rapid succession, diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 64f0a060a247c..c9d37c985d211 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -43,6 +43,11 @@ private[spark] sealed trait MapStatus { * necessary for correctness, since block fetchers are allowed to skip zero-size blocks. */ def getSizeForBlock(reduceId: Int): Long + + /** + * The unique ID of this shuffle map task, we use taskContext.taskAttemptId to fill this. + */ + def mapTaskId: Long } @@ -56,11 +61,14 @@ private[spark] object MapStatus { .map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS)) .getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get) - def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = { + def apply( + loc: BlockManagerId, + uncompressedSizes: Array[Long], + mapTaskId: Long): MapStatus = { if (uncompressedSizes.length > minPartitionsToUseHighlyCompressMapStatus) { - HighlyCompressedMapStatus(loc, uncompressedSizes) + HighlyCompressedMapStatus(loc, uncompressedSizes, mapTaskId) } else { - new CompressedMapStatus(loc, uncompressedSizes) + new CompressedMapStatus(loc, uncompressedSizes, mapTaskId) } } @@ -100,16 +108,19 @@ private[spark] object MapStatus { * * @param loc location where the task is being executed. * @param compressedSizes size of the blocks, indexed by reduce partition id. + * @param _mapTaskId unique task id for the task */ private[spark] class CompressedMapStatus( private[this] var loc: BlockManagerId, - private[this] var compressedSizes: Array[Byte]) + private[this] var compressedSizes: Array[Byte], + private[this] var _mapTaskId: Long) extends MapStatus with Externalizable { - protected def this() = this(null, null.asInstanceOf[Array[Byte]]) // For deserialization only + // For deserialization only + protected def this() = this(null, null.asInstanceOf[Array[Byte]], -1) - def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) { - this(loc, uncompressedSizes.map(MapStatus.compressSize)) + def this(loc: BlockManagerId, uncompressedSizes: Array[Long], mapTaskId: Long) { + this(loc, uncompressedSizes.map(MapStatus.compressSize), mapTaskId) } override def location: BlockManagerId = loc @@ -118,10 +129,13 @@ private[spark] class CompressedMapStatus( MapStatus.decompressSize(compressedSizes(reduceId)) } + override def mapTaskId: Long = _mapTaskId + override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { loc.writeExternal(out) out.writeInt(compressedSizes.length) out.write(compressedSizes) + out.writeLong(_mapTaskId) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { @@ -129,6 +143,7 @@ private[spark] class CompressedMapStatus( val len = in.readInt() compressedSizes = new Array[Byte](len) in.readFully(compressedSizes) + _mapTaskId = in.readLong() } } @@ -142,20 +157,23 @@ private[spark] class CompressedMapStatus( * @param emptyBlocks a bitmap tracking which blocks are empty * @param avgSize average size of the non-empty and non-huge blocks * @param hugeBlockSizes sizes of huge blocks by their reduceId. + * @param _mapTaskId unique task id for the task */ private[spark] class HighlyCompressedMapStatus private ( private[this] var loc: BlockManagerId, private[this] var numNonEmptyBlocks: Int, private[this] var emptyBlocks: RoaringBitmap, private[this] var avgSize: Long, - private[this] var hugeBlockSizes: scala.collection.Map[Int, Byte]) + private[this] var hugeBlockSizes: scala.collection.Map[Int, Byte], + private[this] var _mapTaskId: Long) extends MapStatus with Externalizable { // loc could be null when the default constructor is called during deserialization - require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0 || numNonEmptyBlocks == 0, + require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0 + || numNonEmptyBlocks == 0 || _mapTaskId > 0, "Average size can only be zero for map stages that produced no output") - protected def this() = this(null, -1, null, -1, null) // For deserialization only + protected def this() = this(null, -1, null, -1, null, -1) // For deserialization only override def location: BlockManagerId = loc @@ -171,6 +189,8 @@ private[spark] class HighlyCompressedMapStatus private ( } } + override def mapTaskId: Long = _mapTaskId + override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { loc.writeExternal(out) emptyBlocks.writeExternal(out) @@ -180,6 +200,7 @@ private[spark] class HighlyCompressedMapStatus private ( out.writeInt(kv._1) out.writeByte(kv._2) } + out.writeLong(_mapTaskId) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { @@ -195,11 +216,15 @@ private[spark] class HighlyCompressedMapStatus private ( hugeBlockSizesImpl(block) = size } hugeBlockSizes = hugeBlockSizesImpl + _mapTaskId = in.readLong() } } private[spark] object HighlyCompressedMapStatus { - def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = { + def apply( + loc: BlockManagerId, + uncompressedSizes: Array[Long], + mapTaskId: Long): HighlyCompressedMapStatus = { // We must keep track of which blocks are empty so that we don't report a zero-sized // block as being non-empty (or vice-versa) when using the average block size. var i = 0 @@ -240,6 +265,6 @@ private[spark] object HighlyCompressedMapStatus { emptyBlocks.trim() emptyBlocks.runOptimize() new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize, - hugeBlockSizes) + hugeBlockSizes, mapTaskId) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 710f5eb211dde..06e5d8ab0302a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -91,7 +91,7 @@ private[spark] class ShuffleMapTask( val rdd = rddAndDep._1 val dep = rddAndDep._2 - dep.shuffleWriterProcessor.write(rdd, dep, partitionId, context, partition) + dep.shuffleWriterProcessor.write(rdd, dep, context, partition) } override def preferredLocations: Seq[TaskLocation] = preferredLocs diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 26cca334d3bd5..a9f72eae71368 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.HashSet import org.apache.spark.executor.TaskMetrics import org.apache.spark.internal.Logging -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{DeterministicLevel, RDD} import org.apache.spark.util.CallSite /** @@ -116,4 +116,8 @@ private[scheduler] abstract class Stage( /** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */ def findMissingPartitions(): Seq[Int] + + def isIndeterminate: Boolean = { + rdd.outputDeterministicLevel == DeterministicLevel.INDETERMINATE + } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala b/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala index 04e4cf88d7063..6fe183c078089 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala @@ -24,6 +24,5 @@ import org.apache.spark.ShuffleDependency */ private[spark] class BaseShuffleHandle[K, V, C]( shuffleId: Int, - val numMaps: Int, val dependency: ShuffleDependency[K, V, C]) extends ShuffleHandle(shuffleId) diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 4329824b1b627..8a0e84d901c2f 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -47,7 +47,8 @@ private[spark] class BlockStoreShuffleReader[K, C]( context, blockManager.blockStoreClient, blockManager, - mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition), + mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition, + SparkEnv.get.conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)), serializerManager.wrapStream, // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024, diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala index 265a8acfa8d61..6509a04dc4893 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala @@ -35,7 +35,8 @@ import org.apache.spark.util.Utils private[spark] class FetchFailedException( bmAddress: BlockManagerId, shuffleId: Int, - mapId: Int, + mapId: Long, + mapIndex: Int, reduceId: Int, message: String, cause: Throwable = null) @@ -44,10 +45,11 @@ private[spark] class FetchFailedException( def this( bmAddress: BlockManagerId, shuffleId: Int, - mapId: Int, + mapTaskId: Long, + mapIndex: Int, reduceId: Int, cause: Throwable) { - this(bmAddress, shuffleId, mapId, reduceId, cause.getMessage, cause) + this(bmAddress, shuffleId, mapTaskId, mapIndex, reduceId, cause.getMessage, cause) } // SPARK-19276. We set the fetch failure in the task context, so that even if there is user-code @@ -56,8 +58,8 @@ private[spark] class FetchFailedException( // because the TaskContext is not defined in some test cases. Option(TaskContext.get()).map(_.setFetchFailed(this)) - def toTaskFailedReason: TaskFailedReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId, - Utils.exceptionString(this)) + def toTaskFailedReason: TaskFailedReason = FetchFailed( + bmAddress, shuffleId, mapId, mapIndex, reduceId, Utils.exceptionString(this)) } /** @@ -67,4 +69,4 @@ private[spark] class MetadataFetchFailedException( shuffleId: Int, reduceId: Int, message: String) - extends FetchFailedException(null, shuffleId, -1, reduceId, message) + extends FetchFailedException(null, shuffleId, -1L, -1, reduceId, message) diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index d3f1c7ec1bbee..332164a7be3e7 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -51,18 +51,18 @@ private[spark] class IndexShuffleBlockResolver( private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") - def getDataFile(shuffleId: Int, mapId: Int): File = { + def getDataFile(shuffleId: Int, mapId: Long): File = { blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) } - private def getIndexFile(shuffleId: Int, mapId: Int): File = { + private def getIndexFile(shuffleId: Int, mapId: Long): File = { blockManager.diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) } /** * Remove data file and index file that contain the output data from one map. */ - def removeDataByMap(shuffleId: Int, mapId: Int): Unit = { + def removeDataByMap(shuffleId: Int, mapId: Long): Unit = { var file = getDataFile(shuffleId, mapId) if (file.exists()) { if (!file.delete()) { @@ -135,7 +135,7 @@ private[spark] class IndexShuffleBlockResolver( */ def writeIndexFileAndCommit( shuffleId: Int, - mapId: Int, + mapId: Long, lengths: Array[Long], dataTmp: File): Unit = { val indexFile = getIndexFile(shuffleId, mapId) diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala index 18a743fbfa6fc..a717ef242ea7c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala @@ -34,13 +34,12 @@ private[spark] trait ShuffleManager { */ def registerShuffle[K, V, C]( shuffleId: Int, - numMaps: Int, dependency: ShuffleDependency[K, V, C]): ShuffleHandle /** Get a writer for a given partition. Called on executors by map tasks. */ def getWriter[K, V]( handle: ShuffleHandle, - mapId: Int, + mapId: Long, context: TaskContext, metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala index 5b0c7e9f2b0b4..f222200a7816c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala @@ -44,7 +44,6 @@ private[spark] class ShuffleWriteProcessor extends Serializable with Logging { def write( rdd: RDD[_], dep: ShuffleDependency[_, _, _], - partitionId: Int, context: TaskContext, partition: Partition): MapStatus = { var writer: ShuffleWriter[Any, Any] = null @@ -52,7 +51,7 @@ private[spark] class ShuffleWriteProcessor extends Serializable with Logging { val manager = SparkEnv.get.shuffleManager writer = manager.getWriter[Any, Any]( dep.shuffleHandle, - partitionId, + context.taskAttemptId(), context, createMetricsReporter(context)) writer.write( diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index cbdc2c886dd9f..d96bcb3d073df 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -24,6 +24,7 @@ import org.apache.spark.internal.{config, Logging} import org.apache.spark.shuffle._ import org.apache.spark.shuffle.api.{ShuffleDataIO, ShuffleExecutorComponents} import org.apache.spark.util.Utils +import org.apache.spark.util.collection.OpenHashSet /** * In sort-based shuffle, incoming records are sorted according to their target partition ids, then @@ -79,9 +80,9 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager } /** - * A mapping from shuffle ids to the number of mappers producing output for those shuffles. + * A mapping from shuffle ids to the task ids of mappers producing output for those shuffles. */ - private[this] val numMapsForShuffle = new ConcurrentHashMap[Int, Int]() + private[this] val taskIdMapsForShuffle = new ConcurrentHashMap[Int, OpenHashSet[Long]]() private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf) @@ -92,7 +93,6 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager */ override def registerShuffle[K, V, C]( shuffleId: Int, - numMaps: Int, dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { if (SortShuffleWriter.shouldBypassMergeSort(conf, dependency)) { // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't @@ -101,14 +101,14 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager // together the spilled files, which would happen with the normal code path. The downside is // having multiple files open at a time and thus more memory allocated to buffers. new BypassMergeSortShuffleHandle[K, V]( - shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + shuffleId, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) { // Otherwise, try to buffer map outputs in a serialized form, since this is more efficient: new SerializedShuffleHandle[K, V]( - shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) + shuffleId, dependency.asInstanceOf[ShuffleDependency[K, V, V]]) } else { // Otherwise, buffer map outputs in a deserialized form: - new BaseShuffleHandle(shuffleId, numMaps, dependency) + new BaseShuffleHandle(shuffleId, dependency) } } @@ -130,11 +130,12 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager /** Get a writer for a given partition. Called on executors by map tasks. */ override def getWriter[K, V]( handle: ShuffleHandle, - mapId: Int, + mapId: Long, context: TaskContext, metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { - numMapsForShuffle.putIfAbsent( - handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps) + val mapTaskIds = taskIdMapsForShuffle.computeIfAbsent( + handle.shuffleId, _ => new OpenHashSet[Long](16)) + mapTaskIds.synchronized { mapTaskIds.add(context.taskAttemptId()) } val env = SparkEnv.get handle match { case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] => @@ -152,7 +153,6 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager env.blockManager, bypassMergeSortHandle, mapId, - context.taskAttemptId(), env.conf, metrics, shuffleExecutorComponents) @@ -164,9 +164,9 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager /** Remove a shuffle's metadata from the ShuffleManager. */ override def unregisterShuffle(shuffleId: Int): Boolean = { - Option(numMapsForShuffle.remove(shuffleId)).foreach { numMaps => - (0 until numMaps).foreach { mapId => - shuffleBlockResolver.removeDataByMap(shuffleId, mapId) + Option(taskIdMapsForShuffle.remove(shuffleId)).foreach { mapTaskIds => + mapTaskIds.iterator.foreach { mapTaskId => + shuffleBlockResolver.removeDataByMap(shuffleId, mapTaskId) } } true @@ -231,9 +231,8 @@ private[spark] object SortShuffleManager extends Logging { */ private[spark] class SerializedShuffleHandle[K, V]( shuffleId: Int, - numMaps: Int, dependency: ShuffleDependency[K, V, V]) - extends BaseShuffleHandle(shuffleId, numMaps, dependency) { + extends BaseShuffleHandle(shuffleId, dependency) { } /** @@ -242,7 +241,6 @@ private[spark] class SerializedShuffleHandle[K, V]( */ private[spark] class BypassMergeSortShuffleHandle[K, V]( shuffleId: Int, - numMaps: Int, dependency: ShuffleDependency[K, V, V]) - extends BaseShuffleHandle(shuffleId, numMaps, dependency) { + extends BaseShuffleHandle(shuffleId, dependency) { } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index a781b16252432..a391bdf2db44e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -27,7 +27,7 @@ import org.apache.spark.util.collection.ExternalSorter private[spark] class SortShuffleWriter[K, V, C]( shuffleBlockResolver: IndexShuffleBlockResolver, handle: BaseShuffleHandle[K, V, C], - mapId: Int, + mapId: Long, context: TaskContext, shuffleExecutorComponents: ShuffleExecutorComponents) extends ShuffleWriter[K, V] with Logging { @@ -65,10 +65,10 @@ private[spark] class SortShuffleWriter[K, V, C]( // because it just opens a single file, so is typically too fast to measure accurately // (see SPARK-3570). val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter( - dep.shuffleId, mapId, context.taskAttemptId(), dep.partitioner.numPartitions) + dep.shuffleId, mapId, dep.partitioner.numPartitions) sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter) val partitionLengths = mapOutputWriter.commitAllPartitions() - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) + mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId) } /** Close this writer, passing along whether the map completed */ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 7ac2c71c18eb3..9c5b7f64e7abe 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -52,17 +52,17 @@ case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId { // Format of the shuffle block ids (including data and index) should be kept in sync with // org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getBlockData(). @DeveloperApi -case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { +case class ShuffleBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId { override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId } @DeveloperApi -case class ShuffleDataBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { +case class ShuffleDataBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId { override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".data" } @DeveloperApi -case class ShuffleIndexBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { +case class ShuffleIndexBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId { override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index" } @@ -117,11 +117,11 @@ object BlockId { case RDD(rddId, splitIndex) => RDDBlockId(rddId.toInt, splitIndex.toInt) case SHUFFLE(shuffleId, mapId, reduceId) => - ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt) + ShuffleBlockId(shuffleId.toInt, mapId.toLong, reduceId.toInt) case SHUFFLE_DATA(shuffleId, mapId, reduceId) => - ShuffleDataBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt) + ShuffleDataBlockId(shuffleId.toInt, mapId.toLong, reduceId.toInt) case SHUFFLE_INDEX(shuffleId, mapId, reduceId) => - ShuffleIndexBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt) + ShuffleIndexBlockId(shuffleId.toInt, mapId.toLong, reduceId.toInt) case BROADCAST(broadcastId, field) => BroadcastBlockId(broadcastId.toLong, field.stripPrefix("_")) case TASKRESULT(taskId) => diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 8a6c199423506..5fce358fae37f 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -49,9 +49,10 @@ import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} * @param shuffleClient [[BlockStoreClient]] for fetching remote blocks * @param blockManager [[BlockManager]] for reading local blocks * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. - * For each block we also require the size (in bytes as a long field) in - * order to throttle the memory usage. Note that zero-sized blocks are - * already excluded, which happened in + * For each block we also require two info: 1. the size (in bytes as a long + * field) in order to throttle the memory usage; 2. the mapIndex for this + * block, which indicate the index in the map stage. + * Note that zero-sized blocks are already excluded, which happened in * [[org.apache.spark.MapOutputTracker.convertMapStatuses]]. * @param streamWrapper A function to wrap the returned input stream. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. @@ -67,7 +68,7 @@ final class ShuffleBlockFetcherIterator( context: TaskContext, shuffleClient: BlockStoreClient, blockManager: BlockManager, - blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long)])], + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], streamWrapper: (BlockId, InputStream) => InputStream, maxBytesInFlight: Long, maxReqsInFlight: Int, @@ -97,7 +98,7 @@ final class ShuffleBlockFetcherIterator( private[this] val startTimeNs = System.nanoTime() /** Local blocks to fetch, excluding zero-sized blocks. */ - private[this] val localBlocks = scala.collection.mutable.LinkedHashSet[BlockId]() + private[this] val localBlocks = scala.collection.mutable.LinkedHashSet[(BlockId, Int)]() /** Remote blocks to fetch, excluding zero-sized blocks. */ private[this] val remoteBlocks = new HashSet[BlockId]() @@ -199,7 +200,7 @@ final class ShuffleBlockFetcherIterator( while (iter.hasNext) { val result = iter.next() result match { - case SuccessFetchResult(_, address, _, buf, _) => + case SuccessFetchResult(_, _, address, _, buf, _) => if (address != blockManager.blockManagerId) { shuffleMetrics.incRemoteBytesRead(buf.size) if (buf.isInstanceOf[FileSegmentManagedBuffer]) { @@ -224,10 +225,12 @@ final class ShuffleBlockFetcherIterator( bytesInFlight += req.size reqsInFlight += 1 - // so we can look up the size of each blockID - val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap - val remainingBlocks = new HashSet[String]() ++= sizeMap.keys - val blockIds = req.blocks.map(_._1.toString) + // so we can look up the block info of each blockID + val infoMap = req.blocks.map { + case FetchBlockInfo(blockId, size, mapIndex) => (blockId.toString, (size, mapIndex)) + }.toMap + val remainingBlocks = new HashSet[String]() ++= infoMap.keys + val blockIds = req.blocks.map(_.blockId.toString) val address = req.address val blockFetchingListener = new BlockFetchingListener { @@ -240,8 +243,8 @@ final class ShuffleBlockFetcherIterator( // This needs to be released after use. buf.retain() remainingBlocks -= blockId - results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf, - remainingBlocks.isEmpty)) + results.put(new SuccessFetchResult(BlockId(blockId), infoMap(blockId)._2, + address, infoMap(blockId)._1, buf, remainingBlocks.isEmpty)) logDebug("remainingBlocks: " + remainingBlocks) } } @@ -250,7 +253,7 @@ final class ShuffleBlockFetcherIterator( override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) - results.put(new FailureFetchResult(BlockId(blockId), address, e)) + results.put(new FailureFetchResult(BlockId(blockId), infoMap(blockId)._2, address, e)) } } @@ -283,28 +286,28 @@ final class ShuffleBlockFetcherIterator( for ((address, blockInfos) <- blocksByAddress) { if (address.executorId == blockManager.blockManagerId.executorId) { blockInfos.find(_._2 <= 0) match { - case Some((blockId, size)) if size < 0 => + case Some((blockId, size, _)) if size < 0 => throw new BlockException(blockId, "Negative block size " + size) - case Some((blockId, size)) if size == 0 => + case Some((blockId, size, _)) if size == 0 => throw new BlockException(blockId, "Zero-sized blocks should be excluded.") case None => // do nothing. } - localBlocks ++= blockInfos.map(_._1) + localBlocks ++= blockInfos.map(info => (info._1, info._3)) localBlockBytes += blockInfos.map(_._2).sum numBlocksToFetch += localBlocks.size } else { val iterator = blockInfos.iterator var curRequestSize = 0L - var curBlocks = new ArrayBuffer[(BlockId, Long)] + var curBlocks = new ArrayBuffer[FetchBlockInfo] while (iterator.hasNext) { - val (blockId, size) = iterator.next() + val (blockId, size, mapIndex) = iterator.next() remoteBlockBytes += size if (size < 0) { throw new BlockException(blockId, "Negative block size " + size) } else if (size == 0) { throw new BlockException(blockId, "Zero-sized blocks should be excluded.") } else { - curBlocks += ((blockId, size)) + curBlocks += FetchBlockInfo(blockId, size, mapIndex) remoteBlocks += blockId numBlocksToFetch += 1 curRequestSize += size @@ -315,7 +318,7 @@ final class ShuffleBlockFetcherIterator( remoteRequests += new FetchRequest(address, curBlocks) logDebug(s"Creating fetch request of $curRequestSize at $address " + s"with ${curBlocks.size} blocks") - curBlocks = new ArrayBuffer[(BlockId, Long)] + curBlocks = new ArrayBuffer[FetchBlockInfo] curRequestSize = 0 } } @@ -341,13 +344,13 @@ final class ShuffleBlockFetcherIterator( logDebug(s"Start fetching local blocks: ${localBlocks.mkString(", ")}") val iter = localBlocks.iterator while (iter.hasNext) { - val blockId = iter.next() + val (blockId, mapIndex) = iter.next() try { val buf = blockManager.getBlockData(blockId) shuffleMetrics.incLocalBlocksFetched(1) shuffleMetrics.incLocalBytesRead(buf.size) buf.retain() - results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, + results.put(new SuccessFetchResult(blockId, mapIndex, blockManager.blockManagerId, buf.size(), buf, false)) } catch { // If we see an exception, stop immediately. @@ -360,7 +363,7 @@ final class ShuffleBlockFetcherIterator( logError("Error occurred while fetching local blocks, " + ce.getMessage) case ex: Exception => logError("Error occurred while fetching local blocks", ex) } - results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e)) + results.put(new FailureFetchResult(blockId, mapIndex, blockManager.blockManagerId, e)) return } } @@ -420,7 +423,7 @@ final class ShuffleBlockFetcherIterator( shuffleMetrics.incFetchWaitTime(fetchWaitTime) result match { - case r @ SuccessFetchResult(blockId, address, size, buf, isNetworkReqDone) => + case r @ SuccessFetchResult(blockId, mapIndex, address, size, buf, isNetworkReqDone) => if (address != blockManager.blockManagerId) { numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 shuffleMetrics.incRemoteBytesRead(buf.size) @@ -429,7 +432,7 @@ final class ShuffleBlockFetcherIterator( } shuffleMetrics.incRemoteBlocksFetched(1) } - if (!localBlocks.contains(blockId)) { + if (!localBlocks.contains((blockId, mapIndex))) { bytesInFlight -= size } if (isNetworkReqDone) { @@ -453,7 +456,7 @@ final class ShuffleBlockFetcherIterator( // since the last call. val msg = s"Received a zero-size buffer for block $blockId from $address " + s"(expectedApproxSize = $size, isNetworkReqDone=$isNetworkReqDone)" - throwFetchFailedException(blockId, address, new IOException(msg)) + throwFetchFailedException(blockId, mapIndex, address, new IOException(msg)) } val in = try { @@ -469,7 +472,7 @@ final class ShuffleBlockFetcherIterator( case e: IOException => logError("Failed to create input stream from local block", e) } buf.release() - throwFetchFailedException(blockId, address, e) + throwFetchFailedException(blockId, mapIndex, address, e) } try { input = streamWrapper(blockId, in) @@ -487,11 +490,12 @@ final class ShuffleBlockFetcherIterator( buf.release() if (buf.isInstanceOf[FileSegmentManagedBuffer] || corruptedBlocks.contains(blockId)) { - throwFetchFailedException(blockId, address, e) + throwFetchFailedException(blockId, mapIndex, address, e) } else { logWarning(s"got an corrupted block $blockId from $address, fetch again", e) corruptedBlocks += blockId - fetchRequests += FetchRequest(address, Array((blockId, size))) + fetchRequests += FetchRequest( + address, Array(FetchBlockInfo(blockId, size, mapIndex))) result = null } } finally { @@ -503,8 +507,8 @@ final class ShuffleBlockFetcherIterator( } } - case FailureFetchResult(blockId, address, e) => - throwFetchFailedException(blockId, address, e) + case FailureFetchResult(blockId, mapIndex, address, e) => + throwFetchFailedException(blockId, mapIndex, address, e) } // Send fetch requests up to maxBytesInFlight @@ -517,6 +521,7 @@ final class ShuffleBlockFetcherIterator( input, this, currentResult.blockId, + currentResult.mapIndex, currentResult.address, detectCorrupt && streamCompressedOrEncrypted)) } @@ -583,11 +588,12 @@ final class ShuffleBlockFetcherIterator( private[storage] def throwFetchFailedException( blockId: BlockId, + mapIndex: Int, address: BlockManagerId, e: Throwable) = { blockId match { case ShuffleBlockId(shufId, mapId, reduceId) => - throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e) + throw new FetchFailedException(address, shufId, mapId, mapIndex, reduceId, e) case _ => throw new SparkException( "Failed to get block " + blockId + ", which is not a shuffle block", e) @@ -604,6 +610,7 @@ private class BufferReleasingInputStream( private[storage] val delegate: InputStream, private val iterator: ShuffleBlockFetcherIterator, private val blockId: BlockId, + private val mapIndex: Int, private val address: BlockManagerId, private val detectCorruption: Boolean) extends InputStream { @@ -615,7 +622,7 @@ private class BufferReleasingInputStream( } catch { case e: IOException if detectCorruption => IOUtils.closeQuietly(this) - iterator.throwFetchFailedException(blockId, address, e) + iterator.throwFetchFailedException(blockId, mapIndex, address, e) } } @@ -637,7 +644,7 @@ private class BufferReleasingInputStream( } catch { case e: IOException if detectCorruption => IOUtils.closeQuietly(this) - iterator.throwFetchFailedException(blockId, address, e) + iterator.throwFetchFailedException(blockId, mapIndex, address, e) } } @@ -649,7 +656,7 @@ private class BufferReleasingInputStream( } catch { case e: IOException if detectCorruption => IOUtils.closeQuietly(this) - iterator.throwFetchFailedException(blockId, address, e) + iterator.throwFetchFailedException(blockId, mapIndex, address, e) } } @@ -659,7 +666,7 @@ private class BufferReleasingInputStream( } catch { case e: IOException if detectCorruption => IOUtils.closeQuietly(this) - iterator.throwFetchFailedException(blockId, address, e) + iterator.throwFetchFailedException(blockId, mapIndex, address, e) } } @@ -690,14 +697,25 @@ private class ShuffleFetchCompletionListener(var data: ShuffleBlockFetcherIterat private[storage] object ShuffleBlockFetcherIterator { + /** + * The block information to fetch used in FetchRequest. + * @param blockId block id + * @param size estimated size of the block. Note that this is NOT the exact bytes. + * Size of remote block is used to calculate bytesInFlight. + * @param mapIndex the mapIndex for this block, which indicate the index in the map stage. + */ + private[storage] case class FetchBlockInfo( + blockId: BlockId, + size: Long, + mapIndex: Int) + /** * A request to fetch blocks from a remote BlockManager. * @param address remote BlockManager to fetch from. - * @param blocks Sequence of tuple, where the first element is the block id, - * and the second element is the estimated size, used to calculate bytesInFlight. + * @param blocks Sequence of the information for blocks to fetch from the same address. */ - case class FetchRequest(address: BlockManagerId, blocks: Seq[(BlockId, Long)]) { - val size = blocks.map(_._2).sum + case class FetchRequest(address: BlockManagerId, blocks: Seq[FetchBlockInfo]) { + val size = blocks.map(_.size).sum } /** @@ -711,6 +729,7 @@ object ShuffleBlockFetcherIterator { /** * Result of a fetch from a remote block successfully. * @param blockId block id + * @param mapIndex the mapIndex for this block, which indicate the index in the map stage. * @param address BlockManager that the block was fetched from. * @param size estimated size of the block. Note that this is NOT the exact bytes. * Size of remote block is used to calculate bytesInFlight. @@ -719,6 +738,7 @@ object ShuffleBlockFetcherIterator { */ private[storage] case class SuccessFetchResult( blockId: BlockId, + mapIndex: Int, address: BlockManagerId, size: Long, buf: ManagedBuffer, @@ -730,11 +750,13 @@ object ShuffleBlockFetcherIterator { /** * Result of a fetch from a remote block unsuccessfully. * @param blockId block id + * @param mapIndex the mapIndex for this block, which indicate the index in the map stage * @param address BlockManager that the block was attempted to be fetched from * @param e the failure exception */ private[storage] case class FailureFetchResult( blockId: BlockId, + mapIndex: Int, address: BlockManagerId, e: Throwable) extends FetchResult diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 73ef80980e73f..353590d201bc5 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -420,6 +420,7 @@ private[spark] object JsonProtocol { ("Block Manager Address" -> blockManagerAddress) ~ ("Shuffle ID" -> fetchFailed.shuffleId) ~ ("Map ID" -> fetchFailed.mapId) ~ + ("Map Index" -> fetchFailed.mapIndex) ~ ("Reduce ID" -> fetchFailed.reduceId) ~ ("Message" -> fetchFailed.message) case exceptionFailure: ExceptionFailure => @@ -974,10 +975,11 @@ private[spark] object JsonProtocol { case `fetchFailed` => val blockManagerAddress = blockManagerIdFromJson(json \ "Block Manager Address") val shuffleId = (json \ "Shuffle ID").extract[Int] - val mapId = (json \ "Map ID").extract[Int] + val mapId = (json \ "Map ID").extract[Long] + val mapIndex = (json \ "Map Index").extract[Int] val reduceId = (json \ "Reduce ID").extract[Int] val message = jsonOption(json \ "Message").map(_.extract[String]) - new FetchFailed(blockManagerAddress, shuffleId, mapId, reduceId, + new FetchFailed(blockManagerAddress, shuffleId, mapId, mapIndex, reduceId, message.getOrElse("Unknown reason")) case `exceptionFailure` => val className = (json \ "Class Name").extract[String] diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 7a822e137e556..6fecfbaca8416 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -727,7 +727,7 @@ private[spark] class ExternalSorter[K, V, C]( */ def writePartitionedMapOutput( shuffleId: Int, - mapId: Int, + mapId: Long, mapOutputWriter: ShuffleMapOutputWriter): Unit = { var nextPartitionId = 0 if (spills.isEmpty) { diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 1022111897a49..a901ae62e8cd8 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -135,7 +135,7 @@ public void setUp() throws IOException { ); }); - when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile); + when(shuffleBlockResolver.getDataFile(anyInt(), anyLong())).thenReturn(mergedOutputFile); Answer renameTempAnswer = invocationOnMock -> { partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2]; @@ -153,11 +153,11 @@ public void setUp() throws IOException { doAnswer(renameTempAnswer) .when(shuffleBlockResolver) - .writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), any(File.class)); + .writeIndexFileAndCommit(anyInt(), anyLong(), any(long[].class), any(File.class)); doAnswer(renameTempAnswer) .when(shuffleBlockResolver) - .writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), eq(null)); + .writeIndexFileAndCommit(anyInt(), anyLong(), any(long[].class), eq(null)); when(diskBlockManager.createTempShuffleBlock()).thenAnswer(invocationOnMock -> { TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID()); @@ -176,9 +176,9 @@ private UnsafeShuffleWriter createWriter(boolean transferToEnabled) { conf.set("spark.file.transferTo", String.valueOf(transferToEnabled)); return new UnsafeShuffleWriter( blockManager, - taskMemoryManager, - new SerializedShuffleHandle<>(0, 1, shuffleDep), - 0, // map id + taskMemoryManager, + new SerializedShuffleHandle<>(0, shuffleDep), + 0L, // map id taskContext, conf, taskContext.taskMetrics().shuffleWriteMetrics(), @@ -536,8 +536,8 @@ public void testPeakMemoryUsed() throws Exception { final UnsafeShuffleWriter writer = new UnsafeShuffleWriter( blockManager, taskMemoryManager, - new SerializedShuffleHandle<>(0, 1, shuffleDep), - 0, // map id + new SerializedShuffleHandle<>(0, shuffleDep), + 0L, // map id taskContext, conf, taskContext.taskMetrics().shuffleWriteMetrics(), diff --git a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala index e7eef8ec5150c..8433a6f52ac7a 100644 --- a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala @@ -142,6 +142,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { sid, taskContext.partitionId(), taskContext.partitionId(), + taskContext.partitionId(), "simulated fetch failure") } else { iter diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index d86975964b558..da2ba2165bb0c 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -64,14 +64,15 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(1000L, 10000L))) + Array(1000L, 10000L), 5)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(10000L, 1000L))) + Array(10000L, 1000L), 6)) val statuses = tracker.getMapSizesByExecutorId(10, 0) assert(statuses.toSet === - Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))), - (BlockManagerId("b", "hostB", 1000), ArrayBuffer((ShuffleBlockId(10, 1, 0), size10000)))) - .toSet) + Seq((BlockManagerId("a", "hostA", 1000), + ArrayBuffer((ShuffleBlockId(10, 5, 0), size1000, 0))), + (BlockManagerId("b", "hostB", 1000), + ArrayBuffer((ShuffleBlockId(10, 6, 0), size10000, 1)))).toSet) assert(0 == tracker.getNumCachedSerializedBroadcast) tracker.stop() rpcEnv.shutdown() @@ -86,9 +87,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(compressedSize1000, compressedSize10000))) + Array(compressedSize1000, compressedSize10000), 5)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(compressedSize10000, compressedSize1000))) + Array(compressedSize10000, compressedSize1000), 6)) assert(tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty) assert(0 == tracker.getNumCachedSerializedBroadcast) @@ -109,9 +110,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(compressedSize1000, compressedSize1000, compressedSize1000))) + Array(compressedSize1000, compressedSize1000, compressedSize1000), 5)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(compressedSize10000, compressedSize1000, compressedSize1000))) + Array(compressedSize10000, compressedSize1000, compressedSize1000), 6)) assert(0 == tracker.getNumCachedSerializedBroadcast) // As if we had two simultaneous fetch failures @@ -147,10 +148,11 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) masterTracker.registerMapOutput(10, 0, MapStatus( - BlockManagerId("a", "hostA", 1000), Array(1000L))) + BlockManagerId("a", "hostA", 1000), Array(1000L), 5)) slaveTracker.updateEpoch(masterTracker.getEpoch) assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) + Seq((BlockManagerId("a", "hostA", 1000), + ArrayBuffer((ShuffleBlockId(10, 5, 0), size1000, 0))))) assert(0 == masterTracker.getNumCachedSerializedBroadcast) val masterTrackerEpochBeforeLossOfMapOutput = masterTracker.getEpoch @@ -184,7 +186,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { // Message size should be ~123B, and no exception should be thrown masterTracker.registerShuffle(10, 1) masterTracker.registerMapOutput(10, 0, MapStatus( - BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0))) + BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0), 5)) val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) when(rpcCallContext.senderAddress).thenReturn(senderAddress) @@ -218,11 +220,11 @@ class MapOutputTrackerSuite extends SparkFunSuite { // on hostB with output size 3 tracker.registerShuffle(10, 3) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(2L))) + Array(2L), 5)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(2L))) + Array(2L), 6)) tracker.registerMapOutput(10, 2, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(3L))) + Array(3L), 7)) // When the threshold is 50%, only host A should be returned as a preferred location // as it has 4 out of 7 bytes of output. @@ -262,7 +264,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterTracker.registerShuffle(20, 100) (0 until 100).foreach { i => masterTracker.registerMapOutput(20, i, new CompressedMapStatus( - BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0))) + BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), 5)) } val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) @@ -311,16 +313,18 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(size0, size1000, size0, size10000))) + Array(size0, size1000, size0, size10000), 5)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(size10000, size0, size1000, size0))) + Array(size10000, size0, size1000, size0), 6)) assert(tracker.containsShuffle(10)) - assert(tracker.getMapSizesByExecutorId(10, 0, 4).toSeq === + assert(tracker.getMapSizesByExecutorId(10, 0, 4, false).toSeq === Seq( (BlockManagerId("a", "hostA", 1000), - Seq((ShuffleBlockId(10, 0, 1), size1000), (ShuffleBlockId(10, 0, 3), size10000))), + Seq((ShuffleBlockId(10, 5, 1), size1000, 0), + (ShuffleBlockId(10, 5, 3), size10000, 0))), (BlockManagerId("b", "hostB", 1000), - Seq((ShuffleBlockId(10, 1, 0), size10000), (ShuffleBlockId(10, 1, 2), size1000))) + Seq((ShuffleBlockId(10, 6, 0), size10000, 1), + (ShuffleBlockId(10, 6, 2), size1000, 1))) ) ) diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 923c9c90447fd..c75b56315547c 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -360,7 +360,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC val metricsSystem = sc.env.metricsSystem val shuffleMapRdd = new MyRDD(sc, 1, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1)) - val shuffleHandle = manager.registerShuffle(0, 1, shuffleDep) + val shuffleHandle = manager.registerShuffle(0, shuffleDep) mapTrackerMaster.registerShuffle(0, 1) // first attempt -- its successful diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 786f55c96a3e8..ac54e5ef10fe9 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -705,7 +705,7 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu if (context.stageAttemptNumber == 0) { if (context.partitionId == 0) { // Make the first task in the first stage attempt fail. - throw new FetchFailedException(SparkEnv.get.blockManager.blockManagerId, 0, 0, 0, + throw new FetchFailedException(SparkEnv.get.blockManager.blockManagerId, 0, 0L, 0, 0, new java.io.IOException("fake")) } else { // Make the second task in the first stage attempt sleep to generate a zombie task diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index ac7e4b51ebc2b..3faab52d6510c 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -528,7 +528,8 @@ class FetchFailureThrowingRDD(sc: SparkContext) extends RDD[Int](sc, Nil) { throw new FetchFailedException( bmAddress = BlockManagerId("1", "hostA", 1234), shuffleId = 0, - mapId = 0, + mapId = 0L, + mapIndex = 0, reduceId = 0, message = "fake fetch failure" ) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 7cb7eceec615b..f6c0bf61f6d9e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -487,18 +487,22 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // map stage1 completes successfully, with one task on each executor complete(taskSets(0), Seq( (Success, - MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2))), + MapStatus( + BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2), mapTaskId = 5)), (Success, - MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2))), - (Success, makeMapStatus("hostB", 1)) + MapStatus( + BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2), mapTaskId = 6)), + (Success, makeMapStatus("hostB", 1, mapTaskId = 7)) )) // map stage2 completes successfully, with one task on each executor complete(taskSets(1), Seq( (Success, - MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2))), + MapStatus( + BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2), mapTaskId = 8)), (Success, - MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2))), - (Success, makeMapStatus("hostB", 1)) + MapStatus( + BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2), mapTaskId = 9)), + (Success, makeMapStatus("hostB", 1, mapTaskId = 10)) )) // make sure our test setup is correct val initialMapStatus1 = mapOutputTracker.shuffleStatuses(firstShuffleId).mapStatuses @@ -506,16 +510,19 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assert(initialMapStatus1.count(_ != null) === 3) assert(initialMapStatus1.map{_.location.executorId}.toSet === Set("exec-hostA1", "exec-hostA2", "exec-hostB")) + assert(initialMapStatus1.map{_.mapTaskId}.toSet === Set(5, 6, 7)) val initialMapStatus2 = mapOutputTracker.shuffleStatuses(secondShuffleId).mapStatuses // val initialMapStatus1 = mapOutputTracker.mapStatuses.get(0).get assert(initialMapStatus2.count(_ != null) === 3) assert(initialMapStatus2.map{_.location.executorId}.toSet === Set("exec-hostA1", "exec-hostA2", "exec-hostB")) + assert(initialMapStatus2.map{_.mapTaskId}.toSet === Set(8, 9, 10)) // reduce stage fails with a fetch failure from one host complete(taskSets(2), Seq( - (FetchFailed(BlockManagerId("exec-hostA2", "hostA", 12345), firstShuffleId, 0, 0, "ignored"), + (FetchFailed(BlockManagerId("exec-hostA2", "hostA", 12345), + firstShuffleId, 0L, 0, 0, "ignored"), null) )) @@ -757,7 +764,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // the 2nd ResultTask failed complete(taskSets(1), Seq( (Success, 42), - (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), null))) + (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0L, 0, 0, "ignored"), null))) // this will get called // blockManagerMaster.removeExecutor("exec-hostA") // ask the scheduler to try it again @@ -904,7 +911,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi val stageAttempt = taskSets.last checkStageId(stageId, attemptIdx, stageAttempt) complete(stageAttempt, stageAttempt.tasks.zipWithIndex.map { case (task, idx) => - (FetchFailed(makeBlockManagerId("hostA"), shuffleDep.shuffleId, 0, idx, "ignored"), null) + (FetchFailed(makeBlockManagerId("hostA"), shuffleDep.shuffleId, 0L, 0, idx, "ignored"), null) }.toSeq) } @@ -1137,14 +1144,14 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // The first result task fails, with a fetch failure for the output from the first mapper. runEvent(makeCompletionEvent( taskSets(1).tasks(0), - FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0L, 0, 0, "ignored"), null)) assert(sparkListener.failedStages.contains(1)) // The second ResultTask fails, with a fetch failure for the output from the second mapper. runEvent(makeCompletionEvent( taskSets(1).tasks(0), - FetchFailed(makeBlockManagerId("hostA"), shuffleId, 1, 1, "ignored"), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 1L, 1, 1, "ignored"), null)) // The SparkListener should not receive redundant failure events. assert(sparkListener.failedStages.size === 1) @@ -1164,7 +1171,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // The first result task fails, with a fetch failure for the output from the first mapper. runEvent(makeCompletionEvent( taskSets(1).tasks(0), - FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0L, 0, 0, "ignored"), null)) assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq(0, 1))) @@ -1266,7 +1273,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // The first result task fails, with a fetch failure for the output from the first mapper. runEvent(makeCompletionEvent( taskSets(1).tasks(0), - FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0L, 0, 0, "ignored"), null)) assert(sparkListener.failedStages.contains(1)) @@ -1279,7 +1286,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // The second ResultTask fails, with a fetch failure for the output from the second mapper. runEvent(makeCompletionEvent( taskSets(1).tasks(1), - FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1, 1, "ignored"), + FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1L, 1, 1, "ignored"), null)) // Another ResubmitFailedStages event should not result in another attempt for the map @@ -1325,7 +1332,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // The first result task fails, with a fetch failure for the output from the first mapper. runEvent(makeCompletionEvent( taskSets(1).tasks(0), - FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0L, 0, 0, "ignored"), null)) // Trigger resubmission of the failed map stage and finish the re-started map task. @@ -1340,7 +1347,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // A late FetchFailed arrives from the second task in the original reduce stage. runEvent(makeCompletionEvent( taskSets(1).tasks(1), - FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1, 1, "ignored"), + FetchFailed(makeBlockManagerId("hostB"), shuffleId, 1L, 1, 1, "ignored"), null)) // Running ResubmitFailedStages shouldn't result in any more attempts for the map stage, because @@ -1535,7 +1542,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi runEvent(ExecutorLost("exec-hostA", ExecutorKilled)) runEvent(makeCompletionEvent( taskSets(1).tasks(0), - FetchFailed(null, firstShuffleId, 2, 0, "Fetch failed"), + FetchFailed(null, firstShuffleId, 2L, 2, 0, "Fetch failed"), null)) // so we resubmit stage 0, which completes happily @@ -1794,7 +1801,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // lets say there is a fetch failure in this task set, which makes us go back and // run stage 0, attempt 1 complete(taskSets(1), Seq( - (FetchFailed(makeBlockManagerId("hostA"), shuffleDep1.shuffleId, 0, 0, "ignored"), null))) + (FetchFailed(makeBlockManagerId("hostA"), + shuffleDep1.shuffleId, 0L, 0, 0, "ignored"), null))) scheduler.resubmitFailedStages() // stage 0, attempt 1 should have the properties of job2 @@ -1875,7 +1883,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi (Success, makeMapStatus("hostC", 1)))) // fail the third stage because hostA went down complete(taskSets(2), Seq( - (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null))) + (FetchFailed(makeBlockManagerId("hostA"), + shuffleDepTwo.shuffleId, 0L, 0, 0, "ignored"), null))) // TODO assert this: // blockManagerMaster.removeExecutor("exec-hostA") // have DAGScheduler try again @@ -1906,7 +1915,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi (Success, makeMapStatus("hostB", 1)))) // pretend stage 2 failed because hostA went down complete(taskSets(2), Seq( - (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null))) + (FetchFailed(makeBlockManagerId("hostA"), + shuffleDepTwo.shuffleId, 0L, 0, 0, "ignored"), null))) // TODO assert this: // blockManagerMaster.removeExecutor("exec-hostA") // DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun. @@ -2267,7 +2277,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi submit(reduceRdd, Array(0, 1)) complete(taskSets(1), Seq( (Success, 42), - (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), null))) + (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0L, 0, 0, "ignored"), null))) // Ask the scheduler to try it again; TaskSet 2 will rerun the map task that we couldn't fetch // from, then TaskSet 3 will run the reduce stage scheduler.resubmitFailedStages() @@ -2326,7 +2336,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assert(taskSets(1).stageId === 1) complete(taskSets(1), Seq( (Success, makeMapStatus("hostA", rdd2.partitions.length)), - (FetchFailed(makeBlockManagerId("hostA"), dep1.shuffleId, 0, 0, "ignored"), null))) + (FetchFailed(makeBlockManagerId("hostA"), dep1.shuffleId, 0L, 0, 0, "ignored"), null))) scheduler.resubmitFailedStages() assert(listener2.results.size === 0) // Second stage listener should not have a result yet @@ -2352,7 +2362,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assert(taskSets(4).stageId === 2) complete(taskSets(4), Seq( (Success, 52), - (FetchFailed(makeBlockManagerId("hostD"), dep2.shuffleId, 0, 0, "ignored"), null))) + (FetchFailed(makeBlockManagerId("hostD"), dep2.shuffleId, 0L, 0, 0, "ignored"), null))) scheduler.resubmitFailedStages() // TaskSet 5 will rerun stage 1's lost task, then TaskSet 6 will rerun stage 2 @@ -2390,7 +2400,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assert(taskSets(1).stageId === 1) complete(taskSets(1), Seq( (Success, makeMapStatus("hostC", rdd2.partitions.length)), - (FetchFailed(makeBlockManagerId("hostA"), dep1.shuffleId, 0, 0, "ignored"), null))) + (FetchFailed(makeBlockManagerId("hostA"), dep1.shuffleId, 0L, 0, 0, "ignored"), null))) scheduler.resubmitFailedStages() // Stage1 listener should not have a result yet assert(listener2.results.size === 0) @@ -2525,7 +2535,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi rdd1.map { case (x, _) if (x == 1) => throw new FetchFailedException( - BlockManagerId("1", "1", 1), shuffleHandle.shuffleId, 0, 0, "test") + BlockManagerId("1", "1", 1), shuffleHandle.shuffleId, 0L, 0, 0, "test") case (x, _) => x }.count() } @@ -2538,7 +2548,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi rdd1.map { case (x, _) if (x == 1) && FailThisAttempt._fail.getAndSet(false) => throw new FetchFailedException( - BlockManagerId("1", "1", 1), shuffleHandle.shuffleId, 0, 0, "test") + BlockManagerId("1", "1", 1), shuffleHandle.shuffleId, 0L, 0, 0, "test") } } @@ -2592,7 +2602,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assert(taskSets(1).stageId === 1 && taskSets(1).stageAttemptId === 0) runEvent(makeCompletionEvent( taskSets(1).tasks(0), - FetchFailed(makeBlockManagerId("hostA"), shuffleIdA, 0, 0, + FetchFailed(makeBlockManagerId("hostA"), shuffleIdA, 0L, 0, 0, "Fetch failure of task: stageId=1, stageAttempt=0, partitionId=0"), result = null)) @@ -2745,7 +2755,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assert(countSubmittedMapStageAttempts() === 2) } - test("SPARK-23207: retry all the succeeding stages when the map stage is indeterminate") { + private def constructIndeterminateStageFetchFailed(): (Int, Int) = { val shuffleMapRdd1 = new MyRDD(sc, 2, Nil, indeterminate = true) val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(2)) @@ -2773,14 +2783,140 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // The first task of the final stage failed with fetch failure runEvent(makeCompletionEvent( taskSets(2).tasks(0), - FetchFailed(makeBlockManagerId("hostC"), shuffleId2, 0, 0, "ignored"), + FetchFailed(makeBlockManagerId("hostC"), shuffleId2, 0L, 0, 0, "ignored"), null)) + (shuffleId1, shuffleId2) + } + + test("SPARK-25341: abort stage while using old fetch protocol") { + // reset the test context with using old fetch protocol + afterEach() + val conf = new SparkConf() + conf.set(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL.key, "true") + init(conf) + // Construct the scenario of indeterminate stage fetch failed. + constructIndeterminateStageFetchFailed() + // The job should fail because Spark can't rollback the shuffle map stage while + // using old protocol. + assert(failure != null && failure.getMessage.contains( + "Spark can only do this while using the new shuffle block fetching protocol")) + } + + test("SPARK-25341: retry all the succeeding stages when the map stage is indeterminate") { + val (shuffleId1, shuffleId2) = constructIndeterminateStageFetchFailed() + + // Check status for all failedStages + val failedStages = scheduler.failedStages.toSeq + assert(failedStages.map(_.id) == Seq(1, 2)) + // Shuffle blocks of "hostC" is lost, so first task of the `shuffleMapRdd2` needs to retry. + assert(failedStages.collect { + case stage: ShuffleMapStage if stage.shuffleDep.shuffleId == shuffleId2 => stage + }.head.findMissingPartitions() == Seq(0)) + // The result stage is still waiting for its 2 tasks to complete + assert(failedStages.collect { + case stage: ResultStage => stage + }.head.findMissingPartitions() == Seq(0, 1)) + + scheduler.resubmitFailedStages() + + // The first task of the `shuffleMapRdd2` failed with fetch failure + runEvent(makeCompletionEvent( + taskSets(3).tasks(0), + FetchFailed(makeBlockManagerId("hostA"), shuffleId1, 0L, 0, 0, "ignored"), + null)) + + val newFailedStages = scheduler.failedStages.toSeq + assert(newFailedStages.map(_.id) == Seq(0, 1)) + + scheduler.resubmitFailedStages() + + // First shuffle map stage resubmitted and reran all tasks. + assert(taskSets(4).stageId == 0) + assert(taskSets(4).stageAttemptId == 1) + assert(taskSets(4).tasks.length == 2) + + // Finish all stage. + complete(taskSets(4), Seq( + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)))) + assert(mapOutputTracker.findMissingPartitions(shuffleId1) === Some(Seq.empty)) + + complete(taskSets(5), Seq( + (Success, makeMapStatus("hostC", 2)), + (Success, makeMapStatus("hostD", 2)))) + assert(mapOutputTracker.findMissingPartitions(shuffleId2) === Some(Seq.empty)) + + complete(taskSets(6), Seq((Success, 11), (Success, 12))) + + // Job successful ended. + assert(results === Map(0 -> 11, 1 -> 12)) + results.clear() + assertDataStructuresEmpty() + } + + test("SPARK-25341: continuous indeterminate stage roll back") { + // shuffleMapRdd1/2/3 are all indeterminate. + val shuffleMapRdd1 = new MyRDD(sc, 2, Nil, indeterminate = true) + val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(2)) + val shuffleId1 = shuffleDep1.shuffleId + + val shuffleMapRdd2 = new MyRDD( + sc, 2, List(shuffleDep1), tracker = mapOutputTracker, indeterminate = true) + val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new HashPartitioner(2)) + val shuffleId2 = shuffleDep2.shuffleId + + val shuffleMapRdd3 = new MyRDD( + sc, 2, List(shuffleDep2), tracker = mapOutputTracker, indeterminate = true) + val shuffleDep3 = new ShuffleDependency(shuffleMapRdd3, new HashPartitioner(2)) + val shuffleId3 = shuffleDep3.shuffleId + val finalRdd = new MyRDD(sc, 2, List(shuffleDep3), tracker = mapOutputTracker) + + submit(finalRdd, Array(0, 1), properties = new Properties()) + + // Finish the first 2 shuffle map stages. + complete(taskSets(0), Seq( + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)))) + assert(mapOutputTracker.findMissingPartitions(shuffleId1) === Some(Seq.empty)) + + complete(taskSets(1), Seq( + (Success, makeMapStatus("hostB", 2)), + (Success, makeMapStatus("hostD", 2)))) + assert(mapOutputTracker.findMissingPartitions(shuffleId2) === Some(Seq.empty)) + + // Executor lost on hostB, both of stage 0 and 1 should be reran. + runEvent(makeCompletionEvent( + taskSets(2).tasks(0), + FetchFailed(makeBlockManagerId("hostB"), shuffleId2, 0L, 0, 0, "ignored"), + null)) + mapOutputTracker.removeOutputsOnHost("hostB") + + assert(scheduler.failedStages.toSeq.map(_.id) == Seq(1, 2)) + scheduler.resubmitFailedStages() + + def checkAndCompleteRetryStage( + taskSetIndex: Int, + stageId: Int, + shuffleId: Int): Unit = { + assert(taskSets(taskSetIndex).stageId == stageId) + assert(taskSets(taskSetIndex).stageAttemptId == 1) + assert(taskSets(taskSetIndex).tasks.length == 2) + complete(taskSets(taskSetIndex), Seq( + (Success, makeMapStatus("hostA", 2)), + (Success, makeMapStatus("hostB", 2)))) + assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty)) + } + + // Check all indeterminate stage roll back. + checkAndCompleteRetryStage(3, 0, shuffleId1) + checkAndCompleteRetryStage(4, 1, shuffleId2) + checkAndCompleteRetryStage(5, 2, shuffleId3) - // The second shuffle map stage need to rerun, the job will abort for the indeterminate - // stage rerun. - // TODO: After we support re-generate shuffle file(SPARK-25341), this test will be extended. - assert(failure != null && failure.getMessage - .contains("Spark cannot rollback the ShuffleMapStage 1")) + // Result stage success, all job ended. + complete(taskSets(6), Seq((Success, 11), (Success, 12))) + assert(results === Map(0 -> 11, 1 -> 12)) + results.clear() + assertDataStructuresEmpty() } test("SPARK-29042: Sampled RDD with unordered input should be indeterminate") { @@ -2813,7 +2949,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // Fail the second task with FetchFailed. runEvent(makeCompletionEvent( taskSets.last.tasks(1), - FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0L, 0, 0, "ignored"), null)) // The job should fail because Spark can't rollback the result stage. @@ -2856,7 +2992,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // Fail the second task with FetchFailed. runEvent(makeCompletionEvent( taskSets.last.tasks(1), - FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0L, 0, 0, "ignored"), null)) assert(failure == null, "job should not fail") @@ -2903,33 +3039,6 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi assert(latch.await(10, TimeUnit.SECONDS)) } - test("SPARK-28699: abort stage if parent stage is indeterminate stage") { - val shuffleMapRdd = new MyRDD(sc, 2, Nil, indeterminate = true) - - val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) - val shuffleId = shuffleDep.shuffleId - val finalRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) - - submit(finalRdd, Array(0, 1)) - - // Finish the first shuffle map stage. - complete(taskSets(0), Seq( - (Success, makeMapStatus("hostA", 2)), - (Success, makeMapStatus("hostB", 2)))) - assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty)) - - runEvent(makeCompletionEvent( - taskSets(1).tasks(0), - FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), - null)) - - // Shuffle blocks of "hostA" is lost, so first task of the `shuffleMapRdd` needs to retry. - // The result stage is still waiting for its 2 tasks to complete. - // Because of shuffleMapRdd is indeterminate, this job will be abort. - assert(failure != null && failure.getMessage - .contains("Spark cannot rollback the ShuffleMapStage 0")) - } - test("Completions in zombie tasksets update status of non-zombie taskset") { val parts = 4 val shuffleMapRdd = new MyRDD(sc, parts, Nil) @@ -2946,7 +3055,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi // The second task of the shuffle map stage failed with FetchFailed. runEvent(makeCompletionEvent( taskSets(0).tasks(1), - FetchFailed(makeBlockManagerId("hostB"), shuffleDep.shuffleId, 0, 0, "ignored"), + FetchFailed(makeBlockManagerId("hostB"), shuffleDep.shuffleId, 0L, 0, 0, "ignored"), null)) scheduler.resubmitFailedStages() @@ -3036,8 +3145,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi } object DAGSchedulerSuite { - def makeMapStatus(host: String, reduces: Int, sizes: Byte = 2): MapStatus = - MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes)) + def makeMapStatus(host: String, reduces: Int, sizes: Byte = 2, mapTaskId: Long = -1): MapStatus = + MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes), mapTaskId) def makeBlockManagerId(host: String): BlockManagerId = BlockManagerId("exec-" + host, host, 12345) diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index c1e7fb9a1db16..700d9ebd76c0c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -61,7 +61,7 @@ class MapStatusSuite extends SparkFunSuite { stddev <- Seq(0.0, 0.01, 0.5, 1.0) ) { val sizes = Array.fill[Long](numSizes)(abs(round(Random.nextGaussian() * stddev)) + mean) - val status = MapStatus(BlockManagerId("a", "b", 10), sizes) + val status = MapStatus(BlockManagerId("a", "b", 10), sizes, -1) val status1 = compressAndDecompressMapStatus(status) for (i <- 0 until numSizes) { if (sizes(i) != 0) { @@ -75,7 +75,7 @@ class MapStatusSuite extends SparkFunSuite { test("large tasks should use " + classOf[HighlyCompressedMapStatus].getName) { val sizes = Array.fill[Long](2001)(150L) - val status = MapStatus(null, sizes) + val status = MapStatus(null, sizes, -1) assert(status.isInstanceOf[HighlyCompressedMapStatus]) assert(status.getSizeForBlock(10) === 150L) assert(status.getSizeForBlock(50) === 150L) @@ -87,10 +87,12 @@ class MapStatusSuite extends SparkFunSuite { val sizes = Array.tabulate[Long](3000) { i => i.toLong } val avg = sizes.sum / sizes.count(_ != 0) val loc = BlockManagerId("a", "b", 10) - val status = MapStatus(loc, sizes) + val mapTaskAttemptId = 5 + val status = MapStatus(loc, sizes, mapTaskAttemptId) val status1 = compressAndDecompressMapStatus(status) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) assert(status1.location == loc) + assert(status1.mapTaskId == mapTaskAttemptId) for (i <- 0 until 3000) { val estimate = status1.getSizeForBlock(i) if (sizes(i) > 0) { @@ -109,7 +111,7 @@ class MapStatusSuite extends SparkFunSuite { val smallBlockSizes = sizes.filter(n => n > 0 && n < threshold) val avg = smallBlockSizes.sum / smallBlockSizes.length val loc = BlockManagerId("a", "b", 10) - val status = MapStatus(loc, sizes) + val status = MapStatus(loc, sizes, 5) val status1 = compressAndDecompressMapStatus(status) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) assert(status1.location == loc) @@ -165,7 +167,7 @@ class MapStatusSuite extends SparkFunSuite { SparkEnv.set(env) // Value of element in sizes is equal to the corresponding index. val sizes = (0L to 2000L).toArray - val status1 = MapStatus(BlockManagerId("exec-0", "host-0", 100), sizes) + val status1 = MapStatus(BlockManagerId("exec-0", "host-0", 100), sizes, 5) val arrayStream = new ByteArrayOutputStream(102400) val objectOutputStream = new ObjectOutputStream(arrayStream) assert(status1.isInstanceOf[HighlyCompressedMapStatus]) diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index d6964063c118e..6f80c7c0fe817 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -254,7 +254,7 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { .reduceByKey { case (_, _) => val ctx = TaskContext.get() if (ctx.stageAttemptNumber() == 0) { - throw new FetchFailedException(SparkEnv.get.blockManager.blockManagerId, 1, 1, 1, + throw new FetchFailedException(SparkEnv.get.blockManager.blockManagerId, 1, 1L, 1, 1, new Exception("Failure for test.")) } else { ctx.stageId() diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala index 96706536fe53c..4f737c9499ad6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala @@ -621,7 +621,7 @@ class BasicSchedulerIntegrationSuite extends SchedulerIntegrationSuite[SingleCor backend.taskSuccess(taskDescription, DAGSchedulerSuite.makeMapStatus("hostA", 10)) case (1, 0, 0) => val fetchFailed = FetchFailed( - DAGSchedulerSuite.makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored") + DAGSchedulerSuite.makeBlockManagerId("hostA"), shuffleId, 0L, 0, 0, "ignored") backend.taskFailed(taskDescription, fetchFailed) case (1, _, partition) => backend.taskSuccess(taskDescription, 42 + partition) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index c16b552d20891..394a2a9fbf7cb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -176,7 +176,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark if (stageAttemptNumber < 2) { // Throw FetchFailedException to explicitly trigger stage resubmission. A normal exception // will only trigger task resubmission in the same stage. - throw new FetchFailedException(null, 0, 0, 0, "Fake") + throw new FetchFailedException(null, 0, 0L, 0, 0, "Fake") } Seq(stageAttemptNumber).iterator }.collect() diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index fedfa083e8d8f..5b1cb08aa4813 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -1262,7 +1262,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // now fail those tasks tsmSpy.handleFailedTask(taskDescs(0).taskId, TaskState.FAILED, - FetchFailed(BlockManagerId(taskDescs(0).executorId, "host1", 12345), 0, 0, 0, "ignored")) + FetchFailed(BlockManagerId(taskDescs(0).executorId, "host1", 12345), 0, 0L, 0, 0, "ignored")) tsmSpy.handleFailedTask(taskDescs(1).taskId, TaskState.FAILED, ExecutorLostFailure(taskDescs(1).executorId, exitCausedByApp = false, reason = None)) tsmSpy.handleFailedTask(taskDescs(2).taskId, TaskState.FAILED, @@ -1302,7 +1302,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // Fail the task with fetch failure tsm.handleFailedTask(taskDescs(0).taskId, TaskState.FAILED, - FetchFailed(BlockManagerId(taskDescs(0).executorId, "host1", 12345), 0, 0, 0, "ignored")) + FetchFailed(BlockManagerId(taskDescs(0).executorId, "host1", 12345), 0, 0L, 0, 0, "ignored")) assert(blacklistTracker.isNodeBlacklisted("host1")) } diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 2442670b6d3f0..43d7d12a3caed 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -350,8 +350,11 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { val ser = new KryoSerializer(conf).newInstance() val denseBlockSizes = new Array[Long](5000) val sparseBlockSizes = Array[Long](0L, 1L, 0L, 2L) + var mapTaskId = 0 Seq(denseBlockSizes, sparseBlockSizes).foreach { blockSizes => - ser.serialize(HighlyCompressedMapStatus(BlockManagerId("exec-1", "host", 1234), blockSizes)) + mapTaskId += 1 + ser.serialize(HighlyCompressedMapStatus( + BlockManagerId("exec-1", "host", 1234), blockSizes, mapTaskId)) } } diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 6d2ef17a7a790..d0cbb30fe0232 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -102,12 +102,13 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // Make a mocked MapOutputTracker for the shuffle reader to use to determine what // shuffle data to read. val mapOutputTracker = mock(classOf[MapOutputTracker]) - when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)).thenReturn { + when(mapOutputTracker.getMapSizesByExecutorId( + shuffleId, reduceId, reduceId + 1, useOldFetchProtocol = false)).thenReturn { // Test a scenario where all data is local, to avoid creating a bunch of additional mocks // for the code to read data over the network. val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId => val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) - (shuffleBlockId, byteOutputStream.size().toLong) + (shuffleBlockId, byteOutputStream.size().toLong, mapId) } Seq((localBlockManagerId, shuffleBlockIdsAndSizes)).toIterator } @@ -118,7 +119,7 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext when(dependency.serializer).thenReturn(serializer) when(dependency.aggregator).thenReturn(None) when(dependency.keyOrdering).thenReturn(None) - new BaseShuffleHandle(shuffleId, numMaps, dependency) + new BaseShuffleHandle(shuffleId, dependency) } val serializerManager = new SerializerManager( diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index b9f81fa0d0a06..f8474022867f4 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -25,7 +25,7 @@ import scala.collection.mutable.ArrayBuffer import org.mockito.{Mock, MockitoAnnotations} import org.mockito.Answers.RETURNS_SMART_NULLS -import org.mockito.ArgumentMatchers.{any, anyInt} +import org.mockito.ArgumentMatchers.{any, anyInt, anyLong} import org.mockito.Mockito._ import org.scalatest.BeforeAndAfterEach @@ -65,7 +65,6 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte taskMetrics = new TaskMetrics shuffleHandle = new BypassMergeSortShuffleHandle[Int, Int]( shuffleId = 0, - numMaps = 2, dependency = dependency ) val memoryManager = new TestMemoryManager(conf) @@ -78,7 +77,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager) when(blockResolver.writeIndexFileAndCommit( - anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File]))) + anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[File]))) .thenAnswer { invocationOnMock => val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File] if (tmp != null) { @@ -139,8 +138,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte val writer = new BypassMergeSortShuffleWriter[Int, Int]( blockManager, shuffleHandle, - 0, // MapId - 0L, // MapTaskAttemptId + 0L, // MapId conf, taskContext.taskMetrics().shuffleWriteMetrics, shuffleExecutorComponents) @@ -166,8 +164,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte val writer = new BypassMergeSortShuffleWriter[Int, Int]( blockManager, shuffleHandle, - 0, // MapId - 0L, + 0L, // MapId transferConf, taskContext.taskMetrics().shuffleWriteMetrics, shuffleExecutorComponents) @@ -202,8 +199,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte val writer = new BypassMergeSortShuffleWriter[Int, Int]( blockManager, shuffleHandle, - 0, // MapId - 0L, + 0L, // MapId conf, taskContext.taskMetrics().shuffleWriteMetrics, shuffleExecutorComponents) @@ -224,8 +220,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte val writer = new BypassMergeSortShuffleWriter[Int, Int]( blockManager, shuffleHandle, - 0, // MapId - 0L, + 0L, // MapId conf, taskContext.taskMetrics().shuffleWriteMetrics, shuffleExecutorComponents) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala index 0dd6040808f9e..4c5694fcf0305 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala @@ -57,7 +57,7 @@ class SortShuffleWriterSuite extends SparkFunSuite with SharedSparkContext with when(dependency.serializer).thenReturn(serializer) when(dependency.aggregator).thenReturn(None) when(dependency.keyOrdering).thenReturn(None) - new BaseShuffleHandle(shuffleId, numMaps = numMaps, dependency) + new BaseShuffleHandle(shuffleId, dependency) } shuffleExecutorComponents = new LocalDiskShuffleExecutorComponents( conf, blockManager, shuffleBlockResolver) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala index 5156cc2cc47a6..f92455912f510 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala @@ -23,7 +23,7 @@ import java.nio.file.Files import java.util.Arrays import org.mockito.Answers.RETURNS_SMART_NULLS -import org.mockito.ArgumentMatchers.{any, anyInt} +import org.mockito.ArgumentMatchers.{any, anyInt, anyLong} import org.mockito.Mock import org.mockito.Mockito.when import org.mockito.MockitoAnnotations @@ -73,9 +73,9 @@ class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndA conf = new SparkConf() .set("spark.app.id", "example.spark.app") .set("spark.shuffle.unsafe.file.output.buffer", "16k") - when(blockResolver.getDataFile(anyInt, anyInt)).thenReturn(mergedOutputFile) + when(blockResolver.getDataFile(anyInt, anyLong)).thenReturn(mergedOutputFile) when(blockResolver.writeIndexFileAndCommit( - anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File]))) + anyInt, anyLong, any(classOf[Array[Long]]), any(classOf[File]))) .thenAnswer { invocationOnMock => partitionSizesInMergedFile = invocationOnMock.getArguments()(2).asInstanceOf[Array[Long]] val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File] diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index ed402440e74f1..e5a615c2c2cbb 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -98,9 +98,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val transfer = createMockTransfer(remoteBlocks) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (localBmId, localBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq), - (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq) + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( + (localBmId, localBlocks.keys.map(blockId => (blockId, 1L, 0)).toSeq), + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 1L, 1)).toSeq) ).toIterator val taskContext = TaskContext.empty() @@ -179,8 +179,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } }) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( + (remoteBmId, blocks.keys.map(blockId => (blockId, 1L, 0)).toSeq)).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -247,8 +247,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } }) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( + (remoteBmId, blocks.keys.map(blockId => (blockId, 1L, 0)).toSeq)) + .toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -336,8 +337,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } }) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( + (remoteBmId, blocks.keys.map(blockId => (blockId, 1L, 0)).toSeq)).toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -389,8 +390,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val corruptBuffer1 = mockCorruptBuffer(streamLength, 0) val blockManagerId1 = BlockManagerId("remote-client-1", "remote-client-1", 1) val shuffleBlockId1 = ShuffleBlockId(0, 1, 0) - val blockLengths1 = Seq[Tuple2[BlockId, Long]]( - shuffleBlockId1 -> corruptBuffer1.size() + val blockLengths1 = Seq[Tuple3[BlockId, Long, Int]]( + (shuffleBlockId1, corruptBuffer1.size(), 1) ) val streamNotCorruptTill = 8 * 1024 @@ -398,13 +399,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val corruptBuffer2 = mockCorruptBuffer(streamLength, streamNotCorruptTill) val blockManagerId2 = BlockManagerId("remote-client-2", "remote-client-2", 2) val shuffleBlockId2 = ShuffleBlockId(0, 2, 0) - val blockLengths2 = Seq[Tuple2[BlockId, Long]]( - shuffleBlockId2 -> corruptBuffer2.size() + val blockLengths2 = Seq[Tuple3[BlockId, Long, Int]]( + (shuffleBlockId2, corruptBuffer2.size(), 2) ) val transfer = createMockTransfer( Map(shuffleBlockId1 -> corruptBuffer1, shuffleBlockId2 -> corruptBuffer2)) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( (blockManagerId1, blockLengths1), (blockManagerId2, blockLengths2) ).toIterator @@ -465,11 +466,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val localBmId = BlockManagerId("test-client", "test-client", 1) doReturn(localBmId).when(blockManager).blockManagerId doReturn(managedBuffer).when(blockManager).getBlockData(ShuffleBlockId(0, 0, 0)) - val localBlockLengths = Seq[Tuple2[BlockId, Long]]( - ShuffleBlockId(0, 0, 0) -> 10000 + val localBlockLengths = Seq[Tuple3[BlockId, Long, Int]]( + (ShuffleBlockId(0, 0, 0), 10000, 0) ) val transfer = createMockTransfer(Map(ShuffleBlockId(0, 0, 0) -> managedBuffer)) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( (localBmId, localBlockLengths) ).toIterator @@ -531,8 +532,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } }) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)).toIterator + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( + (remoteBmId, blocks.keys.map(blockId => (blockId, 1L, 0)).toSeq)) + .toIterator val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( @@ -591,7 +593,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT }) def fetchShuffleBlock( - blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long)])]): Unit = { + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]): Unit = { // Set `maxBytesInFlight` and `maxReqsInFlight` to `Int.MaxValue`, so that during the // construction of `ShuffleBlockFetcherIterator`, all requests to fetch remote shuffle blocks // are issued. The `maxReqSizeShuffleToMem` is hard-coded as 200 here. @@ -611,15 +613,15 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT taskContext.taskMetrics.createTempShuffleReadMetrics()) } - val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L)).toSeq)).toIterator + val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L, 0)).toSeq)).toIterator fetchShuffleBlock(blocksByAddress1) // `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch // shuffle block to disk. assert(tempFileManager == null) - val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)).toIterator + val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L, 0)).toSeq)).toIterator fetchShuffleBlock(blocksByAddress2) // `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch // shuffle block to disk. @@ -640,8 +642,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val transfer = createMockTransfer(blocks.mapValues(_ => createMockManagedBuffer(0))) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( + (remoteBmId, blocks.keys.map(blockId => (blockId, 1L, 0)).toSeq)) val taskContext = TaskContext.empty() val iterator = new ShuffleBlockFetcherIterator( diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index 1913b8d425519..580af086ba9da 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -316,10 +316,12 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B val env = SparkEnv.get val bmAddress = env.blockManager.blockManagerId val shuffleId = shuffleHandle.shuffleId - val mapId = 0 + val mapId = 0L + val mapIndex = 0 val reduceId = taskContext.partitionId() val message = "Simulated fetch failure" - throw new FetchFailedException(bmAddress, shuffleId, mapId, reduceId, message) + throw new FetchFailedException( + bmAddress, shuffleId, mapId, mapIndex, reduceId, message) } else { x } diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index e781c5f71faf4..54625a93679fb 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -179,7 +179,7 @@ class JsonProtocolSuite extends SparkFunSuite { testJobResult(jobFailed) // TaskEndReason - val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19, + val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 16L, 18, 19, "Some exception") val fetchMetadataFailed = new MetadataFetchFailedException(17, 19, "metadata Fetch failed exception").toTaskFailedReason @@ -296,12 +296,12 @@ class JsonProtocolSuite extends SparkFunSuite { test("FetchFailed backwards compatibility") { // FetchFailed in Spark 1.1.0 does not have a "Message" property. - val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19, + val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 16L, 18, 19, "ignored") val oldEvent = JsonProtocol.taskEndReasonToJson(fetchFailed) .removeField({ _._1 == "Message" }) - val expectedFetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19, - "Unknown reason") + val expectedFetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 16L, + 18, 19, "Unknown reason") assert(expectedFetchFailed === JsonProtocol.taskEndReasonFromJson(oldEvent)) } @@ -732,6 +732,7 @@ private[spark] object JsonProtocolSuite extends Assertions { case (r1: FetchFailed, r2: FetchFailed) => assert(r1.shuffleId === r2.shuffleId) assert(r1.mapId === r2.mapId) + assert(r1.mapIndex === r2.mapIndex) assert(r1.reduceId === r2.reduceId) assert(r1.bmAddress === r2.bmAddress) assert(r1.message === r2.message) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 097f1d2c2a6e1..9f60338df7059 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -417,7 +417,32 @@ object MimaExcludes { // [SPARK-25382][SQL][PYSPARK] Remove ImageSchema.readImages in 3.0 ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.image.ImageSchema.readImages"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.image.ImageSchema.readImages") + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.image.ImageSchema.readImages"), + + // [SPARK-25341][CORE] Support rolling back a shuffle map stage and re-generate the shuffle files + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.shuffle.sort.UnsafeShuffleWriter.this"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.storage.ShuffleIndexBlockId.copy$default$2"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.ShuffleIndexBlockId.copy"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.ShuffleIndexBlockId.this"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.storage.ShuffleDataBlockId.copy$default$2"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.ShuffleDataBlockId.copy"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.ShuffleDataBlockId.this"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.storage.ShuffleBlockId.copy$default$2"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.ShuffleBlockId.copy"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.ShuffleBlockId.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.ShuffleIndexBlockId.apply"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.ShuffleDataBlockId.apply"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.ShuffleBlockId.apply"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.storage.ShuffleIndexBlockId.mapId"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.storage.ShuffleDataBlockId.mapId"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.storage.ShuffleBlockId.mapId"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.FetchFailed.mapId"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.FetchFailed$"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.FetchFailed.apply"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.FetchFailed.copy$default$5"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.FetchFailed.copy"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.FetchFailed.copy$default$3"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.FetchFailed.this") ) // Exclude rules for 2.4.x