Skip to content

Commit 149b3ee

Browse files
brkyvzrxin
authored andcommitted
[SPARK-7242][SQL][MLLIB] Frequent items for DataFrames
Finding frequent items with possibly false positives, using the algorithm described in `http://www.cs.umd.edu/~samir/498/karp.pdf`. public API under: ``` df.stat.freqItems(cols: Array[String], support: Double = 0.001): DataFrame ``` The output is a local DataFrame having the input column names with `-freqItems` appended to it. This is a single pass algorithm that may return false positives, but no false negatives. cc mengxr rxin Let's get the implementations in, I can add python API in a follow up PR. Author: Burak Yavuz <[email protected]> Closes apache#5799 from brkyvz/freq-items and squashes the following commits: a6ec82c [Burak Yavuz] addressed comments v? 39b1bba [Burak Yavuz] removed toSeq 0915e23 [Burak Yavuz] addressed comments v2.1 3a5c177 [Burak Yavuz] addressed comments v2.0 482e741 [Burak Yavuz] removed old import 38e784d [Burak Yavuz] addressed comments v1.0 8279d4d [Burak Yavuz] added default value for support 3d82168 [Burak Yavuz] made base implementation
1 parent 1c3e402 commit 149b3ee

File tree

5 files changed

+256
-5
lines changed

5 files changed

