Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1062,10 +1062,15 @@ class DataFrame private[sql](
* @since 1.4.0
*/
def randomSplit(weights: Array[Double], seed: Long): Array[DataFrame] = {
// It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its
// constituent partitions each time a split is materialized which could result in
// overlapping splits. To prevent this, we explicitly sort each input partition to make the
// ordering deterministic.
val sorted = Sort(logicalPlan.output.map(SortOrder(_, Ascending)), global = false, logicalPlan)
val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x =>
new DataFrame(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, logicalPlan))
new DataFrame(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, sorted))
}.toArray
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,28 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
}
}

test("randomSplit on reordered partitions") {
// This test ensures that randomSplit does not create overlapping splits even when the
// underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of
// rows in each partition.
val data =
sparkContext.parallelize(1 to 600, 2).mapPartitions(scala.util.Random.shuffle(_)).toDF("id")
val splits = data.randomSplit(Array[Double](2, 3), seed = 1)

assert(splits.length == 2, "wrong number of splits")

// Verify that the splits span the entire dataset
assert(splits.flatMap(_.collect()).toSet == data.collect().toSet)

// Verify that the splits don't overalap
assert(splits(0).intersect(splits(1)).collect().isEmpty)

// Verify that the results are deterministic across multiple runs
val firstRun = splits.toSeq.map(_.collect().toSeq)
val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq)
assert(firstRun == secondRun)
}

test("pearson correlation") {
val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c")
val corr1 = df.stat.corr("a", "b", "pearson")
Expand Down