Skip to content

Commit f725d47

Browse files
xuanyuankingcloud-fan
authored andcommitted
[SPARK-25341][CORE] Support rolling back a shuffle map stage and re-generate the shuffle files
After the newly added shuffle block fetching protocol in #24565, we can keep this work by extending the FetchShuffleBlocks message. ### What changes were proposed in this pull request? In this patch, we achieve the indeterminate shuffle rerun by reusing the task attempt id(unique id within an application) in shuffle id, so that each shuffle write attempt has a different file name. For the indeterministic stage, when the stage resubmits, we'll clear all existing map status and rerun all partitions. All changes are summarized as follows: - Change the mapId to mapTaskAttemptId in shuffle related id. - Record the mapTaskAttemptId in MapStatus. - Still keep mapId in ShuffleFetcherIterator for fetch failed scenario. - Add the determinate flag in Stage and use it in DAGScheduler and the cleaning work for the intermediate stage. ### Why are the changes needed? This is a follow-up work for #22112's future improvment[1]: `Currently we can't rollback and rerun a shuffle map stage, and just fail.` Spark will rerun a finished shuffle write stage while meeting fetch failures, currently, the rerun shuffle map stage will only resubmit the task for missing partitions and reuse the output of other partitions. This logic is fine in most scenarios, but for indeterministic operations(like repartition), multiple shuffle write attempts may write different data, only rerun the missing partition will lead a correctness bug. So for the shuffle map stage of indeterministic operations, we need to support rolling back the shuffle map stage and re-generate the shuffle files. ### Does this PR introduce any user-facing change? Yes, after this PR, the indeterminate stage rerun will be accepted by rerunning the whole stage. The original behavior is aborting the stage and fail the job. ### How was this patch tested? - UT: Add UT for all changing code and newly added function. - Manual Test: Also providing a manual test to verify the effect. ``` import scala.sys.process._ import org.apache.spark.TaskContext val determinateStage0 = sc.parallelize(0 until 1000 * 1000 * 100, 10) val indeterminateStage1 = determinateStage0.repartition(200) val indeterminateStage2 = indeterminateStage1.repartition(200) val indeterminateStage3 = indeterminateStage2.repartition(100) val indeterminateStage4 = indeterminateStage3.repartition(300) val fetchFailIndeterminateStage4 = indeterminateStage4.map { x => if (TaskContext.get.attemptNumber == 0 && TaskContext.get.partitionId == 190 && TaskContext.get.stageAttemptNumber == 0) { throw new Exception("pkill -f -n java".!!) } x } val indeterminateStage5 = fetchFailIndeterminateStage4.repartition(200) val finalStage6 = indeterminateStage5.repartition(100).collect().distinct.length ``` It's a simple job with multi indeterminate stage, it will get a wrong answer while using old Spark version like 2.2/2.3, and will be killed after #22112. With this fix, the job can retry all indeterminate stage as below screenshot and get the right result. ![image](https://user-images.githubusercontent.com/4833765/63948434-3477de00-caab-11e9-9ed1-75abfe6d16bd.png) Closes #25620 from xuanyuanking/SPARK-25341-8.27. Authored-by: Yuanjian Li <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 7a2ea58 commit f725d47

File tree

56 files changed

+672
-391
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+672
-391
lines changed

common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,4 +112,27 @@ public static int[] decode(ByteBuf buf) {
112112
return ints;
113113
}
114114
}
115+
116+
/** Long integer arrays are encoded with their length followed by long integers. */
117+
public static class LongArrays {
118+
public static int encodedLength(long[] longs) {
119+
return 4 + 8 * longs.length;
120+
}
121+
122+
public static void encode(ByteBuf buf, long[] longs) {
123+
buf.writeInt(longs.length);
124+
for (long i : longs) {
125+
buf.writeLong(i);
126+
}
127+
}
128+
129+
public static long[] decode(ByteBuf buf) {
130+
int numLongs = buf.readInt();
131+
long[] longs = new long[numLongs];
132+
for (int i = 0; i < longs.length; i ++) {
133+
longs[i] = buf.readLong();
134+
}
135+
return longs;
136+
}
137+
}
115138
}