+256
-5
lines changed

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,17 @@ class DataFrame private[sql](
330330
*/
331331
def na: DataFrameNaFunctions = new DataFrameNaFunctions(this)
332332

333+
/**
334+
* Returns a [[DataFrameStatFunctions]] for working statistic functions support.
335+
* {{{
336+
* // Finding frequent items in column with name 'a'.
337+
* df.stat.freqItems(Seq("a"))
338+
* }}}
339+
*
340+
* @group dfops
341+
*/
342+
def stat: DataFrameStatFunctions = new DataFrameStatFunctions(this)
343+
333344
/**
334345
* Cartesian join with another [[DataFrame]].
335346
*
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql
19+
20+
import org.apache.spark.annotation.Experimental
21+
import org.apache.spark.sql.execution.stat.FrequentItems
22+
23+
/**
24+
* :: Experimental ::
25+
* Statistic functions for [[DataFrame]]s.
26+
*/
27+
@Experimental
28+
final class DataFrameStatFunctions private[sql](df: DataFrame) {
29+
30+
/**
31+
* Finding frequent items for columns, possibly with false positives. Using the
32+
* frequent element count algorithm described in
33+
* [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]].
34+
* The `support` should be greater than 1e-4.
35+
*
36+
* @param cols the names of the columns to search frequent items in.
37+
* @param support The minimum frequency for an item to be considered `frequent`. Should be greater
38+
* than 1e-4.
39+
* @return A Local DataFrame with the Array of frequent items for each column.
40+
*/
41+
def freqItems(cols: Array[String], support: Double): DataFrame = {
42+
FrequentItems.singlePassFreqItems(df, cols, support)
43+
}
44+
45+
/**
46+
* Runs `freqItems` with a default `support` of 1%.
47+
*
48+
* @param cols the names of the columns to search frequent items in.
49+
* @return A Local DataFrame with the Array of frequent items for each column.
50+
*/
51+
def freqItems(cols: Array[String]): DataFrame = {
52+
FrequentItems.singlePassFreqItems(df, cols, 0.01)
53+
}
54+
55+
/**
56+
* Python friendly implementation for `freqItems`
57+
*/
58+
def freqItems(cols: List[String], support: Double): DataFrame = {
59+
FrequentItems.singlePassFreqItems(df, cols, support)
60+
}
61+
62+
/**
63+
* Python friendly implementation for `freqItems` with a default `support` of 1%.
64+
*/
65+
def freqItems(cols: List[String]): DataFrame = {
66+
FrequentItems.singlePassFreqItems(df, cols, 0.01)
67+
}
68+
}
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.stat
19+
20+
import scala.collection.mutable.{Map => MutableMap}
21+
22+
import org.apache.spark.Logging
23+
import org.apache.spark.sql.{Column, DataFrame, Row}
24+
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
25+
import org.apache.spark.sql.types.{ArrayType, StructField, StructType}
26+
27+
private[sql] object FrequentItems extends Logging {
28+
29+
/** A helper class wrapping `MutableMap[Any, Long]` for simplicity. */
30+
private class FreqItemCounter(size: Int) extends Serializable {
31+
val baseMap: MutableMap[Any, Long] = MutableMap.empty[Any, Long]
32+
33+
/**
34+
* Add a new example to the counts if it exists, otherwise deduct the count
35+
* from existing items.
36+
*/
37+
def add(key: Any, count: Long): this.type = {
38+
if (baseMap.contains(key)) {
39+
baseMap(key) += count
40+
} else {
41+
if (baseMap.size < size) {
42+
baseMap += key -> count
43+
} else {
44+
// TODO: Make this more efficient... A flatMap?
45+
baseMap.retain((k, v) => v > count)
46+
baseMap.transform((k, v) => v - count)
47+
}
48+
}
49+
this
50+
}
51+
52+
/**
53+
* Merge two maps of counts.
54+
* @param other The map containing the counts for that partition
55+
*/
56+
def merge(other: FreqItemCounter): this.type = {
57+
other.baseMap.foreach { case (k, v) =>
58+
add(k, v)
59+
}
60+
this
61+
}
62+
}
63+
64+
/**
65+
* Finding frequent items for columns, possibly with false positives. Using the
66+
* frequent element count algorithm described in
67+
* [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]].
68+
* The `support` should be greater than 1e-4.
69+
* For Internal use only.
70+
*
71+
* @param df The input DataFrame
72+
* @param cols the names of the columns to search frequent items in
73+
* @param support The minimum frequency for an item to be considered `frequent`. Should be greater
74+
* than 1e-4.
75+
* @return A Local DataFrame with the Array of frequent items for each column.
76+
*/
77+
private[sql] def singlePassFreqItems(
78+
df: DataFrame,
79+
cols: Seq[String],
80+
support: Double): DataFrame = {
81+
require(support >= 1e-4, s"support ($support) must be greater than 1e-4.")
82+
val numCols = cols.length
83+
// number of max items to keep counts for
84+
val sizeOfMap = (1 / support).toInt
85+
val countMaps = Seq.tabulate(numCols)(i => new FreqItemCounter(sizeOfMap))
86+
val originalSchema = df.schema
87+
val colInfo = cols.map { name =>
88+
val index = originalSchema.fieldIndex(name)
89+
(name, originalSchema.fields(index).dataType)
90+
}
91+
92+
val freqItems = df.select(cols.map(Column(_)):_*).rdd.aggregate(countMaps)(
93+
seqOp = (counts, row) => {
94+
var i = 0
95+
while (i < numCols) {
96+
val thisMap = counts(i)
97+
val key = row.get(i)
98+
thisMap.add(key, 1L)
99+
i += 1
100+
}
101+
counts
102+
},
103+
combOp = (baseCounts, counts) => {
104+
var i = 0
105+
while (i < numCols) {
106+
baseCounts(i).merge(counts(i))
107+
i += 1
108+
}
109+
baseCounts
110+
}
111+
)
112+
val justItems = freqItems.map(m => m.baseMap.keys.toSeq)
113+
val resultRow = Row(justItems:_*)
114+
// append frequent Items to the column name for easy debugging
115+
val outputCols = colInfo.map { v =>
116+
StructField(v._1 + "_freqItems", ArrayType(v._2, false))
117+
}
118+
val schema = StructType(outputCols).toAttributes
119+
new DataFrame(df.sqlContext, LocalRelation(schema, Seq(resultRow)))
120+
}
121+
}

sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,7 @@
2222

2323
import org.apache.spark.api.java.JavaRDD;
2424
import org.apache.spark.api.java.JavaSparkContext;
25-
import org.apache.spark.sql.DataFrame;
26-
import org.apache.spark.sql.Row;
27-
import org.apache.spark.sql.SQLContext;
28-
import org.apache.spark.sql.TestData$;
25+
import org.apache.spark.sql.*;
2926
import org.apache.spark.sql.test.TestSQLContext;
3027
import org.apache.spark.sql.test.TestSQLContext$;
3128
import org.apache.spark.sql.types.*;
@@ -178,5 +175,12 @@ public void testCreateDataFrameFromJavaBeans() {
178175
Assert.assertEquals(bean.getD().get(i), d.apply(i));
179176
}
180177
}
181-
178+
179+
@Test
180+
public void testFrequentItems() {
181+
DataFrame df = context.table("testData2");
182+
String[] cols = new String[]{"a"};
183+
DataFrame results = df.stat().freqItems(cols, 0.2);
184+
Assert.assertTrue(results.collect()[0].getSeq(0).contains(1));
185+
}
182186
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql
19+
20+
import org.scalatest.FunSuite
21+
import org.scalatest.Matchers._
22+
23+
import org.apache.spark.sql.test.TestSQLContext
24+
import org.apache.spark.sql.test.TestSQLContext.implicits._
25+
26+
class DataFrameStatSuite extends FunSuite {
27+
28+
val sqlCtx = TestSQLContext
29+
30+
test("Frequent Items") {
31+
def toLetter(i: Int): String = (i + 96).toChar.toString
32+
val rows = Array.tabulate(1000) { i =>
33+
if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0)
34+
}
35+
val df = sqlCtx.sparkContext.parallelize(rows).toDF("numbers", "letters", "negDoubles")
36+
37+
val results = df.stat.freqItems(Array("numbers", "letters"), 0.1)
38+
val items = results.collect().head
39+
items.getSeq[Int](0) should contain (1)
40+
items.getSeq[String](1) should contain (toLetter(1))
41+
42+
val singleColResults = df.stat.freqItems(Array("negDoubles"), 0.1)
43+
val items2 = singleColResults.collect().head
44+
items2.getSeq[Double](0) should contain (-1.0)
45+
46+
}
47+
}

0 commit comments

Comments
 (0)