Skip to content

Commit 3a5c177

Browse files
committed
addressed comments v2.0
1 parent 482e741 commit 3a5c177

File tree

4 files changed

+60
-22
lines changed

4 files changed

+60
-22
lines changed

sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
3333
* [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]].
3434
*
3535
* @param cols the names of the columns to search frequent items in
36-
* @param support The minimum frequency for an item to be considered `frequent`
36+
* @param support The minimum frequency for an item to be considered `frequent` Should be greater
37+
* than 1e-4.
3738
* @return A Local DataFrame with the Array of frequent items for each column.
3839
*/
3940
def freqItems(cols: Seq[String], support: Double): DataFrame = {
@@ -44,12 +45,39 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
4445
* Finding frequent items for columns, possibly with false positives. Using the
4546
* frequent element count algorithm described in
4647
* [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]].
47-
* Returns items more frequent than 1/1000'th of the time.
48+
* Returns items more frequent than 1 percent.
4849
*
4950
* @param cols the names of the columns to search frequent items in
5051
* @return A Local DataFrame with the Array of frequent items for each column.
5152
*/
5253
def freqItems(cols: Seq[String]): DataFrame = {
53-
FrequentItems.singlePassFreqItems(df, cols, 0.001)
54+
FrequentItems.singlePassFreqItems(df, cols, 0.01)
55+
}
56+
57+
/**
58+
* Finding frequent items for columns, possibly with false positives. Using the
59+
* frequent element count algorithm described in
60+
* [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]].
61+
*
62+
* @param cols the names of the columns to search frequent items in
63+
* @param support The minimum frequency for an item to be considered `frequent` Should be greater
64+
* than 1e-4.
65+
* @return A Local DataFrame with the Array of frequent items for each column.
66+
*/
67+
def freqItems(cols: List[String], support: Double): DataFrame = {
68+
FrequentItems.singlePassFreqItems(df, cols, support)
69+
}
70+
71+
/**
72+
* Finding frequent items for columns, possibly with false positives. Using the
73+
* frequent element count algorithm described in
74+
* [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]].
75+
* Returns items more frequent than 1 percent of the time.
76+
*
77+
* @param cols the names of the columns to search frequent items in
78+
* @return A Local DataFrame with the Array of frequent items for each column.
79+
*/
80+
def freqItems(cols: List[String]): DataFrame = {
81+
FrequentItems.singlePassFreqItems(df, cols, 0.01)
5482
}
5583
}

sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,16 +75,15 @@ private[sql] object FrequentItems extends Logging {
7575
*
7676
* @param df The input DataFrame
7777
* @param cols the names of the columns to search frequent items in
78-
* @param support The minimum frequency for an item to be considered `frequent`
78+
* @param support The minimum frequency for an item to be considered `frequent`. Should be greater
79+
* than 1e-4.
7980
* @return A Local DataFrame with the Array of frequent items for each column.
8081
*/
8182
private[sql] def singlePassFreqItems(
8283
df: DataFrame,
8384
cols: Seq[String],
8485
support: Double): DataFrame = {
85-
if (support < 1e-6) {
86-
logWarning(s"The selected support ($support) is too small, and might cause memory problems.")
87-
}
86+
require(support >= 1e-4, s"support ($support) must be greater than 1e-4.")
8887
val numCols = cols.length
8988
// number of max items to keep counts for
9089
val sizeOfMap = (1 / support).toInt

sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,7 @@
2222

2323
import org.apache.spark.api.java.JavaRDD;
2424
import org.apache.spark.api.java.JavaSparkContext;
25-
import org.apache.spark.sql.DataFrame;
26-
import org.apache.spark.sql.Row;
27-
import org.apache.spark.sql.SQLContext;
28-
import org.apache.spark.sql.TestData$;
25+
import org.apache.spark.sql.*;
2926
import org.apache.spark.sql.test.TestSQLContext;
3027
import org.apache.spark.sql.test.TestSQLContext$;
3128
import org.apache.spark.sql.types.*;
@@ -36,6 +33,7 @@
3633
import scala.collection.mutable.Buffer;
3734

3835
import java.io.Serializable;
36+
import java.util.ArrayList;
3937
import java.util.Arrays;
4038
import java.util.List;
4139
import java.util.Map;
@@ -178,4 +176,13 @@ public void testCreateDataFrameFromJavaBeans() {
178176
Assert.assertEquals(bean.getD().get(i), d.apply(i));
179177
}
180178
}
179+
180+
@Test
181+
public void testFrequentItems() {
182+
DataFrame df = context.table("testData2");
183+
List<String> cols = Arrays.asList("a");
184+
DataFrame results = df.stat().freqItems(JavaConversions.asScalaIterable(cols).toList(), 0.2);
185+
System.out.println(results.collect()[0].getSeq(0));
186+
Assert.assertTrue(results.collect()[0].getSeq(0).contains(1));
187+
}
181188
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,27 +17,31 @@
1717

1818
package org.apache.spark.sql
1919

20-
import org.apache.spark.sql.test.TestSQLContext
21-
import org.apache.spark.sql.types._
2220
import org.scalatest.FunSuite
21+
import org.scalatest.Matchers._
22+
23+
import org.apache.spark.sql.test.TestSQLContext
24+
import org.apache.spark.sql.test.TestSQLContext.implicits._
2325

2426
class DataFrameStatSuite extends FunSuite {
2527

2628
val sqlCtx = TestSQLContext
2729

2830
test("Frequent Items") {
2931
def toLetter(i: Int): String = (i + 96).toChar.toString
30-
val rows = Array.tabulate(1000)(i => if (i % 3 == 0) (1, toLetter(1)) else (i, toLetter(i)))
31-
val rowRdd = sqlCtx.sparkContext.parallelize(rows.map(v => Row(v._1, v._2)))
32-
val schema = StructType(StructField("numbers", IntegerType, false) ::
33-
StructField("letters", StringType, false) :: Nil)
34-
val df = sqlCtx.createDataFrame(rowRdd, schema)
32+
val rows = Array.tabulate(1000) { i =>
33+
if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0)
34+
}
35+
val df = sqlCtx.sparkContext.parallelize(rows).toDF("numbers", "letters", "negDoubles")
3536

3637
val results = df.stat.freqItems(Array("numbers", "letters"), 0.1)
3738
val items = results.collect().head
38-
assert(items.getSeq(0).contains(1),
39-
"1 should be the frequent item for column 'numbers")
40-
assert(items.getSeq(1).contains(toLetter(1)),
41-
s"${toLetter(1)} should be the frequent item for column 'letters'")
39+
items.getSeq[Int](0) should contain (1)
40+
items.getSeq[String](1) should contain (toLetter(1))
41+
42+
val singleColResults = df.stat.freqItems(Array("negDoubles"), 0.1)
43+
val items2 = singleColResults.collect().head
44+
items2.getSeq[Double](0) should contain (-1.0)
45+
4246
}
4347
}

0 commit comments

Comments
 (0)