Skip to content

Commit 94c67a7

Browse files
jiangxb1987sameeragarwal
authored andcommitted
[SPARK-23207][SQL] Shuffle+Repartition on a DataFrame could lead to incorrect answers
## What changes were proposed in this pull request? Currently shuffle repartition uses RoundRobinPartitioning, the generated result is nondeterministic since the sequence of input rows are not determined. The bug can be triggered when there is a repartition call following a shuffle (which would lead to non-deterministic row ordering), as the pattern shows below: upstream stage -> repartition stage -> result stage (-> indicate a shuffle) When one of the executors process goes down, some tasks on the repartition stage will be retried and generate inconsistent ordering, and some tasks of the result stage will be retried generating different data. The following code returns 931532, instead of 1000000: ``` import scala.sys.process._ import org.apache.spark.TaskContext val res = spark.range(0, 1000 * 1000, 1).repartition(200).map { x => x }.repartition(200).map { x => if (TaskContext.get.attemptNumber == 0 && TaskContext.get.partitionId < 2) { throw new Exception("pkill -f java".!!) } x } res.distinct().count() ``` In this PR, we propose a most straight-forward way to fix this problem by performing a local sort before partitioning, after we make the input row ordering deterministic, the function from rows to partitions is fully deterministic too. The downside of the approach is that with extra local sort inserted, the performance of repartition() will go down, so we add a new config named `spark.sql.execution.sortBeforeRepartition` to control whether this patch is applied. The patch is default enabled to be safe-by-default, but user may choose to manually turn it off to avoid performance regression. This patch also changes the output rows ordering of repartition(), that leads to a bunch of test cases failure because they are comparing the results directly. ## How was this patch tested? Add unit test in ExchangeSuite. With this patch(and `spark.sql.execution.sortBeforeRepartition` set to true), the following query returns 1000000: ``` import scala.sys.process._ import org.apache.spark.TaskContext spark.conf.set("spark.sql.execution.sortBeforeRepartition", "true") val res = spark.range(0, 1000 * 1000, 1).repartition(200).map { x => x }.repartition(200).map { x => if (TaskContext.get.attemptNumber == 0 && TaskContext.get.partitionId < 2) { throw new Exception("pkill -f java".!!) } x } res.distinct().count() res7: Long = 1000000 ``` Author: Xingbo Jiang <[email protected]> Closes #20393 from jiangxb1987/shuffle-repartition.
1 parent a8a3e9b commit 94c67a7

File tree

17 files changed

+233
-29
lines changed

17 files changed

