From 27288a30ebe4cea98050803c1dfc55cf6275d162 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Wed, 6 Jan 2016 18:59:48 -0800 Subject: [PATCH 1/7] Fix DataFrame.randomSplit to avoid creating overlapping splits --- .../org/apache/spark/sql/DataFrame.scala | 9 ++- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 58 +++++++++++++++++++ 2 files changed, 66 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 7cf2818590a78..4952d13f3ed0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1062,10 +1062,17 @@ class DataFrame private[sql]( * @since 1.4.0 */ def randomSplit(weights: Array[Double], seed: Long): Array[DataFrame] = { + // It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its + // constituent partitions each time a split is materialized which could result in + // overlapping splits. To prevent this, we explicitly sort each input partition to make the + // ordering deterministic. + val logicalPlanWithLocalSort = + Sort(logicalPlan.output.map(SortOrder(_, Ascending)), global = false, logicalPlan) val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) normalizedCumWeights.sliding(2).map { x => - new DataFrame(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, logicalPlan)) + new DataFrame(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, + logicalPlanWithLocalSort)) }.toArray } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 8932ce9503a3d..4eec93ae815be 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -129,6 +129,19 @@ class HiveSparkSubmitSuite runSparkSubmit(args) } + test("SPARK-12662 fix DataFrame.randomSplit to avoid creating overlapping splits") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val args = Seq( + "--class", SPARK_12662.getClass.getName.stripSuffix("$"), + "--name", "SparkSQLConfTest", + "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--driver-java-options", "-Dderby.system.durability=test", + unusedJar.toString) + runSparkSubmit(args) + } + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. // This is copied from org.apache.spark.deploy.SparkSubmitSuite private def runSparkSubmit(args: Seq[String]): Unit = { @@ -372,3 +385,48 @@ object SPARK_11009 extends QueryTest { } } } + +/** + * This object is used to test SPARK-12662: https://issues.apache.org/jira/browse/SPARK-12662. + * This test ensures that [[org.apache.spark.sql.DataFrame.randomSplit]] does not create overlapping + * splits even when the underlying dataframe doesn't guarantee a deterministic ordering of rows in + * each partition. + */ +object SPARK_12662 extends QueryTest { + import org.apache.spark.sql.functions._ + + protected var sqlContext: SQLContext = _ + + def main(args: Array[String]): Unit = { + Utils.configTestLog4j("INFO") + + val sparkContext = new SparkContext( + new SparkConf() + .set("spark.sql.shuffle.partitions", "100")) + + val hiveContext = new TestHiveContext(sparkContext) + sqlContext = hiveContext + + try { + val n = 600 + val data = sqlContext.range(n).toDF("id").repartition(200, col("id")) + val splits = data.randomSplit(Array[Double](1, 2, 3), seed = 1) + assert(splits.length == 3, "wrong number of splits") + + assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList == + data.sort(col("id")).collect().toList, "incomplete or wrong split") + + for (id <- splits.indices) { + assert(splits(id).intersect(splits((id + 1) % splits.length)).collect().isEmpty, + s"split $id overlaps with split ${(id + 1) % splits.length}") + } + + val s = splits.map(_.count()) + assert(math.abs(s(0) - 100) < 50) // std = 9.13 + assert(math.abs(s(1) - 200) < 50) // std = 11.55 + assert(math.abs(s(2) - 300) < 50) // std = 12.25 + } finally { + sparkContext.stop() + } + } +} From 633683219022a3d7fa512bdbc3ea7ae6349fc7e1 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Wed, 6 Jan 2016 22:06:17 -0800 Subject: [PATCH 2/7] Reynold's comments --- .../apache/spark/sql/DataFrameStatSuite.scala | 26 +++++++++ .../spark/sql/hive/HiveSparkSubmitSuite.scala | 58 ------------------- 2 files changed, 26 insertions(+), 58 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index b15af42caa3ab..362c2f956c161 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -62,6 +62,32 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { } } + test("randomSplit on reordered partitions") { + val n = 600 + // This test ensures that randomSplit does not create overlapping splits even when the + // underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of + // rows in each partition. + val data = + sparkContext.parallelize(1 to n, 2).mapPartitions(scala.util.Random.shuffle(_)).toDF("id") + for (seed <- 1 to 5) { + val splits = data.randomSplit(Array[Double](1, 2, 3), seed) + assert(splits.length == 3, "wrong number of splits") + + assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList == + data.sort($"id").collect().toList, "incomplete or wrong split") + + for (id <- splits.indices) { + assert(splits(id).intersect(splits((id + 1) % splits.length)).collect().isEmpty, + s"split $id overlaps with split ${(id + 1) % splits.length}") + } + + val s = splits.map(_.count()) + assert(math.abs(s(0) - 100) < 50) // std = 9.13 + assert(math.abs(s(1) - 200) < 50) // std = 11.55 + assert(math.abs(s(2) - 300) < 50) // std = 12.25 + } + } + test("pearson correlation") { val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c") val corr1 = df.stat.corr("a", "b", "pearson") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 4eec93ae815be..8932ce9503a3d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -129,19 +129,6 @@ class HiveSparkSubmitSuite runSparkSubmit(args) } - test("SPARK-12662 fix DataFrame.randomSplit to avoid creating overlapping splits") { - val unusedJar = TestUtils.createJarWithClasses(Seq.empty) - val args = Seq( - "--class", SPARK_12662.getClass.getName.stripSuffix("$"), - "--name", "SparkSQLConfTest", - "--master", "local-cluster[2,1,1024]", - "--conf", "spark.ui.enabled=false", - "--conf", "spark.master.rest.enabled=false", - "--driver-java-options", "-Dderby.system.durability=test", - unusedJar.toString) - runSparkSubmit(args) - } - // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. // This is copied from org.apache.spark.deploy.SparkSubmitSuite private def runSparkSubmit(args: Seq[String]): Unit = { @@ -385,48 +372,3 @@ object SPARK_11009 extends QueryTest { } } } - -/** - * This object is used to test SPARK-12662: https://issues.apache.org/jira/browse/SPARK-12662. - * This test ensures that [[org.apache.spark.sql.DataFrame.randomSplit]] does not create overlapping - * splits even when the underlying dataframe doesn't guarantee a deterministic ordering of rows in - * each partition. - */ -object SPARK_12662 extends QueryTest { - import org.apache.spark.sql.functions._ - - protected var sqlContext: SQLContext = _ - - def main(args: Array[String]): Unit = { - Utils.configTestLog4j("INFO") - - val sparkContext = new SparkContext( - new SparkConf() - .set("spark.sql.shuffle.partitions", "100")) - - val hiveContext = new TestHiveContext(sparkContext) - sqlContext = hiveContext - - try { - val n = 600 - val data = sqlContext.range(n).toDF("id").repartition(200, col("id")) - val splits = data.randomSplit(Array[Double](1, 2, 3), seed = 1) - assert(splits.length == 3, "wrong number of splits") - - assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList == - data.sort(col("id")).collect().toList, "incomplete or wrong split") - - for (id <- splits.indices) { - assert(splits(id).intersect(splits((id + 1) % splits.length)).collect().isEmpty, - s"split $id overlaps with split ${(id + 1) % splits.length}") - } - - val s = splits.map(_.count()) - assert(math.abs(s(0) - 100) < 50) // std = 9.13 - assert(math.abs(s(1) - 200) < 50) // std = 11.55 - assert(math.abs(s(2) - 300) < 50) // std = 12.25 - } finally { - sparkContext.stop() - } - } -} From 8c3293cc2d450ceb5d8713fe8d40e345b4bf00fc Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Wed, 6 Jan 2016 22:21:17 -0800 Subject: [PATCH 3/7] s/logicalPlanWithLocalSort/sorted --- .../src/main/scala/org/apache/spark/sql/DataFrame.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 4952d13f3ed0d..60d2f05b8605b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1066,13 +1066,11 @@ class DataFrame private[sql]( // constituent partitions each time a split is materialized which could result in // overlapping splits. To prevent this, we explicitly sort each input partition to make the // ordering deterministic. - val logicalPlanWithLocalSort = - Sort(logicalPlan.output.map(SortOrder(_, Ascending)), global = false, logicalPlan) + val sorted = Sort(logicalPlan.output.map(SortOrder(_, Ascending)), global = false, logicalPlan) val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) normalizedCumWeights.sliding(2).map { x => - new DataFrame(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, - logicalPlanWithLocalSort)) + new DataFrame(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, sorted)) }.toArray } From be9630fde2b1e427b5bb0e1a0c51c04cddc862f8 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Wed, 6 Jan 2016 22:23:29 -0800 Subject: [PATCH 4/7] test for single seed --- .../apache/spark/sql/DataFrameStatSuite.scala | 26 +++++++++---------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 362c2f956c161..cb7eed45b8ce6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -69,23 +69,21 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { // rows in each partition. val data = sparkContext.parallelize(1 to n, 2).mapPartitions(scala.util.Random.shuffle(_)).toDF("id") - for (seed <- 1 to 5) { - val splits = data.randomSplit(Array[Double](1, 2, 3), seed) - assert(splits.length == 3, "wrong number of splits") + val splits = data.randomSplit(Array[Double](1, 2, 3), seed = 1) + assert(splits.length == 3, "wrong number of splits") - assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList == - data.sort($"id").collect().toList, "incomplete or wrong split") - - for (id <- splits.indices) { - assert(splits(id).intersect(splits((id + 1) % splits.length)).collect().isEmpty, - s"split $id overlaps with split ${(id + 1) % splits.length}") - } + assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList == + data.sort($"id").collect().toList, "incomplete or wrong split") - val s = splits.map(_.count()) - assert(math.abs(s(0) - 100) < 50) // std = 9.13 - assert(math.abs(s(1) - 200) < 50) // std = 11.55 - assert(math.abs(s(2) - 300) < 50) // std = 12.25 + for (id <- splits.indices) { + assert(splits(id).intersect(splits((id + 1) % splits.length)).collect().isEmpty, + s"split $id overlaps with split ${(id + 1) % splits.length}") } + + val s = splits.map(_.count()) + assert(math.abs(s(0) - 100) < 50) // std = 9.13 + assert(math.abs(s(1) - 200) < 50) // std = 11.55 + assert(math.abs(s(2) - 300) < 50) // std = 12.25 } test("pearson correlation") { From 56d9bd7f05d0d294826ae541e0725d92eeeb0b08 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Wed, 6 Jan 2016 22:48:34 -0800 Subject: [PATCH 5/7] single seed + check for deterministic results across multiple runs --- .../scala/org/apache/spark/sql/DataFrameStatSuite.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index cb7eed45b8ce6..df6db1e984705 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -80,10 +80,16 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { s"split $id overlaps with split ${(id + 1) % splits.length}") } + // Verify sample sizes val s = splits.map(_.count()) assert(math.abs(s(0) - 100) < 50) // std = 9.13 assert(math.abs(s(1) - 200) < 50) // std = 11.55 assert(math.abs(s(2) - 300) < 50) // std = 12.25 + + // Verify that the results are deterministic across multiple runs + val firstRun = splits.toSeq.map(_.collect().toSeq) + val secondRun = data.randomSplit(Array[Double](1, 2, 3), seed = 1).toSeq.map(_.collect().toSeq) + assert(firstRun == secondRun) } test("pearson correlation") { From 3af6a2db857e9140f0fc5683c3ce8fc6b5f73d64 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Wed, 6 Jan 2016 23:51:41 -0800 Subject: [PATCH 6/7] Remove size checks --- .../scala/org/apache/spark/sql/DataFrameStatSuite.scala | 6 ------ 1 file changed, 6 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index df6db1e984705..12f0ae6139bed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -80,12 +80,6 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { s"split $id overlaps with split ${(id + 1) % splits.length}") } - // Verify sample sizes - val s = splits.map(_.count()) - assert(math.abs(s(0) - 100) < 50) // std = 9.13 - assert(math.abs(s(1) - 200) < 50) // std = 11.55 - assert(math.abs(s(2) - 300) < 50) // std = 12.25 - // Verify that the results are deterministic across multiple runs val firstRun = splits.toSeq.map(_.collect().toSeq) val secondRun = data.randomSplit(Array[Double](1, 2, 3), seed = 1).toSeq.map(_.collect().toSeq) From 1b30119d4524bec42bec1b1aaece3f3ab691f0e4 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Thu, 7 Jan 2016 00:32:42 -0800 Subject: [PATCH 7/7] Simplify test --- .../apache/spark/sql/DataFrameStatSuite.scala | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 12f0ae6139bed..63ad6c439a870 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -63,26 +63,24 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { } test("randomSplit on reordered partitions") { - val n = 600 // This test ensures that randomSplit does not create overlapping splits even when the // underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of // rows in each partition. val data = - sparkContext.parallelize(1 to n, 2).mapPartitions(scala.util.Random.shuffle(_)).toDF("id") - val splits = data.randomSplit(Array[Double](1, 2, 3), seed = 1) - assert(splits.length == 3, "wrong number of splits") + sparkContext.parallelize(1 to 600, 2).mapPartitions(scala.util.Random.shuffle(_)).toDF("id") + val splits = data.randomSplit(Array[Double](2, 3), seed = 1) - assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList == - data.sort($"id").collect().toList, "incomplete or wrong split") + assert(splits.length == 2, "wrong number of splits") - for (id <- splits.indices) { - assert(splits(id).intersect(splits((id + 1) % splits.length)).collect().isEmpty, - s"split $id overlaps with split ${(id + 1) % splits.length}") - } + // Verify that the splits span the entire dataset + assert(splits.flatMap(_.collect()).toSet == data.collect().toSet) + + // Verify that the splits don't overalap + assert(splits(0).intersect(splits(1)).collect().isEmpty) // Verify that the results are deterministic across multiple runs val firstRun = splits.toSeq.map(_.collect().toSeq) - val secondRun = data.randomSplit(Array[Double](1, 2, 3), seed = 1).toSeq.map(_.collect().toSeq) + val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq) assert(firstRun == secondRun) }