Skip to content

Commit 3d82168

Browse files
committed
made base implementation
implemented frequent items
1 parent f8cbb0a commit 3d82168

File tree

3 files changed

+191
-0
lines changed

3 files changed

+191
-0
lines changed

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
4141
import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD}
4242
import org.apache.spark.sql.jdbc.JDBCWriteDetails
4343
import org.apache.spark.sql.json.JsonRDD
44+
import org.apache.spark.sql.ml.FrequentItems
4445
import org.apache.spark.sql.types._
4546
import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect}
4647
import org.apache.spark.util.Utils
@@ -1414,4 +1415,25 @@ class DataFrame private[sql](
14141415
val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD()
14151416
SerDeUtil.javaToPython(jrdd)
14161417
}
1418+
1419+
/////////////////////////////////////////////////////////////////////////////
1420+
// Statistic functions
1421+
/////////////////////////////////////////////////////////////////////////////
1422+
1423+
// scalastyle:off
1424+
object stat {
1425+
// scalastyle:on
1426+
1427+
/**
1428+
* Finding frequent items for columns, possibly with false positives. Using the algorithm
1429+
* described in `http://www.cs.umd.edu/~samir/498/karp.pdf`.
1430+
*
1431+
* @param cols the names of the columns to search frequent items in
1432+
* @param support The minimum frequency for an item to be considered `frequent`
1433+
* @return A Local DataFrame with the Array of frequent items for each column.
1434+
*/
1435+
def freqItems(cols: Array[String], support: Double): DataFrame = {
1436+
FrequentItems.singlePassFreqItems(toDF(), cols, support)
1437+
}
1438+
}
14171439
}
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.ml
19+
20+
21+
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
22+
import org.apache.spark.sql.types.{StructType, ArrayType, StructField}
23+
24+
import scala.collection.mutable.{Map => MutableMap}
25+
26+
import org.apache.spark.Logging
27+
import org.apache.spark.sql.{Row, DataFrame, functions}
28+
29+
private[sql] object FrequentItems extends Logging {
30+
31+
/**
32+
* Merge two maps of counts. Subtracts the sum of `otherMap` from `baseMap`, and fills in
33+
* any emptied slots with the most frequent of `otherMap`.
34+
* @param baseMap The map containing the global counts
35+
* @param otherMap The map containing the counts for that partition
36+
* @param maxSize The maximum number of counts to keep in memory
37+
*/
38+
private def mergeCounts[A](
39+
baseMap: MutableMap[A, Long],
40+
otherMap: MutableMap[A, Long],
41+
maxSize: Int): Unit = {
42+
val otherSum = otherMap.foldLeft(0L) { case (sum, (k, v)) =>
43+
if (!baseMap.contains(k)) sum + v else sum
44+
}
45+
baseMap.retain((k, v) => v > otherSum)
46+
// sort in decreasing order, so that we will add the most frequent items first
47+
val sorted = otherMap.toSeq.sortBy(-_._2)
48+
var i = 0
49+
val otherSize = sorted.length
50+
while (i < otherSize && baseMap.size < maxSize) {
51+
val keyVal = sorted(i)
52+
baseMap += keyVal._1 -> keyVal._2
53+
i += 1
54+
}
55+
}
56+
57+
58+
/**
59+
* Finding frequent items for columns, possibly with false positives. Using the algorithm
60+
* described in `http://www.cs.umd.edu/~samir/498/karp.pdf`.
61+
* For Internal use only.
62+
*
63+
* @param df The input DataFrame
64+
* @param cols the names of the columns to search frequent items in
65+
* @param support The minimum frequency for an item to be considered `frequent`
66+
* @return A Local DataFrame with the Array of frequent items for each column.
67+
*/
68+
private[sql] def singlePassFreqItems(
69+
df: DataFrame,
70+
cols: Array[String],
71+
support: Double): DataFrame = {
72+
val numCols = cols.length
73+
// number of max items to keep counts for
74+
val sizeOfMap = math.floor(1 / support).toInt
75+
val countMaps = Array.tabulate(numCols)(i => MutableMap.empty[Any, Long])
76+
val originalSchema = df.schema
77+
val colInfo = cols.map { name =>
78+
val index = originalSchema.fieldIndex(name)
79+
val dataType = originalSchema.fields(index)
80+
(index, dataType.dataType)
81+
}
82+
val colIndices = colInfo.map(_._1)
83+
84+
val freqItems: Array[MutableMap[Any, Long]] = df.rdd.aggregate(countMaps)(
85+
seqOp = (counts, row) => {
86+
var i = 0
87+
colIndices.foreach { index =>
88+
val thisMap = counts(i)
89+
val key = row.get(index)
90+
if (thisMap.contains(key)) {
91+
thisMap(key) += 1
92+
} else {
93+
if (thisMap.size < sizeOfMap) {
94+
thisMap += key -> 1
95+
} else {
96+
// TODO: Make this more efficient... A flatMap?
97+
thisMap.retain((k, v) => v > 1)
98+
thisMap.transform((k, v) => v - 1)
99+
}
100+
}
101+
i += 1
102+
}
103+
counts
104+
},
105+
combOp = (baseCounts, counts) => {
106+
var i = 0
107+
while (i < numCols) {
108+
mergeCounts(baseCounts(i), counts(i), sizeOfMap)
109+
i += 1
110+
}
111+
baseCounts
112+
}
113+
)
114+
//
115+
val justItems = freqItems.map(m => m.keys.toSeq)
116+
val resultRow = Row(justItems:_*)
117+
// append frequent Items to the column name for easy debugging
118+
val outputCols = cols.zip(colInfo).map{ v =>
119+
StructField(v._1 + "-freqItems", ArrayType(v._2._2, false))
120+
}
121+
val schema = StructType(outputCols).toAttributes
122+
new DataFrame(df.sqlContext, LocalRelation(schema, Seq(resultRow)))
123+
}
124+
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.ml
19+
20+
import org.apache.spark.sql.Row
21+
import org.apache.spark.sql.types._
22+
import org.scalatest.FunSuite
23+
24+
import org.apache.spark.sql.test.TestSQLContext
25+
26+
class FrequentItemsSuite extends FunSuite {
27+
28+
val sqlCtx = TestSQLContext
29+
30+
test("Frequent Items") {
31+
def toLetter(i: Int): String = (i + 96).toChar.toString
32+
val rows = Array.tabulate(1000)(i => if (i % 3 == 0) (1, toLetter(1)) else (i, toLetter(i)))
33+
val rowRdd = sqlCtx.sparkContext.parallelize(rows.map(v => Row(v._1, v._2)))
34+
val schema = StructType(StructField("numbers", IntegerType, false) ::
35+
StructField("letters", StringType, false) :: Nil)
36+
val df = sqlCtx.createDataFrame(rowRdd, schema)
37+
38+
val results = df.stat.freqItems(Array("numbers", "letters"), 0.1)
39+
val items = results.collect().head
40+
assert(items.getSeq(0).contains(1),
41+
"1 should be the frequent item for column 'numbers")
42+
assert(items.getSeq(1).contains(toLetter(1)),
43+
s"${toLetter(1)} should be the frequent item for column 'letters'")
44+
}
45+
}

0 commit comments

Comments
 (0)