From 495cba55aee0223daef089fc8513962997468f77 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Fri, 26 Jan 2018 15:01:03 -0800 Subject: [PATCH 1/6] [SPARK-23207][SQL] Shuffle+Repartition on a DataFrame could lead to incorrect answers Currently shuffle repartition uses RoundRobinPartitioning, the generated result is nondeterministic since the sequence of input rows are not determined. The bug can be triggered when there is a repartition call following a shuffle (which would lead to non-deterministic row ordering), as the pattern shows below: upstream stage -> repartition stage -> result stage (-> indicate a shuffle) When one of the executors process goes down, some tasks on the repartition stage will be retried and generate inconsistent ordering, and some tasks of the result stage will be retried generating different data. The following code returns 931532, instead of 1000000: ``` import scala.sys.process._ import org.apache.spark.TaskContext val res = spark.range(0, 1000 * 1000, 1).repartition(200).map { x => x }.repartition(200).map { x => if (TaskContext.get.attemptNumber == 0 && TaskContext.get.partitionId < 2) { throw new Exception("pkill -f java".!!) } x } res.distinct().count() ``` In this PR, we propose a most straight-forward way to fix this problem by performing a local sort before partitioning, after we make the input row ordering deterministic, the function from rows to partitions is fully deterministic too. The downside of the approach is that with extra local sort inserted, the performance of repartition() will go down, so we add a new config named `spark.sql.execution.sortBeforeRepartition` to control whether this patch is applied. The patch is default enabled to be safe-by-default, but user may choose to manually turn it off to avoid performance regression. This patch also changes the output rows ordering of repartition(), that leads to a bunch of test cases failure because they are comparing the results directly. Add unit test in ExchangeSuite. With this patch(and `spark.sql.execution.sortBeforeRepartition` set to true), the following query returns 1000000: ``` import scala.sys.process._ import org.apache.spark.TaskContext spark.conf.set("spark.sql.execution.sortBeforeRepartition", "true") val res = spark.range(0, 1000 * 1000, 1).repartition(200).map { x => x }.repartition(200).map { x => if (TaskContext.get.attemptNumber == 0 && TaskContext.get.partitionId < 2) { throw new Exception("pkill -f java".!!) } x } res.distinct().count() res7: Long = 1000000 ``` Author: Xingbo Jiang Closes #20393 from jiangxb1987/shuffle-repartition. --- .../unsafe/sort/RecordComparator.java | 4 +- .../unsafe/sort/UnsafeInMemorySorter.java | 7 +- .../unsafe/sort/UnsafeSorterSpillMerger.java | 4 +- .../main/scala/org/apache/spark/rdd/RDD.scala | 2 + .../sort/UnsafeExternalSorterSuite.java | 4 +- .../sort/UnsafeInMemorySorterSuite.java | 8 ++- .../spark/mllib/feature/ChiSqSelector.scala | 2 +- .../spark/ml/feature/Word2VecSuite.scala | 3 +- .../sql/execution/RecordBinaryComparator.java | 70 +++++++++++++++++++ .../execution/UnsafeExternalRowSorter.java | 46 ++++++++++-- .../apache/spark/sql/internal/SQLConf.scala | 14 ++++ .../sql/execution/UnsafeKVExternalSorter.java | 10 ++- .../apache/spark/sql/execution/SortExec.scala | 2 +- .../execution/exchange/ShuffleExchange.scala | 51 +++++++++++++- .../spark/sql/execution/ExchangeSuite.scala | 26 ++++++- .../datasources/parquet/ParquetIOSuite.scala | 6 +- .../streaming/ForeachSinkSuite.scala | 6 +- 17 files changed, 235 insertions(+), 30 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java index 09e4258792204..02b5de8e128c9 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/RecordComparator.java @@ -32,6 +32,8 @@ public abstract class RecordComparator { public abstract int compare( Object leftBaseObject, long leftBaseOffset, + int leftBaseLength, Object rightBaseObject, - long rightBaseOffset); + long rightBaseOffset, + int rightBaseLength); } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 869ec908be1fb..839b41db4082b 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -61,12 +61,13 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) { int uaoSize = UnsafeAlignedOffset.getUaoSize(); if (prefixComparisonResult == 0) { final Object baseObject1 = memoryManager.getPage(r1.recordPointer); - // skip length final long baseOffset1 = memoryManager.getOffsetInPage(r1.recordPointer) + uaoSize; + final int baseLength1 = UnsafeAlignedOffset.getSize(baseObject1, baseOffset1 - uaoSize); final Object baseObject2 = memoryManager.getPage(r2.recordPointer); - // skip length final long baseOffset2 = memoryManager.getOffsetInPage(r2.recordPointer) + uaoSize; - return recordComparator.compare(baseObject1, baseOffset1, baseObject2, baseOffset2); + final int baseLength2 = UnsafeAlignedOffset.getSize(baseObject2, baseOffset2 - uaoSize); + return recordComparator.compare(baseObject1, baseOffset1, baseLength1, baseObject2, + baseOffset2, baseLength2); } else { return prefixComparisonResult; } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java index cf4dfde86ca91..ff0dcc259a4ad 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java @@ -35,8 +35,8 @@ final class UnsafeSorterSpillMerger { prefixComparator.compare(left.getKeyPrefix(), right.getKeyPrefix()); if (prefixComparisonResult == 0) { return recordComparator.compare( - left.getBaseObject(), left.getBaseOffset(), - right.getBaseObject(), right.getBaseOffset()); + left.getBaseObject(), left.getBaseOffset(), left.getRecordLength(), + right.getBaseObject(), right.getBaseOffset(), right.getRecordLength()); } else { return prefixComparisonResult; } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 63a87e7f09d85..102836d5a5589 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -413,6 +413,8 @@ abstract class RDD[T: ClassTag]( * * If you are decreasing the number of partitions in this RDD, consider using `coalesce`, * which can avoid performing a shuffle. + * + * TODO Fix the Shuffle+Repartition data loss issue described in SPARK-23207. */ def repartition(numPartitions: Int)(implicit ord: Ordering[T] = null): RDD[T] = withScope { coalesce(numPartitions, shuffle = true) diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 8d847daba2cdf..cce01a3275421 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -71,8 +71,10 @@ public class UnsafeExternalSorterSuite { public int compare( Object leftBaseObject, long leftBaseOffset, + int leftBaseLength, Object rightBaseObject, - long rightBaseOffset) { + long rightBaseOffset, + int rightBaseLength) { return 0; } }; diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index 1a3e11efe9787..cfb00307fa883 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -98,8 +98,10 @@ public void testSortingOnlyByIntegerPrefix() throws Exception { public int compare( Object leftBaseObject, long leftBaseOffset, + int leftBaseLength, Object rightBaseObject, - long rightBaseOffset) { + long rightBaseOffset, + int rightBaseLength) { return 0; } }; @@ -164,8 +166,10 @@ public void freeAfterOOM() { public int compare( Object leftBaseObject, long leftBaseOffset, + int leftBaseLength, Object rightBaseObject, - long rightBaseOffset) { + long rightBaseOffset, + int rightBaseLength) { return 0; } }; diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index 862be6f37e7e3..015cc9f81e8c4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -144,7 +144,7 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { val dataArray = Array.tabulate(model.selectedFeatures.length) { i => Data(model.selectedFeatures(i)) } - spark.createDataFrame(dataArray).repartition(1).write.parquet(Loader.dataPath(path)) + spark.createDataFrame(sc.makeRDD(dataArray, 1)).write.parquet(Loader.dataPath(path)) } def load(sc: SparkContext, path: String): ChiSqSelectorModel = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 6183606a7b2ac..10682ba176aca 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -222,7 +222,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val oldModel = new OldWord2VecModel(word2VecMap) val instance = new Word2VecModel("myWord2VecModel", oldModel) val newInstance = testDefaultReadWrite(instance) - assert(newInstance.getVectors.collect() === instance.getVectors.collect()) + assert(newInstance.getVectors.collect().sortBy(_.getString(0)) === + instance.getVectors.collect().sortBy(_.getString(0))) } test("Word2Vec works with input that is non-nullable (NGram)") { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java new file mode 100644 index 0000000000000..bb77b5bf6de2a --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution; + +import org.apache.spark.unsafe.Platform; +import org.apache.spark.util.collection.unsafe.sort.RecordComparator; + +public final class RecordBinaryComparator extends RecordComparator { + + // TODO(jiangxb) Add test suite for this. + @Override + public int compare( + Object leftObj, long leftOff, int leftLen, Object rightObj, long rightOff, int rightLen) { + int i = 0; + int res = 0; + + // If the arrays have different length, the longer one is larger. + if (leftLen != rightLen) { + return leftLen - rightLen; + } + + // The following logic uses `leftLen` as the length for both `leftObj` and `rightObj`, since + // we have guaranteed `leftLen` == `rightLen`. + + // check if stars align and we can get both offsets to be aligned + if ((leftOff % 8) == (rightOff % 8)) { + while ((leftOff + i) % 8 != 0 && i < leftLen) { + res = (Platform.getByte(leftObj, leftOff + i) & 0xff) - + (Platform.getByte(rightObj, rightOff + i) & 0xff); + if (res != 0) return res; + i += 1; + } + } + // for architectures that support unaligned accesses, chew it up 8 bytes at a time + if (Platform.unaligned() || (((leftOff + i) % 8 == 0) && ((rightOff + i) % 8 == 0))) { + while (i <= leftLen - 8) { + res = (int) ((Platform.getLong(leftObj, leftOff + i) - + Platform.getLong(rightObj, rightOff + i)) % Integer.MAX_VALUE); + if (res != 0) return res; + i += 8; + } + } + // this will finish off the unaligned comparisons, or do the entire aligned comparison + // whichever is needed. + while (i < leftLen) { + res = (Platform.getByte(leftObj, leftOff + i) & 0xff) - + (Platform.getByte(rightObj, rightOff + i) & 0xff); + if (res != 0) return res; + i += 1; + } + + // The two arrays are equal. + return 0; + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index c29b002a998ca..3ec17c1e61ca6 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution; import java.io.IOException; +import java.util.function.Supplier; import scala.collection.Iterator; import scala.math.Ordering; @@ -55,26 +56,50 @@ public abstract static class PrefixComputer { public static class Prefix { /** Key prefix value, or the null prefix value if isNull = true. **/ - long value; + public long value; /** Whether the key is null. */ - boolean isNull; + public boolean isNull; } /** * Computes prefix for the given row. For efficiency, the returned object may be reused in * further calls to a given PrefixComputer. */ - abstract Prefix computePrefix(InternalRow row); + public abstract Prefix computePrefix(InternalRow row); } - public UnsafeExternalRowSorter( + public static UnsafeExternalRowSorter createWithRecordComparator( + StructType schema, + Supplier recordComparatorSupplier, + PrefixComparator prefixComparator, + PrefixComputer prefixComputer, + long pageSizeBytes, + boolean canUseRadixSort) throws IOException { + return new UnsafeExternalRowSorter(schema, recordComparatorSupplier, prefixComparator, + prefixComputer, pageSizeBytes, canUseRadixSort); + } + + public static UnsafeExternalRowSorter create( StructType schema, Ordering ordering, PrefixComparator prefixComparator, PrefixComputer prefixComputer, long pageSizeBytes, boolean canUseRadixSort) throws IOException { + Supplier recordComparatorSupplier = + () -> new RowComparator(ordering, schema.length()); + return new UnsafeExternalRowSorter(schema, recordComparatorSupplier, prefixComparator, + prefixComputer, pageSizeBytes, canUseRadixSort); + } + + private UnsafeExternalRowSorter( + StructType schema, + Supplier recordComparatorSupplier, + PrefixComparator prefixComparator, + PrefixComputer prefixComputer, + long pageSizeBytes, + boolean canUseRadixSort) throws IOException { this.schema = schema; this.prefixComputer = prefixComputer; final SparkEnv sparkEnv = SparkEnv.get(); @@ -84,7 +109,7 @@ public UnsafeExternalRowSorter( sparkEnv.blockManager(), sparkEnv.serializerManager(), taskContext, - new RowComparator(ordering, schema.length()), + recordComparatorSupplier.get(), prefixComparator, sparkEnv.conf().getInt("spark.shuffle.sort.initialBufferSize", DEFAULT_INITIAL_SORT_BUFFER_SIZE), @@ -207,8 +232,15 @@ private static final class RowComparator extends RecordComparator { } @Override - public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) { - // TODO: Why are the sizes -1? + public int compare( + Object baseObj1, + long baseOff1, + int baseLen1, + Object baseObj2, + long baseOff2, + int baseLen2) { + // Note that since ordering doesn't need the total length of the record, we just pass 0 + // into the row. row1.pointTo(baseObj1, baseOff1, -1); row2.pointTo(baseObj2, baseOff2, -1); return ordering.compare(row1, row2); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ebabd1a1396b4..9db5acd7262a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -829,6 +829,18 @@ object SQLConf { .regexConf .createWithDefault("(?i)url".r) + val SORT_BEFORE_REPARTITION = + buildConf("spark.sql.execution.sortBeforeRepartition") + .internal() + .doc("When perform a repartition following a shuffle, the output row ordering would be " + + "nondeterministic. If some downstream stages fail and some tasks of the repartition " + + "stage retry, these tasks may generate different data, and that can lead to correctness " + + "issues. Turn on this config to insert a local sort before actually doing repartition " + + "to generate consistent repartition results. The performance of repartition() may go " + + "down since we insert extra local sort before it.") + .booleanConf + .createWithDefault(true) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -961,6 +973,8 @@ class SQLConf extends Serializable with Logging { def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS) + def sortBeforeRepartition: Boolean = getConf(SORT_BEFORE_REPARTITION) + /** * Returns the [[Resolver]] for the current configuration, which can be used to determine if two * identifiers are equal. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 7d67b87ed915d..7549decf57488 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -252,8 +252,14 @@ private static final class KVComparator extends RecordComparator { } @Override - public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) { - // Note that since ordering doesn't need the total length of the record, we just pass -1 + public int compare( + Object baseObj1, + long baseOff1, + int baseLen1, + Object baseObj2, + long baseOff2, + int baseLen2) { + // Note that since ordering doesn't need the total length of the record, we just pass -1 // into the row. row1.pointTo(baseObj1, baseOff1 + 4, -1); row2.pointTo(baseObj2, baseOff2 + 4, -1); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index f98ae82574d20..d225979a33299 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -84,7 +84,7 @@ case class SortExec( } val pageSize = SparkEnv.get.memoryManager.pageSizeBytes - val sorter = new UnsafeExternalRowSorter( + val sorter = UnsafeExternalRowSorter.create( schema, ordering, prefixComparator, prefixComputer, pageSize, canUseRadixSort) if (testSpillFrequency > 0) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala index eebe6ad2e7944..b253e51a30e4f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.exchange import java.util.Random +import java.util.function.Supplier import org.apache.spark._ import org.apache.spark.rdd.RDD @@ -30,7 +31,10 @@ import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType import org.apache.spark.util.MutablePair +import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordComparator} /** * Performs a shuffle that will result in the desired `newPartitioning`. @@ -242,14 +246,57 @@ object ShuffleExchange { case RangePartitioning(_, _) | SinglePartition => identity case _ => sys.error(s"Exchange not implemented for $newPartitioning") } + val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = { - if (needToCopyObjectsBeforeShuffle(part, serializer)) { + // [SPARK-23207] Have to make sure the generated RoundRobinPartitioning is deterministic, + // otherwise a retry task may output different rows and thus lead to data loss. + // + // Currently we following the most straight-forward way that perform a local sort before + // partitioning. + val newRdd = if (SparkEnv.get.conf.get(SQLConf.SORT_BEFORE_REPARTITION) && + newPartitioning.isInstanceOf[RoundRobinPartitioning]) { rdd.mapPartitionsInternal { iter => + val recordComparatorSupplier = new Supplier[RecordComparator] { + override def get: RecordComparator = new RecordBinaryComparator() + } + // The comparator for comparing row hashcode, which should always be Integer. + val prefixComparator = PrefixComparators.LONG + val canUseRadixSort = SparkEnv.get.conf.get(SQLConf.RADIX_SORT_ENABLED) + // The prefix computer generates row hashcode as the prefix, so we may decrease the + // probability that the prefixes are equal when input rows choose column values from a + // limited range. + val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { + private val result = new UnsafeExternalRowSorter.PrefixComputer.Prefix + override def computePrefix(row: InternalRow): + UnsafeExternalRowSorter.PrefixComputer.Prefix = { + // The hashcode generated from the binary form of a [[UnsafeRow]] should not be null. + result.isNull = false + result.value = row.hashCode() + result + } + } + val pageSize = SparkEnv.get.memoryManager.pageSizeBytes + + val sorter = UnsafeExternalRowSorter.createWithRecordComparator( + StructType.fromAttributes(outputAttributes), + recordComparatorSupplier, + prefixComparator, + prefixComputer, + pageSize, + canUseRadixSort) + sorter.sort(iter.asInstanceOf[Iterator[UnsafeRow]]) + } + } else { + rdd + } + + if (needToCopyObjectsBeforeShuffle(part, serializer)) { + newRdd.mapPartitionsInternal { iter => val getPartitionKey = getPartitionKeyExtractor() iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) } } } else { - rdd.mapPartitionsInternal { iter => + newRdd.mapPartitionsInternal { iter => val getPartitionKey = getPartitionKeyExtractor() val mutablePair = new MutablePair[Int, InternalRow]() iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)), row) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 59eaf4d1c29b7..abd3e6ca37840 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -17,11 +17,14 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.Row +import scala.util.Random + +import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, IdentityBroadcastMode, SinglePartition} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange} import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext class ExchangeSuite extends SparkPlanTest with SharedSQLContext { @@ -101,4 +104,25 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { assert(exchange4.sameResult(exchange5)) assert(exchange5 sameResult exchange4) } + + test("SPARK-23207: Make repartition() generate consistent output") { + def assertConsistency(ds: Dataset[java.lang.Long]): Unit = { + ds.persist() + + val exchange = ds.mapPartitions { iter => + Random.shuffle(iter) + }.repartition(111) + val exchange2 = ds.repartition(111) + + assert(exchange.rdd.collectPartitions() === exchange2.rdd.collectPartitions()) + } + + withSQLConf(SQLConf.SORT_BEFORE_REPARTITION.key -> "true") { + // repartition() should generate consistent output. + assertConsistency(spark.range(10000)) + + // case when input contains duplicated rows. + assertConsistency(spark.range(10000).map(i => Random.nextInt(1000).toLong)) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 94a2f9a00b3f3..34f00aaea4093 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -661,7 +661,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { val v = (row.getInt(0), row.getString(1)) result += v } - assert(data == result) + assert(data.toSet == result.toSet) } finally { reader.close() } @@ -677,7 +677,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { val row = reader.getCurrentValue.asInstanceOf[InternalRow] result += row.getString(0) } - assert(data.map(_._2) == result) + assert(data.map(_._2).toSet == result.toSet) } finally { reader.close() } @@ -694,7 +694,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { val v = (row.getString(0), row.getInt(1)) result += v } - assert(data.map { x => (x._2, x._1) } == result) + assert(data.map { x => (x._2, x._1) }.toSet == result.toSet) } finally { reader.close() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala index 9137d650e906b..1248c670df45c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala @@ -52,13 +52,13 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf var expectedEventsForPartition0 = Seq( ForeachSinkSuite.Open(partition = 0, version = 0), - ForeachSinkSuite.Process(value = 1), + ForeachSinkSuite.Process(value = 2), ForeachSinkSuite.Process(value = 3), ForeachSinkSuite.Close(None) ) var expectedEventsForPartition1 = Seq( ForeachSinkSuite.Open(partition = 1, version = 0), - ForeachSinkSuite.Process(value = 2), + ForeachSinkSuite.Process(value = 1), ForeachSinkSuite.Process(value = 4), ForeachSinkSuite.Close(None) ) @@ -162,7 +162,7 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf val allEvents = ForeachSinkSuite.allEvents() assert(allEvents.size === 1) assert(allEvents(0)(0) === ForeachSinkSuite.Open(partition = 0, version = 0)) - assert(allEvents(0)(1) === ForeachSinkSuite.Process(value = 1)) + assert(allEvents(0)(1) === ForeachSinkSuite.Process(value = 2)) // `close` should be called with the error val errorEvent = allEvents(0)(2).asInstanceOf[ForeachSinkSuite.Close] From 8d2d5585b2c2832cd4d88b3851607ce15180cca5 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Sun, 12 Aug 2018 17:03:10 -0700 Subject: [PATCH 2/6] Fix comment --- .../org/apache/spark/sql/execution/UnsafeExternalRowSorter.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 3ec17c1e61ca6..54cec601f614e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -239,7 +239,7 @@ public int compare( Object baseObj2, long baseOff2, int baseLen2) { - // Note that since ordering doesn't need the total length of the record, we just pass 0 + // Note that since ordering doesn't need the total length of the record, we just pass -1 // into the row. row1.pointTo(baseObj1, baseOff1, -1); row2.pointTo(baseObj2, baseOff2, -1); From 81f57febfa5f81cd41ac7803eb9d8931df88fc5c Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Tue, 30 Jan 2018 11:40:42 +0800 Subject: [PATCH 3/6] [SPARK-23207][SQL][FOLLOW-UP] Don't perform local sort for DataFrame.repartition(1) In `ShuffleExchangeExec`, we don't need to insert extra local sort before round-robin partitioning, if the new partitioning has only 1 partition, because under that case all output rows go to the same partition. The existing test cases. Author: Xingbo Jiang Closes #20426 from jiangxb1987/repartition1. --- .../apache/spark/sql/execution/exchange/ShuffleExchange.scala | 4 ++++ .../spark/sql/execution/streaming/ForeachSinkSuite.scala | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala index b253e51a30e4f..c0ba5135e5ed8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala @@ -253,7 +253,11 @@ object ShuffleExchange { // // Currently we following the most straight-forward way that perform a local sort before // partitioning. + // + // Note that we don't perform local sort if the new partitioning has only 1 partition, under + // that case all output rows go to the same partition. val newRdd = if (SparkEnv.get.conf.get(SQLConf.SORT_BEFORE_REPARTITION) && + newPartitioning.numPartitions > 1 && newPartitioning.isInstanceOf[RoundRobinPartitioning]) { rdd.mapPartitionsInternal { iter => val recordComparatorSupplier = new Supplier[RecordComparator] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala index 1248c670df45c..41434e6d8b974 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala @@ -162,7 +162,7 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf val allEvents = ForeachSinkSuite.allEvents() assert(allEvents.size === 1) assert(allEvents(0)(0) === ForeachSinkSuite.Open(partition = 0, version = 0)) - assert(allEvents(0)(1) === ForeachSinkSuite.Process(value = 2)) + assert(allEvents(0)(1) === ForeachSinkSuite.Process(value = 1)) // `close` should be called with the error val errorEvent = allEvents(0)(2).asInstanceOf[ForeachSinkSuite.Close] From bab8e68bc292a3a71ee378a839fd540dcf0a72bd Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Fri, 29 Dec 2017 10:08:03 -0800 Subject: [PATCH 4/6] [SPARK-22905][ML][FOLLOWUP] Fix GaussianMixtureModel save ## What changes were proposed in this pull request? make sure model data is stored in order. WeichenXu123 ## How was this patch tested? existing tests Author: Zheng RuiFeng Closes #20113 from zhengruifeng/gmm_save. --- .../apache/spark/mllib/clustering/GaussianMixtureModel.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index afbe4f978b286..1933d5499c3bf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -154,7 +154,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { val dataArray = Array.tabulate(weights.length) { i => Data(weights(i), gaussians(i).mu, gaussians(i).sigma) } - spark.createDataFrame(dataArray).repartition(1).write.parquet(Loader.dataPath(path)) + spark.createDataFrame(sc.makeRDD(dataArray, 1)).write.parquet(Loader.dataPath(path)) } def load(sc: SparkContext, path: String): GaussianMixtureModel = { From 754a454764b3e04c563846615dd42d51355089db Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Thu, 28 Jun 2018 14:19:50 +0800 Subject: [PATCH 5/6] [SPARK-24564][TEST] Add test suite for RecordBinaryComparator ## What changes were proposed in this pull request? Add a new test suite to test RecordBinaryComparator. ## How was this patch tested? New test suite. Author: Xingbo Jiang Closes #21570 from jiangxb1987/rbc-test. --- .../spark/memory/TestMemoryConsumer.java | 10 + .../sort/RecordBinaryComparatorSuite.java | 256 ++++++++++++++++++ 2 files changed, 266 insertions(+) create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java diff --git a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java index db91329c94cb6..0bbaea6b834b8 100644 --- a/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java +++ b/core/src/test/java/org/apache/spark/memory/TestMemoryConsumer.java @@ -17,6 +17,10 @@ package org.apache.spark.memory; +import com.google.common.annotations.VisibleForTesting; + +import org.apache.spark.unsafe.memory.MemoryBlock; + import java.io.IOException; public class TestMemoryConsumer extends MemoryConsumer { @@ -43,6 +47,12 @@ void free(long size) { used -= size; taskMemoryManager.releaseExecutionMemory(size, this); } + + @VisibleForTesting + public void freePage(MemoryBlock page) { + used -= page.size(); + taskMemoryManager.freePage(page, this); + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java new file mode 100644 index 0000000000000..a19ddbdbadba2 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.execution.sort; + +import org.apache.spark.SparkConf; +import org.apache.spark.memory.TaskMemoryManager; +import org.apache.spark.memory.TestMemoryConsumer; +import org.apache.spark.memory.TestMemoryManager; +import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.execution.RecordBinaryComparator; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.UnsafeAlignedOffset; +import org.apache.spark.unsafe.array.LongArray; +import org.apache.spark.unsafe.memory.MemoryBlock; +import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.util.collection.unsafe.sort.*; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +/** + * Test the RecordBinaryComparator, which compares two UnsafeRows by their binary form. + */ +public class RecordBinaryComparatorSuite { + + private final TaskMemoryManager memoryManager = new TaskMemoryManager( + new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")), 0); + private final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager); + + private final int uaoSize = UnsafeAlignedOffset.getUaoSize(); + + private MemoryBlock dataPage; + private long pageCursor; + + private LongArray array; + private int pos; + + @Before + public void beforeEach() { + // Only compare between two input rows. + array = consumer.allocateArray(2); + pos = 0; + + dataPage = memoryManager.allocatePage(4096, consumer); + pageCursor = dataPage.getBaseOffset(); + } + + @After + public void afterEach() { + consumer.freePage(dataPage); + dataPage = null; + pageCursor = 0; + + consumer.freeArray(array); + array = null; + pos = 0; + } + + private void insertRow(UnsafeRow row) { + Object recordBase = row.getBaseObject(); + long recordOffset = row.getBaseOffset(); + int recordLength = row.getSizeInBytes(); + + Object baseObject = dataPage.getBaseObject(); + assert(pageCursor + recordLength <= dataPage.getBaseOffset() + dataPage.size()); + long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, pageCursor); + UnsafeAlignedOffset.putSize(baseObject, pageCursor, recordLength); + pageCursor += uaoSize; + Platform.copyMemory(recordBase, recordOffset, baseObject, pageCursor, recordLength); + pageCursor += recordLength; + + assert(pos < 2); + array.set(pos, recordAddress); + pos++; + } + + private int compare(int index1, int index2) { + Object baseObject = dataPage.getBaseObject(); + + long recordAddress1 = array.get(index1); + long baseOffset1 = memoryManager.getOffsetInPage(recordAddress1) + uaoSize; + int recordLength1 = UnsafeAlignedOffset.getSize(baseObject, baseOffset1 - uaoSize); + + long recordAddress2 = array.get(index2); + long baseOffset2 = memoryManager.getOffsetInPage(recordAddress2) + uaoSize; + int recordLength2 = UnsafeAlignedOffset.getSize(baseObject, baseOffset2 - uaoSize); + + return binaryComparator.compare(baseObject, baseOffset1, recordLength1, baseObject, + baseOffset2, recordLength2); + } + + private final RecordComparator binaryComparator = new RecordBinaryComparator(); + + // Compute the most compact size for UnsafeRow's backing data. + private int computeSizeInBytes(int originalSize) { + // All the UnsafeRows in this suite contains less than 64 columns, so the bitSetSize shall + // always be 8. + return 8 + (originalSize + 7) / 8 * 8; + } + + // Compute the relative offset of variable-length values. + private long relativeOffset(int numFields) { + // All the UnsafeRows in this suite contains less than 64 columns, so the bitSetSize shall + // always be 8. + return 8 + numFields * 8L; + } + + @Test + public void testBinaryComparatorForSingleColumnRow() throws Exception { + int numFields = 1; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + row1.setInt(0, 11); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + row2.setInt(0, 42); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 0) == 0); + assert(compare(0, 1) < 0); + } + + @Test + public void testBinaryComparatorForMultipleColumnRow() throws Exception { + int numFields = 5; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + for (int i = 0; i < numFields; i++) { + row1.setDouble(i, i * 3.14); + } + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + for (int i = 0; i < numFields; i++) { + row2.setDouble(i, 198.7 / (i + 1)); + } + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 0) == 0); + assert(compare(0, 1) < 0); + } + + @Test + public void testBinaryComparatorForArrayColumn() throws Exception { + int numFields = 1; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + UnsafeArrayData arrayData1 = UnsafeArrayData.fromPrimitiveArray(new int[]{11, 42, -1}); + row1.pointTo(data1, computeSizeInBytes(numFields * 8 + arrayData1.getSizeInBytes())); + row1.setLong(0, (relativeOffset(numFields) << 32) | (long) arrayData1.getSizeInBytes()); + Platform.copyMemory(arrayData1.getBaseObject(), arrayData1.getBaseOffset(), data1, + row1.getBaseOffset() + relativeOffset(numFields), arrayData1.getSizeInBytes()); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + UnsafeArrayData arrayData2 = UnsafeArrayData.fromPrimitiveArray(new int[]{22}); + row2.pointTo(data2, computeSizeInBytes(numFields * 8 + arrayData2.getSizeInBytes())); + row2.setLong(0, (relativeOffset(numFields) << 32) | (long) arrayData2.getSizeInBytes()); + Platform.copyMemory(arrayData2.getBaseObject(), arrayData2.getBaseOffset(), data2, + row2.getBaseOffset() + relativeOffset(numFields), arrayData2.getSizeInBytes()); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 0) == 0); + assert(compare(0, 1) > 0); + } + + @Test + public void testBinaryComparatorForMixedColumns() throws Exception { + int numFields = 4; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + UTF8String str1 = UTF8String.fromString("Milk tea"); + row1.pointTo(data1, computeSizeInBytes(numFields * 8 + str1.numBytes())); + row1.setInt(0, 11); + row1.setDouble(1, 3.14); + row1.setInt(2, -1); + row1.setLong(3, (relativeOffset(numFields) << 32) | (long) str1.numBytes()); + Platform.copyMemory(str1.getBaseObject(), str1.getBaseOffset(), data1, + row1.getBaseOffset() + relativeOffset(numFields), str1.numBytes()); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + UTF8String str2 = UTF8String.fromString("Java"); + row2.pointTo(data2, computeSizeInBytes(numFields * 8 + str2.numBytes())); + row2.setInt(0, 11); + row2.setDouble(1, 3.14); + row2.setInt(2, -1); + row2.setLong(3, (relativeOffset(numFields) << 32) | (long) str2.numBytes()); + Platform.copyMemory(str2.getBaseObject(), str2.getBaseOffset(), data2, + row2.getBaseOffset() + relativeOffset(numFields), str2.numBytes()); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 0) == 0); + assert(compare(0, 1) > 0); + } + + @Test + public void testBinaryComparatorForNullColumns() throws Exception { + int numFields = 3; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + for (int i = 0; i < numFields; i++) { + row1.setNullAt(i); + } + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + for (int i = 0; i < numFields - 1; i++) { + row2.setNullAt(i); + } + row2.setDouble(numFields - 1, 3.14); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 0) == 0); + assert(compare(0, 1) > 0); + } +} From 2edad85c39f0fbfea8c8361e9ee420abc6fe4202 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Mon, 20 Aug 2018 23:13:31 -0700 Subject: [PATCH 6/6] [SPARK-25114][CORE] Fix RecordBinaryComparator when subtraction between two words is divisible by Integer.MAX_VALUE. ## What changes were proposed in this pull request? https://github.com/apache/spark/pull/22079#discussion_r209705612 It is possible for two objects to be unequal and yet we consider them as equal with this code, if the long values are separated by Int.MaxValue. This PR fixes the issue. ## How was this patch tested? Add new test cases in `RecordBinaryComparatorSuite`. Closes #22101 from jiangxb1987/fix-rbc. Authored-by: Xingbo Jiang Signed-off-by: Xiao Li --- .../sql/execution/RecordBinaryComparator.java | 26 ++++---- .../sort/RecordBinaryComparatorSuite.java | 66 +++++++++++++++++++ 2 files changed, 81 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java index bb77b5bf6de2a..40c2cc806e87a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/RecordBinaryComparator.java @@ -22,12 +22,10 @@ public final class RecordBinaryComparator extends RecordComparator { - // TODO(jiangxb) Add test suite for this. @Override public int compare( Object leftObj, long leftOff, int leftLen, Object rightObj, long rightOff, int rightLen) { int i = 0; - int res = 0; // If the arrays have different length, the longer one is larger. if (leftLen != rightLen) { @@ -40,27 +38,33 @@ public int compare( // check if stars align and we can get both offsets to be aligned if ((leftOff % 8) == (rightOff % 8)) { while ((leftOff + i) % 8 != 0 && i < leftLen) { - res = (Platform.getByte(leftObj, leftOff + i) & 0xff) - - (Platform.getByte(rightObj, rightOff + i) & 0xff); - if (res != 0) return res; + final int v1 = Platform.getByte(leftObj, leftOff + i) & 0xff; + final int v2 = Platform.getByte(rightObj, rightOff + i) & 0xff; + if (v1 != v2) { + return v1 > v2 ? 1 : -1; + } i += 1; } } // for architectures that support unaligned accesses, chew it up 8 bytes at a time if (Platform.unaligned() || (((leftOff + i) % 8 == 0) && ((rightOff + i) % 8 == 0))) { while (i <= leftLen - 8) { - res = (int) ((Platform.getLong(leftObj, leftOff + i) - - Platform.getLong(rightObj, rightOff + i)) % Integer.MAX_VALUE); - if (res != 0) return res; + final long v1 = Platform.getLong(leftObj, leftOff + i); + final long v2 = Platform.getLong(rightObj, rightOff + i); + if (v1 != v2) { + return v1 > v2 ? 1 : -1; + } i += 8; } } // this will finish off the unaligned comparisons, or do the entire aligned comparison // whichever is needed. while (i < leftLen) { - res = (Platform.getByte(leftObj, leftOff + i) & 0xff) - - (Platform.getByte(rightObj, rightOff + i) & 0xff); - if (res != 0) return res; + final int v1 = Platform.getByte(leftObj, leftOff + i) & 0xff; + final int v2 = Platform.getByte(rightObj, rightOff + i) & 0xff; + if (v1 != v2) { + return v1 > v2 ? 1 : -1; + } i += 1; } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java index a19ddbdbadba2..97f3dc588ecc5 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java @@ -253,4 +253,70 @@ public void testBinaryComparatorForNullColumns() throws Exception { assert(compare(0, 0) == 0); assert(compare(0, 1) > 0); } + + @Test + public void testBinaryComparatorWhenSubtractionIsDivisibleByMaxIntValue() throws Exception { + int numFields = 1; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + row1.setLong(0, 11); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + row2.setLong(0, 11L + Integer.MAX_VALUE); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 1) < 0); + } + + @Test + public void testBinaryComparatorWhenSubtractionCanOverflowLongValue() throws Exception { + int numFields = 1; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + row1.setLong(0, Long.MIN_VALUE); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + row2.setLong(0, 1); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 1) < 0); + } + + @Test + public void testBinaryComparatorWhenOnlyTheLastColumnDiffers() throws Exception { + int numFields = 4; + + UnsafeRow row1 = new UnsafeRow(numFields); + byte[] data1 = new byte[100]; + row1.pointTo(data1, computeSizeInBytes(numFields * 8)); + row1.setInt(0, 11); + row1.setDouble(1, 3.14); + row1.setInt(2, -1); + row1.setLong(3, 0); + + UnsafeRow row2 = new UnsafeRow(numFields); + byte[] data2 = new byte[100]; + row2.pointTo(data2, computeSizeInBytes(numFields * 8)); + row2.setInt(0, 11); + row2.setDouble(1, 3.14); + row2.setInt(2, -1); + row2.setLong(3, 1); + + insertRow(row1); + insertRow(row2); + + assert(compare(0, 1) < 0); + } }