Skip to content

Commit c834cee

Browse files
committed
apply SPARK-4148 to branch-1.0
1 parent 49224fd commit c834cee

File tree

3 files changed

+20
-9
lines changed

3 files changed

+20
-9
lines changed

python/pyspark/rdd.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -366,9 +366,6 @@ def sample(self, withReplacement, fraction, seed=None):
366366
"""
367367
Return a sampled subset of this RDD (relies on numpy and falls back
368368
on default random generator if numpy is unavailable).
369-
370-
>>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP
371-
[2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98]
372369
"""
373370
assert fraction >= 0.0, "Invalid fraction value: %s" % fraction
374371
return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True)

python/pyspark/rddsampler.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,13 @@ def __init__(self, withReplacement, fraction, seed=None):
3737
def initRandomGenerator(self, split):
3838
if self._use_numpy:
3939
import numpy
40-
self._random = numpy.random.RandomState(self._seed)
40+
self._random = numpy.random.RandomState(self._seed ^ split)
4141
else:
42-
self._random = random.Random(self._seed)
42+
self._random = random.Random(self._seed ^ split)
4343

44-
for _ in range(0, split):
45-
# discard the next few values in the sequence to have a
46-
# different seed for the different splits
47-
self._random.randint(0, sys.maxint)
44+
# mixing because the initial seeds are close to each other
45+
for _ in xrange(10):
46+
self._random.randint(0, 1)
4847

4948
self._split = split
5049
self._rand_initialized = True

python/pyspark/tests.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,21 @@ def test_itemgetter(self):
204204
self.assertEqual([1], rdd.map(itemgetter(1)).collect())
205205
self.assertEqual([(2, 3)], rdd.map(itemgetter(2, 3)).collect())
206206

207+
def test_sample(self):
208+
rdd = self.sc.parallelize(range(0, 100), 4)
209+
wo = rdd.sample(False, 0.1, 2).collect()
210+
wo_dup = rdd.sample(False, 0.1, 2).collect()
211+
self.assertSetEqual(set(wo), set(wo_dup))
212+
wr = rdd.sample(True, 0.2, 5).collect()
213+
wr_dup = rdd.sample(True, 0.2, 5).collect()
214+
self.assertSetEqual(set(wr), set(wr_dup))
215+
wo_s10 = rdd.sample(False, 0.3, 10).collect()
216+
wo_s20 = rdd.sample(False, 0.3, 20).collect()
217+
self.assertNotEqual(set(wo_s10), set(wo_s20))
218+
wr_s11 = rdd.sample(True, 0.4, 11).collect()
219+
wr_s21 = rdd.sample(True, 0.4, 21).collect()
220+
self.assertNotEqual(set(wr_s11), set(wr_s21))
221+
207222

208223
class TestIO(PySparkTestCase):
209224

0 commit comments

Comments
 (0)