common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ protected void handleMessage(
106106
numBlockIds += ids.length;
107107
}
108108
streamId = streamManager.registerStream(client.getClientId(),
109-
new ManagedBufferIterator(msg, numBlockIds), client.getChannel());
109+
new ShuffleManagedBufferIterator(msg), client.getChannel());
110110
} else {
111111
// For the compatibility with the old version, still keep the support for OpenBlocks.
112112
OpenBlocks msg = (OpenBlocks) msgObj;
@@ -299,21 +299,6 @@ private int[] shuffleMapIdAndReduceIds(String[] blockIds, int shuffleId) {
299299
return mapIdAndReduceIds;
300300
}
301301

302-
ManagedBufferIterator(FetchShuffleBlocks msg, int numBlockIds) {
303-
final int[] mapIdAndReduceIds = new int[2 * numBlockIds];
304-
int idx = 0;
305-
for (int i = 0; i < msg.mapIds.length; i++) {
306-
for (int reduceId : msg.reduceIds[i]) {
307-
mapIdAndReduceIds[idx++] = msg.mapIds[i];
308-
mapIdAndReduceIds[idx++] = reduceId;
309-
}
310-
}
311-
assert(idx == 2 * numBlockIds);
312-
size = mapIdAndReduceIds.length;
313-
blockDataForIndexFn = index -> blockManager.getBlockData(msg.appId, msg.execId,
314-
msg.shuffleId, mapIdAndReduceIds[index], mapIdAndReduceIds[index + 1]);
315-
}
316-
317302
@Override
318303
public boolean hasNext() {
319304
return index < size;
@@ -328,6 +313,49 @@ public ManagedBuffer next() {
328313
}
329314
}
330315

316+
private class ShuffleManagedBufferIterator implements Iterator<ManagedBuffer> {
317+
318+
private int mapIdx = 0;
319+
private int reduceIdx = 0;
320+
321+
private final String appId;
322+
private final String execId;
323+
private final int shuffleId;
324+
private final long[] mapIds;
325+
private final int[][] reduceIds;
326+
327+
ShuffleManagedBufferIterator(FetchShuffleBlocks msg) {
328+
appId = msg.appId;
329+
execId = msg.execId;
330+
shuffleId = msg.shuffleId;
331+
mapIds = msg.mapIds;
332+
reduceIds = msg.reduceIds;
333+
}
334+
335+
@Override
336+
public boolean hasNext() {
337+
// mapIds.length must equal to reduceIds.length, and the passed in FetchShuffleBlocks
338+
// must have non-empty mapIds and reduceIds, see the checking logic in
339+
// OneForOneBlockFetcher.
340+
assert(mapIds.length != 0 && mapIds.length == reduceIds.length);
341+
return mapIdx < mapIds.length && reduceIdx < reduceIds[mapIdx].length;
342+
}
343+
344+
@Override
345+
public ManagedBuffer next() {
346+
final ManagedBuffer block = blockManager.getBlockData(
347+
appId, execId, shuffleId, mapIds[mapIdx], reduceIds[mapIdx][reduceIdx]);
348+
if (reduceIdx < reduceIds[mapIdx].length - 1) {
349+
reduceIdx += 1;
350+
} else {
351+
reduceIdx = 0;
352+
mapIdx += 1;
353+
}
354+
metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0);
355+
return block;
356+
}
357+
}
358+
331359
@Override
332360
public void channelActive(TransportClient client) {
333361
metrics.activeConnections.inc();

common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ public ManagedBuffer getBlockData(
172172
String appId,
173173
String execId,
174174
int shuffleId,
175-
int mapId,
175+
long mapId,
176176
int reduceId) {
177177
ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId));
178178
if (executor == null) {
@@ -296,7 +296,7 @@ private void deleteNonShuffleServiceServedFiles(String[] dirs) {
296296
* and the block id format is from ShuffleDataBlockId and ShuffleIndexBlockId.
297297
*/
298298
private ManagedBuffer getSortBasedShuffleBlockData(
299-
ExecutorShuffleInfo executor, int shuffleId, int mapId, int reduceId) {
299+
ExecutorShuffleInfo executor, int shuffleId, long mapId, int reduceId) {
300300
File indexFile = ExecutorDiskUtils.getFile(executor.localDirs, executor.subDirsPerLocalDir,
301301
"shuffle_" + shuffleId + "_" + mapId + "_0.index");
302302

common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import java.util.HashMap;
2525

2626
import com.google.common.primitives.Ints;
27+
import com.google.common.primitives.Longs;
28+
import org.apache.commons.lang3.tuple.ImmutableTriple;
2729
import org.slf4j.Logger;
2830
import org.slf4j.LoggerFactory;
2931

@@ -111,21 +113,21 @@ private boolean isShuffleBlocks(String[] blockIds) {
111113
*/
112114
private FetchShuffleBlocks createFetchShuffleBlocksMsg(
113115
String appId, String execId, String[] blockIds) {
114-
int shuffleId = splitBlockId(blockIds[0])[0];
115-
HashMap<Integer, ArrayList<Integer>> mapIdToReduceIds = new HashMap<>();
116+
int shuffleId = splitBlockId(blockIds[0]).left;
117+
HashMap<Long, ArrayList<Integer>> mapIdToReduceIds = new HashMap<>();
116118
for (String blockId : blockIds) {
117-
int[] blockIdParts = splitBlockId(blockId);
118-
if (blockIdParts[0] != shuffleId) {
119+
ImmutableTriple<Integer, Long, Integer> blockIdParts = splitBlockId(blockId);
120+
if (blockIdParts.left != shuffleId) {
119121
throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
120122
", got:" + blockId);
121123
}
122-
int mapId = blockIdParts[1];
124+
long mapId = blockIdParts.middle;
123125
if (!mapIdToReduceIds.containsKey(mapId)) {
124126
mapIdToReduceIds.put(mapId, new ArrayList<>());
125127
}
126-
mapIdToReduceIds.get(mapId).add(blockIdParts[2]);
128+
mapIdToReduceIds.get(mapId).add(blockIdParts.right);
127129
}
128-
int[] mapIds = Ints.toArray(mapIdToReduceIds.keySet());
130+
long[] mapIds = Longs.toArray(mapIdToReduceIds.keySet());
129131
int[][] reduceIdArr = new int[mapIds.length][];
130132
for (int i = 0; i < mapIds.length; i++) {
131133
reduceIdArr[i] = Ints.toArray(mapIdToReduceIds.get(mapIds[i]));
@@ -134,17 +136,16 @@ private FetchShuffleBlocks createFetchShuffleBlocksMsg(
134136
}
135137

136138
/** Split the shuffleBlockId and return shuffleId, mapId and reduceId. */
137-
private int[] splitBlockId(String blockId) {
139+
private ImmutableTriple<Integer, Long, Integer> splitBlockId(String blockId) {
138140
String[] blockIdParts = blockId.split("_");
139141
if (blockIdParts.length != 4 || !blockIdParts[0].equals("shuffle")) {
140142
throw new IllegalArgumentException(
141143
"Unexpected shuffle block id format: " + blockId);
142144
}
143-
return new int[] {
144-
Integer.parseInt(blockIdParts[1]),
145-
Integer.parseInt(blockIdParts[2]),
146-
Integer.parseInt(blockIdParts[3])
147-
};
145+
return new ImmutableTriple<>(
146+
Integer.parseInt(blockIdParts[1]),
147+
Long.parseLong(blockIdParts[2]),
148+
Integer.parseInt(blockIdParts[3]));
148149
}
149150

150151
/** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */

common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,14 @@ public class FetchShuffleBlocks extends BlockTransferMessage {
3434
public final int shuffleId;
3535
// The length of mapIds must equal to reduceIds.size(), for the i-th mapId in mapIds,
3636
// it corresponds to the i-th int[] in reduceIds, which contains all reduce id for this map id.
37-
public final int[] mapIds;
37+
public final long[] mapIds;
3838
public final int[][] reduceIds;
3939

4040
public FetchShuffleBlocks(
4141
String appId,
4242
String execId,
4343
int shuffleId,
44-
int[] mapIds,
44+
long[] mapIds,
4545
int[][] reduceIds) {
4646
this.appId = appId;
4747
this.execId = execId;
@@ -98,7 +98,7 @@ public int encodedLength() {
9898
return Encoders.Strings.encodedLength(appId)
9999
+ Encoders.Strings.encodedLength(execId)
100100
+ 4 /* encoded length of shuffleId */
101-
+ Encoders.IntArrays.encodedLength(mapIds)
101+
+ Encoders.LongArrays.encodedLength(mapIds)
102102
+ 4 /* encoded length of reduceIds.size() */
103103
+ encodedLengthOfReduceIds;
104104
}
@@ -108,7 +108,7 @@ public void encode(ByteBuf buf) {
108108
Encoders.Strings.encode(buf, appId);
109109
Encoders.Strings.encode(buf, execId);
110110
buf.writeInt(shuffleId);
111-
Encoders.IntArrays.encode(buf, mapIds);
111+
Encoders.LongArrays.encode(buf, mapIds);
112112
buf.writeInt(reduceIds.length);
113113
for (int[] ids: reduceIds) {
114114
Encoders.IntArrays.encode(buf, ids);
@@ -119,7 +119,7 @@ public static FetchShuffleBlocks decode(ByteBuf buf) {
119119
String appId = Encoders.Strings.decode(buf);
120120
String execId = Encoders.Strings.decode(buf);
121121
int shuffleId = buf.readInt();
122-
int[] mapIds = Encoders.IntArrays.decode(buf);
122+
long[] mapIds = Encoders.LongArrays.decode(buf);
123123
int reduceIdsSize = buf.readInt();
124124
int[][] reduceIds = new int[reduceIdsSize][];
125125
for (int i = 0; i < reduceIdsSize; i++) {

common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ public class BlockTransferMessagesSuite {
2929
public void serializeOpenShuffleBlocks() {
3030
checkSerializeDeserialize(new OpenBlocks("app-1", "exec-2", new String[] { "b1", "b2" }));
3131
checkSerializeDeserialize(new FetchShuffleBlocks(
32-
"app-1", "exec-2", 0, new int[] {0, 1},
32+
"app-1", "exec-2", 0, new long[] {0, 1},
3333
new int[][] {{ 0, 1 }, { 0, 1, 2 }}));
3434
checkSerializeDeserialize(new RegisterExecutor("app-1", "exec-2", new ExecutorShuffleInfo(
3535
new String[] { "/local1", "/local2" }, 32, "MyShuffleManager")));

common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ public void testFetchShuffleBlocks() {
101101
when(blockResolver.getBlockData("app0", "exec1", 0, 0, 1)).thenReturn(blockMarkers[1]);
102102

103103
FetchShuffleBlocks fetchShuffleBlocks = new FetchShuffleBlocks(
104-
"app0", "exec1", 0, new int[] { 0 }, new int[][] {{ 0, 1 }});
104+
"app0", "exec1", 0, new long[] { 0 }, new int[][] {{ 0, 1 }});
105105
checkOpenBlocksReceive(fetchShuffleBlocks, blockMarkers);
106106

107107
verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 0);

common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ public void testFetchOne() {
6464
BlockFetchingListener listener = fetchBlocks(
6565
blocks,
6666
blockIds,
67-
new FetchShuffleBlocks("app-id", "exec-id", 0, new int[] { 0 }, new int[][] {{ 0 }}),
67+
new FetchShuffleBlocks("app-id", "exec-id", 0, new long[] { 0 }, new int[][] {{ 0 }}),
6868
conf);
6969

7070
verify(listener).onBlockFetchSuccess("shuffle_0_0_0", blocks.get("shuffle_0_0_0"));
@@ -100,7 +100,7 @@ public void testFetchThreeShuffleBlocks() {
100100
BlockFetchingListener listener = fetchBlocks(
101101
blocks,
102102
blockIds,
103-
new FetchShuffleBlocks("app-id", "exec-id", 0, new int[] { 0 }, new int[][] {{ 0, 1, 2 }}),
103+
new FetchShuffleBlocks("app-id", "exec-id", 0, new long[] { 0 }, new int[][] {{ 0, 1, 2 }}),
104104
conf);
105105

106106
for (int i = 0; i < 3; i ++) {

core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,17 +42,13 @@ public interface ShuffleExecutorComponents {
4242
* partitioned bytes written by that map task.
4343
*
4444
* @param shuffleId Unique identifier for the shuffle the map task is a part of
45-
* @param mapId Within the shuffle, the identifier of the map task
46-
* @param mapTaskAttemptId Identifier of the task attempt. Multiple attempts of the same map task
47-
* with the same (shuffleId, mapId) pair can be distinguished by the
48-
* different values of mapTaskAttemptId.
45+
* @param mapId An ID of the map task. The ID is unique within this Spark application.
4946
* @param numPartitions The number of partitions that will be written by the map task. Some of
5047
* these partitions may be empty.
5148
*/
5249
ShuffleMapOutputWriter createMapOutputWriter(
5350
int shuffleId,
54-
int mapId,
55-
long mapTaskAttemptId,
51+
long mapId,
5652
int numPartitions) throws IOException;
5753

5854
/**
@@ -64,15 +60,11 @@ ShuffleMapOutputWriter createMapOutputWriter(
6460
* preserving an optimization in the local disk shuffle storage implementation.
6561
*
6662
* @param shuffleId Unique identifier for the shuffle the map task is a part of
67-
* @param mapId Within the shuffle, the identifier of the map task
68-
* @param mapTaskAttemptId Identifier of the task attempt. Multiple attempts of the same map task
69-
* with the same (shuffleId, mapId) pair can be distinguished by the
70-
* different values of mapTaskAttemptId.
63+
* @param mapId An ID of the map task. The ID is unique within this Spark application.
7164
*/
7265
default Optional<SingleSpillShuffleMapOutputWriter> createSingleFileMapOutputWriter(
7366
int shuffleId,
74-
int mapId,
75-
long mapTaskAttemptId) throws IOException {
67+
long mapId) throws IOException {
7668
return Optional.empty();
7769
}
7870
}

core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ public interface ShuffleMapOutputWriter {
3939
* for the same partition within any given map task. The partition identifier will be in the
4040
* range of precisely 0 (inclusive) to numPartitions (exclusive), where numPartitions was
4141
* provided upon the creation of this map output writer via
42-
* {@link ShuffleExecutorComponents#createMapOutputWriter(int, int, long, int)}.
42+
* {@link ShuffleExecutorComponents#createMapOutputWriter(int, long, int)}.
4343
* <p>
4444
* Calls to this method will be invoked with monotonically increasing reducePartitionIds; each
4545
* call to this method will be called with a reducePartitionId that is strictly greater than

0 commit comments

Comments
 (0)