+233
-29
lines changed

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ public abstract class RecordComparator {
3232
public abstract int compare(
3333
Object leftBaseObject,
3434
long leftBaseOffset,
35+
int leftBaseLength,
3536
Object rightBaseObject,
36-
long rightBaseOffset);
37+
long rightBaseOffset,
38+
int rightBaseLength);
3739
}

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,13 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) {
6262
int uaoSize = UnsafeAlignedOffset.getUaoSize();
6363
if (prefixComparisonResult == 0) {
6464
final Object baseObject1 = memoryManager.getPage(r1.recordPointer);
65-
// skip length
6665
final long baseOffset1 = memoryManager.getOffsetInPage(r1.recordPointer) + uaoSize;
66+
final int baseLength1 = UnsafeAlignedOffset.getSize(baseObject1, baseOffset1 - uaoSize);
6767
final Object baseObject2 = memoryManager.getPage(r2.recordPointer);
68-
// skip length
6968
final long baseOffset2 = memoryManager.getOffsetInPage(r2.recordPointer) + uaoSize;
70-
return recordComparator.compare(baseObject1, baseOffset1, baseObject2, baseOffset2);
69+
final int baseLength2 = UnsafeAlignedOffset.getSize(baseObject2, baseOffset2 - uaoSize);
70+
return recordComparator.compare(baseObject1, baseOffset1, baseLength1, baseObject2,
71+
baseOffset2, baseLength2);
7172
} else {
7273
return prefixComparisonResult;
7374
}

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ final class UnsafeSorterSpillMerger {
3535
prefixComparator.compare(left.getKeyPrefix(), right.getKeyPrefix());
3636
if (prefixComparisonResult == 0) {
3737
return recordComparator.compare(
38-
left.getBaseObject(), left.getBaseOffset(),
39-
right.getBaseObject(), right.getBaseOffset());
38+
left.getBaseObject(), left.getBaseOffset(), left.getRecordLength(),
39+
right.getBaseObject(), right.getBaseOffset(), right.getRecordLength());
4040
} else {
4141
return prefixComparisonResult;
4242
}

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,8 @@ abstract class RDD[T: ClassTag](
414414
*
415415
* If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
416416
* which can avoid performing a shuffle.
417+
*
418+
* TODO Fix the Shuffle+Repartition data loss issue described in SPARK-23207.
417419
*/
418420
def repartition(numPartitions: Int)(implicit ord: Ordering[T] = null): RDD[T] = withScope {
419421
coalesce(numPartitions, shuffle = true)

core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,10 @@ public class UnsafeExternalSorterSuite {
7272
public int compare(
7373
Object leftBaseObject,
7474
long leftBaseOffset,
75+
int leftBaseLength,
7576
Object rightBaseObject,
76-
long rightBaseOffset) {
77+
long rightBaseOffset,
78+
int rightBaseLength) {
7779
return 0;
7880
}
7981
};

core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,10 @@ public void testSortingOnlyByIntegerPrefix() throws Exception {
9898
public int compare(
9999
Object leftBaseObject,
100100
long leftBaseOffset,
101+
int leftBaseLength,
101102
Object rightBaseObject,
102-
long rightBaseOffset) {
103+
long rightBaseOffset,
104+
int rightBaseLength) {
103105
return 0;
104106
}
105107
};
@@ -164,8 +166,10 @@ public void freeAfterOOM() {
164166
public int compare(
165167
Object leftBaseObject,
166168
long leftBaseOffset,
169+
int leftBaseLength,
167170
Object rightBaseObject,
168-
long rightBaseOffset) {
171+
long rightBaseOffset,
172+
int rightBaseLength) {
169173
return 0;
170174
}
171175
};

mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
222222
val oldModel = new OldWord2VecModel(word2VecMap)
223223
val instance = new Word2VecModel("myWord2VecModel", oldModel)
224224
val newInstance = testDefaultReadWrite(instance)
225-
assert(newInstance.getVectors.collect() === instance.getVectors.collect())
225+
assert(newInstance.getVectors.collect().sortBy(_.getString(0)) ===
226+
instance.getVectors.collect().sortBy(_.getString(0)))
226227
}
227228

228229
test("Word2Vec works with input that is non-nullable (NGram)") {
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution;
19+
20+
import org.apache.spark.unsafe.Platform;
21+
import org.apache.spark.util.collection.unsafe.sort.RecordComparator;
22+
23+
public final class RecordBinaryComparator extends RecordComparator {
24+
25+
// TODO(jiangxb) Add test suite for this.
26+
@Override
27+
public int compare(
28+
Object leftObj, long leftOff, int leftLen, Object rightObj, long rightOff, int rightLen) {
29+
int i = 0;
30+
int res = 0;
31+
32+
// If the arrays have different length, the longer one is larger.
33+
if (leftLen != rightLen) {
34+
return leftLen - rightLen;
35+
}
36+
37+
// The following logic uses `leftLen` as the length for both `leftObj` and `rightObj`, since
38+
// we have guaranteed `leftLen` == `rightLen`.
39+
40+
// check if stars align and we can get both offsets to be aligned
41+
if ((leftOff % 8) == (rightOff % 8)) {
42+
while ((leftOff + i) % 8 != 0 && i < leftLen) {
43+
res = (Platform.getByte(leftObj, leftOff + i) & 0xff) -
44+
(Platform.getByte(rightObj, rightOff + i) & 0xff);
45+
if (res != 0) return res;
46+
i += 1;
47+
}
48+
}
49+
// for architectures that support unaligned accesses, chew it up 8 bytes at a time
50+
if (Platform.unaligned() || (((leftOff + i) % 8 == 0) && ((rightOff + i) % 8 == 0))) {
51+
while (i <= leftLen - 8) {
52+
res = (int) ((Platform.getLong(leftObj, leftOff + i) -
53+
Platform.getLong(rightObj, rightOff + i)) % Integer.MAX_VALUE);
54+
if (res != 0) return res;
55+
i += 8;
56+
}
57+
}
58+
// this will finish off the unaligned comparisons, or do the entire aligned comparison
59+
// whichever is needed.
60+
while (i < leftLen) {
61+
res = (Platform.getByte(leftObj, leftOff + i) & 0xff) -
62+
(Platform.getByte(rightObj, rightOff + i) & 0xff);
63+
if (res != 0) return res;
64+
i += 1;
65+
}
66+
67+
// The two arrays are equal.
68+
return 0;
69+
}
70+
}

sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
package org.apache.spark.sql.execution;
1919

2020
import java.io.IOException;
21+
import java.util.function.Supplier;
2122

23+
import org.apache.spark.sql.catalyst.util.TypeUtils;
2224
import scala.collection.AbstractIterator;
2325
import scala.collection.Iterator;
2426
import scala.math.Ordering;
@@ -56,26 +58,50 @@ public abstract static class PrefixComputer {
5658

5759
public static class Prefix {
5860
/** Key prefix value, or the null prefix value if isNull = true. **/
59-
long value;
61+
public long value;
6062

6163
/** Whether the key is null. */
62-
boolean isNull;
64+
public boolean isNull;
6365
}
6466

6567
/**
6668
* Computes prefix for the given row. For efficiency, the returned object may be reused in
6769
* further calls to a given PrefixComputer.
6870
*/
69-
abstract Prefix computePrefix(InternalRow row);
71+
public abstract Prefix computePrefix(InternalRow row);
7072
}
7173

72-
public UnsafeExternalRowSorter(
74+
public static UnsafeExternalRowSorter createWithRecordComparator(
75+
StructType schema,
76+
Supplier<RecordComparator> recordComparatorSupplier,
77+
PrefixComparator prefixComparator,
78+
PrefixComputer prefixComputer,
79+
long pageSizeBytes,
80+
boolean canUseRadixSort) throws IOException {
81+
return new UnsafeExternalRowSorter(schema, recordComparatorSupplier, prefixComparator,
82+
prefixComputer, pageSizeBytes, canUseRadixSort);
83+
}
84+
85+
public static UnsafeExternalRowSorter create(
7386
StructType schema,
7487
Ordering<InternalRow> ordering,
7588
PrefixComparator prefixComparator,
7689
PrefixComputer prefixComputer,
7790
long pageSizeBytes,
7891
boolean canUseRadixSort) throws IOException {
92+
Supplier<RecordComparator> recordComparatorSupplier =
93+
() -> new RowComparator(ordering, schema.length());
94+
return new UnsafeExternalRowSorter(schema, recordComparatorSupplier, prefixComparator,
95+
prefixComputer, pageSizeBytes, canUseRadixSort);
96+
}
97+
98+
private UnsafeExternalRowSorter(
99+
StructType schema,
100+
Supplier<RecordComparator> recordComparatorSupplier,
101+
PrefixComparator prefixComparator,
102+
PrefixComputer prefixComputer,
103+
long pageSizeBytes,
104+
boolean canUseRadixSort) throws IOException {
79105
this.schema = schema;
80106
this.prefixComputer = prefixComputer;
81107
final SparkEnv sparkEnv = SparkEnv.get();
@@ -85,7 +111,7 @@ public UnsafeExternalRowSorter(
85111
sparkEnv.blockManager(),
86112
sparkEnv.serializerManager(),
87113
taskContext,
88-
() -> new RowComparator(ordering, schema.length()),
114+
recordComparatorSupplier,
89115
prefixComparator,
90116
sparkEnv.conf().getInt("spark.shuffle.sort.initialBufferSize",
91117
DEFAULT_INITIAL_SORT_BUFFER_SIZE),
@@ -206,7 +232,13 @@ private static final class RowComparator extends RecordComparator {
206232
}
207233

208234
@Override
209-
public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) {
235+
public int compare(
236+
Object baseObj1,
237+
long baseOff1,
238+
int baseLen1,
239+
Object baseObj2,
240+
long baseOff2,
241+
int baseLen2) {
210242
// Note that since ordering doesn't need the total length of the record, we just pass 0
211243
// into the row.
212244
row1.pointTo(baseObj1, baseOff1, 0);

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,6 +1145,18 @@ object SQLConf {
11451145
.checkValues(PartitionOverwriteMode.values.map(_.toString))
11461146
.createWithDefault(PartitionOverwriteMode.STATIC.toString)
11471147

1148+
val SORT_BEFORE_REPARTITION =
1149+
buildConf("spark.sql.execution.sortBeforeRepartition")
1150+
.internal()
1151+
.doc("When perform a repartition following a shuffle, the output row ordering would be " +
1152+
"nondeterministic. If some downstream stages fail and some tasks of the repartition " +
1153+
"stage retry, these tasks may generate different data, and that can lead to correctness " +
1154+
"issues. Turn on this config to insert a local sort before actually doing repartition " +
1155+
"to generate consistent repartition results. The performance of repartition() may go " +
1156+
"down since we insert extra local sort before it.")
1157+
.booleanConf
1158+
.createWithDefault(true)
1159+
11481160
object Deprecated {
11491161
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
11501162
}
@@ -1300,6 +1312,8 @@ class SQLConf extends Serializable with Logging {
13001312

13011313
def stringRedationPattern: Option[Regex] = SQL_STRING_REDACTION_PATTERN.readFrom(reader)
13021314

1315+
def sortBeforeRepartition: Boolean = getConf(SORT_BEFORE_REPARTITION)
1316+
13031317
/**
13041318
* Returns the [[Resolver]] for the current configuration, which can be used to determine if two
13051319
* identifiers are equal.

0 commit comments

Comments
 (0)