Skip to content

Commit 408cb77

Browse files
committed
initial commit
make progress make progress2 finished cov implementation. Waiting for freqItems to be merged trying to debug added cov
1 parent 149b3ee commit 408cb77

File tree

6 files changed

+163
-4
lines changed

6 files changed

+163
-4
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434
from pyspark.sql.types import _create_cls, _parse_datatype_json_string
3535

3636

37-
__all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD", "DataFrameNaFunctions"]
37+
__all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD", "DataFrameNaFunctions",
38+
"DataFrameStatFunctions"]
3839

3940

4041
class DataFrame(object):
@@ -93,6 +94,12 @@ def na(self):
9394
"""
9495
return DataFrameNaFunctions(self)
9596

97+
@property
98+
def stat(self):
99+
"""Returns a :class:`DataFrameStatFunctions` for statistic functions.
100+
"""
101+
return DataFrameStatFunctions(self)
102+
96103
@ignore_unicode_prefix
97104
def toJSON(self, use_unicode=True):
98105
"""Converts a :class:`DataFrame` into a :class:`RDD` of string.
@@ -868,6 +875,17 @@ def fillna(self, value, subset=None):
868875

869876
return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx)
870877

878+
def cov(self, col1, col2):
879+
"""
880+
Calculate the covariance for the given columns, specified by their names.
881+
alias for ``stat.cov()``.
882+
883+
:param col1: The name of the first column
884+
:param col2: The name of the second column
885+
:return: the covariance of the columns
886+
"""
887+
return self.stat.cov(col1, col2)
888+
871889
@ignore_unicode_prefix
872890
def withColumn(self, colName, col):
873891
"""Returns a new :class:`DataFrame` by adding a column.
@@ -1311,6 +1329,28 @@ def fill(self, value, subset=None):
13111329
fill.__doc__ = DataFrame.fillna.__doc__
13121330

13131331

1332+
class DataFrameStatFunctions(object):
1333+
"""Functionality for statistic functions with :class:`DataFrame`.
1334+
"""
1335+
1336+
def __init__(self, df):
1337+
self.df = df
1338+
1339+
def cov(self, col1, col2):
1340+
"""
1341+
Calculate the covariance for the given columns, specified by their names.
1342+
1343+
:param col1: The name of the first column
1344+
:param col2: The name of the second column
1345+
:return: the covariance of the columns
1346+
"""
1347+
if not isinstance(col1, str):
1348+
raise ValueError("col1 should be a string.")
1349+
if not isinstance(col2, str):
1350+
raise ValueError("col2 should be a string.")
1351+
return self.df._jdf.stat().cov(col1, col2)
1352+
1353+
13141354
def _test():
13151355
import doctest
13161356
from pyspark.context import SparkContext

python/pyspark/sql/tests.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,11 @@ def test_aggregator(self):
387387
self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0])
388388
self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0])
389389

390+
def test_cov(self):
391+
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
392+
cov = df.stat.cov("a", "b")
393+
self.assertTrue(abs(cov - 16.5) < 1e-6)
394+
390395
def test_math_functions(self):
391396
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
392397
from pyspark.sql import mathfunctions as functions

sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql
1919

2020
import org.apache.spark.annotation.Experimental
21-
import org.apache.spark.sql.execution.stat.FrequentItems
21+
import org.apache.spark.sql.execution.stat._
2222

2323
/**
2424
* :: Experimental ::
@@ -65,4 +65,14 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
6565
def freqItems(cols: List[String]): DataFrame = {
6666
FrequentItems.singlePassFreqItems(df, cols, 0.01)
6767
}
68+
69+
/**
70+
* Calculate the covariance of two numerical columns of a DataFrame.
71+
* @param col1 the name of the first column
72+
* @param col2 the name of the second column
73+
* @return the covariance of the two columns.
74+
*/
75+
def cov(col1: String, col2: String): Double = {
76+
StatFunctions.calculateCov(df, Seq(col1, col2))
77+
}
6878
}
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.stat
19+
20+
import org.apache.spark.sql.types.NumericType
21+
import org.apache.spark.sql.{Column, DataFrame}
22+
23+
private[sql] object StatFunctions {
24+
25+
/** Helper class to simplify tracking and merging counts. */
26+
private class CovarianceCounter extends Serializable {
27+
var xAvg = 0.0
28+
var yAvg = 0.0
29+
var Ck = 0.0
30+
var count = 0
31+
// add an example to the calculation
32+
def add(x: Number, y: Number): this.type = {
33+
val oldX = xAvg
34+
val otherX = x.doubleValue()
35+
val otherY = y.doubleValue()
36+
count += 1
37+
xAvg += (otherX - xAvg) / count
38+
yAvg += (otherY - yAvg) / count
39+
println(oldX)
40+
Ck += (otherY - yAvg) * (otherX - oldX)
41+
this
42+
}
43+
// merge counters from other partitions
44+
def merge(other: CovarianceCounter): this.type = {
45+
val totalCount = count + other.count
46+
Ck += other.Ck +
47+
(xAvg - other.xAvg) * (yAvg - other.yAvg) * (count * other.count) / totalCount
48+
xAvg = (xAvg * count + other.xAvg * other.count) / totalCount
49+
yAvg = (yAvg * count + other.yAvg * other.count) / totalCount
50+
count = totalCount
51+
this
52+
}
53+
// return the covariance for the observed examples
54+
def cov: Double = Ck / count
55+
}
56+
57+
/**
58+
* Calculate the covariance of two numerical columns of a DataFrame.
59+
* @param df The DataFrame
60+
* @param cols the column names
61+
* @return the covariance of the two columns.
62+
*/
63+
private[sql] def calculateCov(df: DataFrame, cols: Seq[String]): Double = {
64+
require(cols.length == 2, "Currently cov supports calculating the covariance " +
65+
"between two columns.")
66+
cols.map(name => (name, df.schema.fields.find(_.name == name))).foreach { case (name, data) =>
67+
require(data.nonEmpty, s"Couldn't find column with name $name")
68+
require(data.get.dataType.isInstanceOf[NumericType], "Covariance calculation for columns " +
69+
s"with dataType ${data.get.dataType} not supported.")
70+
}
71+
val counts = df.select(cols.map(Column(_)):_*).rdd.aggregate(new CovarianceCounter)(
72+
seqOp = (counter, row) => {
73+
counter.add(row.getAs[Number](0), row.getAs[Number](1))
74+
},
75+
combOp = (baseCounter, other) => {
76+
baseCounter.merge(other)
77+
})
78+
counts.cov
79+
}
80+
81+
}

sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,4 +183,11 @@ public void testFrequentItems() {
183183
DataFrame results = df.stat().freqItems(cols, 0.2);
184184
Assert.assertTrue(results.collect()[0].getSeq(0).contains(1));
185185
}
186+
187+
@Test
188+
public void testCovariance() {
189+
DataFrame df = context.table("testData2");
190+
Double result = df.stat().cov("a", "b");
191+
Assert.assertTrue(Math.abs(result) < 1e-6);
192+
}
186193
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@ import org.apache.spark.sql.test.TestSQLContext.implicits._
2525

2626
class DataFrameStatSuite extends FunSuite {
2727

28+
import TestData._
2829
val sqlCtx = TestSQLContext
29-
30+
def toLetter(i: Int): String = (i + 97).toChar.toString
31+
3032
test("Frequent Items") {
31-
def toLetter(i: Int): String = (i + 96).toChar.toString
3233
val rows = Array.tabulate(1000) { i =>
3334
if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0)
3435
}
@@ -44,4 +45,19 @@ class DataFrameStatSuite extends FunSuite {
4445
items2.getSeq[Double](0) should contain (-1.0)
4546

4647
}
48+
49+
test("covariance") {
50+
val rows = Array.tabulate(10)(i => (i, 2.0 * i, toLetter(i)))
51+
val df = sqlCtx.sparkContext.parallelize(rows).toDF("singles", "doubles", "letters")
52+
df.show()
53+
54+
val results = df.stat.cov("singles", "doubles")
55+
println(results)
56+
assert(math.abs(results - 16.5) < 1e-6)
57+
intercept[IllegalArgumentException] {
58+
df.stat.cov("singles", "letters") // doesn't accept non-numerical dataTypes
59+
}
60+
val decimalRes = decimalData.stat.cov("a", "b")
61+
assert(math.abs(decimalRes) < 1e-6)
62+
}
4763
}

0 commit comments

Comments
 (0)