Skip to content

Commit a458efc

Browse files
committed
Revert "[SPARK-7157][SQL] add sampleBy to DataFrame"
This reverts commit 0401cba. The new test case on Jenkins is failing.
1 parent 0401cba commit a458efc

File tree

3 files changed

+2
-74
lines changed

3 files changed

+2
-74
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -448,41 +448,6 @@ def sample(self, withReplacement, fraction, seed=None):
448448
rdd = self._jdf.sample(withReplacement, fraction, long(seed))
449449
return DataFrame(rdd, self.sql_ctx)
450450

451-
@since(1.5)
452-
def sampleBy(self, col, fractions, seed=None):
453-
"""
454-
Returns a stratified sample without replacement based on the
455-
fraction given on each stratum.
456-
457-
:param col: column that defines strata
458-
:param fractions:
459-
sampling fraction for each stratum. If a stratum is not
460-
specified, we treat its fraction as zero.
461-
:param seed: random seed
462-
:return: a new DataFrame that represents the stratified sample
463-
464-
>>> from pyspark.sql.functions import col
465-
>>> dataset = sqlContext.range(0, 100).select((col("id") % 3).alias("key"))
466-
>>> sampled = dataset.sampleBy("key", fractions={0: 0.1, 1: 0.2}, seed=0)
467-
>>> sampled.groupBy("key").count().orderBy("key").show()
468-
+---+-----+
469-
|key|count|
470-
+---+-----+
471-
| 0| 5|
472-
| 1| 8|
473-
+---+-----+
474-
"""
475-
if not isinstance(col, str):
476-
raise ValueError("col must be a string, but got %r" % type(col))
477-
if not isinstance(fractions, dict):
478-
raise ValueError("fractions must be a dict but got %r" % type(fractions))
479-
for k, v in fractions.items():
480-
if not isinstance(k, (float, int, long, basestring)):
481-
raise ValueError("key must be float, int, long, or string, but got %r" % type(k))
482-
fractions[k] = float(v)
483-
seed = seed if seed is not None else random.randint(0, sys.maxsize)
484-
return DataFrame(self._jdf.stat().sampleBy(col, self._jmap(fractions), seed), self.sql_ctx)
485-
486451
@since(1.4)
487452
def randomSplit(self, weights, seed=None):
488453
"""Randomly splits this :class:`DataFrame` with the provided weights.
@@ -1357,11 +1322,6 @@ def freqItems(self, cols, support=None):
13571322

13581323
freqItems.__doc__ = DataFrame.freqItems.__doc__
13591324

1360-
def sampleBy(self, col, fractions, seed=None):
1361-
return self.df.sampleBy(col, fractions, seed)
1362-
1363-
sampleBy.__doc__ = DataFrame.sampleBy.__doc__
1364-
13651325

13661326
def _test():
13671327
import doctest

sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717

1818
package org.apache.spark.sql
1919

20-
import java.util.UUID
21-
2220
import org.apache.spark.annotation.Experimental
2321
import org.apache.spark.sql.execution.stat._
2422

@@ -165,26 +163,4 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
165163
def freqItems(cols: Seq[String]): DataFrame = {
166164
FrequentItems.singlePassFreqItems(df, cols, 0.01)
167165
}
168-
169-
/**
170-
* Returns a stratified sample without replacement based on the fraction given on each stratum.
171-
* @param col column that defines strata
172-
* @param fractions sampling fraction for each stratum. If a stratum is not specified, we treat
173-
* its fraction as zero.
174-
* @param seed random seed
175-
* @return a new [[DataFrame]] that represents the stratified sample
176-
*
177-
* @since 1.5.0
178-
*/
179-
def sampleBy(col: String, fractions: Map[Any, Double], seed: Long): DataFrame = {
180-
require(fractions.values.forall(p => p >= 0.0 && p <= 1.0),
181-
s"Fractions must be in [0, 1], but got $fractions.")
182-
import org.apache.spark.sql.functions.rand
183-
val c = Column(col)
184-
val r = rand(seed).as("rand_" + UUID.randomUUID().toString.take(8))
185-
val expr = fractions.toSeq.map { case (k, v) =>
186-
(c === k) && (r < v)
187-
}.reduce(_ || _) || false
188-
df.filter(expr)
189-
}
190166
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ package org.apache.spark.sql
1919

2020
import org.scalatest.Matchers._
2121

22-
import org.apache.spark.sql.functions.col
22+
import org.apache.spark.SparkFunSuite
2323

24-
class DataFrameStatSuite extends QueryTest {
24+
class DataFrameStatSuite extends SparkFunSuite {
2525

2626
private val sqlCtx = org.apache.spark.sql.test.TestSQLContext
2727
import sqlCtx.implicits._
@@ -98,12 +98,4 @@ class DataFrameStatSuite extends QueryTest {
9898
val items2 = singleColResults.collect().head
9999
items2.getSeq[Double](0) should contain (-1.0)
100100
}
101-
102-
test("sampleBy") {
103-
val df = sqlCtx.range(0, 100).select((col("id") % 3).as("key"))
104-
val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L)
105-
checkAnswer(
106-
sampled.groupBy("key").count().orderBy("key"),
107-
Seq(Row(0, 4), Row(1, 9)))
108-
}
109101
}

0 commit comments

Comments
 (0)