Skip to content

Commit 869ae4b

Browse files
committed
move tests tests.py
1 parent c1bacd9 commit 869ae4b

File tree

2 files changed

+15
-18
lines changed

2 files changed

+15
-18
lines changed

python/pyspark/rdd.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -316,24 +316,6 @@ def sample(self, withReplacement, fraction, seed=None):
316316
"""
317317
Return a sampled subset of this RDD (relies on numpy and falls back
318318
on default random generator if numpy is unavailable).
319-
320-
>>> rdd = sc.parallelize(range(0, 100), 4)
321-
>>> wo = rdd.sample(False, 0.1, 2).collect()
322-
>>> wo_dup = rdd.sample(False, 0.1, 2).collect()
323-
>>> set(wo) == set(wo_dup)
324-
True
325-
>>> wr = rdd.sample(True, 0.2, 5).collect()
326-
>>> wr_dup = rdd.sample(True, 0.2, 5).collect()
327-
>>> set(wr) == set(wr_dup)
328-
True
329-
>>> wo_s10 = rdd.sample(False, 0.3, 10).collect()
330-
>>> wo_s20 = rdd.sample(False, 0.3, 20).collect()
331-
>>> set(wo_s10) != set(wo_s20)
332-
True
333-
>>> wr_s11 = rdd.sample(True, 0.4, 11).collect()
334-
>>> wr_s21 = rdd.sample(True, 0.4, 21).collect()
335-
>>> set(wr_s11) != set(wr_s21)
336-
True
337319
"""
338320
assert fraction >= 0.0, "Negative fraction value: %s" % fraction
339321
return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True)

python/pyspark/tests.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,21 @@ def test_distinct(self):
648648
self.assertEquals(result.getNumPartitions(), 5)
649649
self.assertEquals(result.count(), 3)
650650

651+
def test_sample(self):
652+
rdd = self.sc.parallelize(range(0, 100), 4)
653+
wo = rdd.sample(False, 0.1, 2).collect()
654+
wo_dup = rdd.sample(False, 0.1, 2).collect()
655+
self.assertSetEqual(set(wo), set(wo_dup))
656+
wr = rdd.sample(True, 0.2, 5).collect()
657+
wr_dup = rdd.sample(True, 0.2, 5).collect()
658+
self.assertSetEqual(set(wr), set(wr_dup))
659+
wo_s10 = rdd.sample(False, 0.3, 10).collect()
660+
wo_s20 = rdd.sample(False, 0.3, 20).collect()
661+
self.assertNotEqual(set(wo_s10), set(wo_s20))
662+
wr_s11 = rdd.sample(True, 0.4, 11).collect()
663+
wr_s21 = rdd.sample(True, 0.4, 21).collect()
664+
self.assertNotEqual(set(wr_s11), set(wr_s21))
665+
651666

652667
class ProfilerTests(PySparkTestCase):
653668

0 commit comments

Comments
 (0)