From ad483fd0ad37e0df5b2dbd125af706454c7fbeb6 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Thu, 13 Mar 2014 14:24:38 -0400 Subject: [PATCH 1/5] handle the case with empty RDD when take sample --- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 9 ++++++++- core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala | 4 ++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index b50c9963b9d2c..5e7bb4531aef4 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -310,6 +310,9 @@ abstract class RDD[T: ClassTag]( * Return a sampled subset of this RDD. */ def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] = { + if (fraction < Double.MinValue || fraction > Double.MaxValue) { + throw new Exception("Invalid fraction value:" + fraction) + } if (withReplacement) { new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), seed) } else { @@ -344,6 +347,10 @@ abstract class RDD[T: ClassTag]( throw new IllegalArgumentException("Negative number of elements requested") } + if (initialCount == 0) { + return new Array[T](0) + } + if (initialCount > Integer.MAX_VALUE - 1) { maxSelected = Integer.MAX_VALUE - 1 } else { @@ -362,7 +369,7 @@ abstract class RDD[T: ClassTag]( var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() // If the first sample didn't turn out large enough, keep trying to take samples; - // this shouldn't happen often because we use a big multiplier for thei initial size + // this shouldn't happen often because we use a big multiplier for the initial size while (samples.length < total) { samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 60bcada55245b..7124cfc742e79 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -457,6 +457,10 @@ class RDDSuite extends FunSuite with SharedSparkContext { test("takeSample") { val data = sc.parallelize(1 to 100, 2) + val emptySet = data.filter(_ => false) + + val sample = emptySet.takeSample(false, 20, 1) + assert(sample.size === 0) for (seed <- 1 to 5) { val sample = data.takeSample(withReplacement=false, 20, seed) assert(sample.size === 20) // Got exactly 20 elements From a40e8fb632fb48951cace8f281cd3f30ded017e1 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Thu, 13 Mar 2014 15:04:02 -0400 Subject: [PATCH 2/5] replace if with require --- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 5e7bb4531aef4..36d7f6c428987 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -310,9 +310,8 @@ abstract class RDD[T: ClassTag]( * Return a sampled subset of this RDD. */ def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] = { - if (fraction < Double.MinValue || fraction > Double.MaxValue) { - throw new Exception("Invalid fraction value:" + fraction) - } + require(fraction >= 0 && fraction <= Double.MaxValue, + "Invalid fraction value: " + fraction) if (withReplacement) { new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), seed) } else { From 810948d49b7aa54adff57ca91766e0cc048a1699 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Thu, 13 Mar 2014 15:25:09 -0400 Subject: [PATCH 3/5] further fix --- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 3 +-- core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 36d7f6c428987..f8283fbbb980d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -310,8 +310,7 @@ abstract class RDD[T: ClassTag]( * Return a sampled subset of this RDD. */ def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] = { - require(fraction >= 0 && fraction <= Double.MaxValue, - "Invalid fraction value: " + fraction) + require(fraction >= 0.0, "Invalid fraction value: " + fraction) if (withReplacement) { new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), seed) } else { diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 7124cfc742e79..ddf2214c397dd 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -457,10 +457,10 @@ class RDDSuite extends FunSuite with SharedSparkContext { test("takeSample") { val data = sc.parallelize(1 to 100, 2) - val emptySet = data.filter(_ => false) + val emptySet = data.mapPartitions { iter => Iterator.empty } val sample = emptySet.takeSample(false, 20, 1) - assert(sample.size === 0) + assert(sample.length === 0) for (seed <- 1 to 5) { val sample = data.takeSample(withReplacement=false, 20, seed) assert(sample.size === 20) // Got exactly 20 elements From 36db06b0b6f98e6809725db2a667ca4ad3971315 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Thu, 13 Mar 2014 16:45:59 -0400 Subject: [PATCH 4/5] create new test cases for takeSample from an empty red --- core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index ddf2214c397dd..9512e0e6eeb14 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -457,10 +457,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { test("takeSample") { val data = sc.parallelize(1 to 100, 2) - val emptySet = data.mapPartitions { iter => Iterator.empty } - val sample = emptySet.takeSample(false, 20, 1) - assert(sample.length === 0) for (seed <- 1 to 5) { val sample = data.takeSample(withReplacement=false, 20, seed) assert(sample.size === 20) // Got exactly 20 elements @@ -492,6 +489,12 @@ class RDDSuite extends FunSuite with SharedSparkContext { } } + test("takeSample from an empty rdd") { + val emptySet = sc.parallelize(Seq.empty[Int], 2) + val sample = emptySet.takeSample(false, 20, 1) + assert(sample.length === 0) + } + test("randomSplit") { val n = 600 val data = sc.parallelize(1 to n, 2) From fef57d4dcc9af47c4604a168996b4505b12b7050 Mon Sep 17 00:00:00 2001 From: CodingCat Date: Fri, 14 Mar 2014 18:14:34 -0400 Subject: [PATCH 5/5] fix the same problem in PySpark --- python/pyspark/rdd.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 6d549b40e5698..f3b432ff248a9 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -268,6 +268,7 @@ def sample(self, withReplacement, fraction, seed): >>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP [2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98] """ + assert fraction >= 0.0, "Invalid fraction value: %s" % fraction return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True) # this is ported from scala/spark/RDD.scala @@ -288,6 +289,9 @@ def takeSample(self, withReplacement, num, seed): if (num < 0): raise ValueError + if (initialCount == 0): + return list() + if initialCount > sys.maxint - 1: maxSelected = sys.maxint - 1 else: