Skip to content

Commit 4dc8d74

Browse files
brkyvzrxin
authored andcommitted
[SPARK-7240][SQL] Single pass covariance calculation for dataframes
Added the calculation of covariance between two columns to DataFrames. cc mengxr rxin Author: Burak Yavuz <[email protected]> Closes apache#5825 from brkyvz/df-cov and squashes the following commits: cb18046 [Burak Yavuz] changed to sample covariance f2e862b [Burak Yavuz] fixed failed test 51e39b8 [Burak Yavuz] moved implementation 0c6a759 [Burak Yavuz] addressed math comments 8456eca [Burak Yavuz] fix pyStyle3 aa2ad29 [Burak Yavuz] fix pyStyle2 4e97a50 [Burak Yavuz] Merge branch 'master' of github.com:apache/spark into df-cov e3b0b85 [Burak Yavuz] addressed comments v0.1 a7115f1 [Burak Yavuz] fix python style 7dc6dbc [Burak Yavuz] reorder imports 408cb77 [Burak Yavuz] initial commit
1 parent 7b5dd3e commit 4dc8d74

File tree

7 files changed

+157
-5
lines changed

7 files changed

+157
-5
lines changed

python/pyspark/sql/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@
5454
from pyspark.sql.types import Row
5555
from pyspark.sql.context import SQLContext, HiveContext
5656
from pyspark.sql.dataframe import DataFrame, GroupedData, Column, SchemaRDD, DataFrameNaFunctions
57+
from pyspark.sql.dataframe import DataFrameStatFunctions
5758

5859
__all__ = [
59-
'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row', 'DataFrameNaFunctions'
60+
'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row',
61+
'DataFrameNaFunctions', 'DataFrameStatFunctions'
6062
]

python/pyspark/sql/dataframe.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434
from pyspark.sql.types import _create_cls, _parse_datatype_json_string
3535

3636

37-
__all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD", "DataFrameNaFunctions"]
37+
__all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD", "DataFrameNaFunctions",
38+
"DataFrameStatFunctions"]
3839

3940

4041
class DataFrame(object):
@@ -93,6 +94,12 @@ def na(self):
9394
"""
9495
return DataFrameNaFunctions(self)
9596

97+
@property
98+
def stat(self):
99+
"""Returns a :class:`DataFrameStatFunctions` for statistic functions.
100+
"""
101+
return DataFrameStatFunctions(self)
102+
96103
@ignore_unicode_prefix
97104
def toJSON(self, use_unicode=True):
98105
"""Converts a :class:`DataFrame` into a :class:`RDD` of string.
@@ -868,6 +875,20 @@ def fillna(self, value, subset=None):
868875

869876
return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx)
870877

878+
def cov(self, col1, col2):
879+
"""
880+
Calculate the sample covariance for the given columns, specified by their names, as a
881+
double value. :func:`DataFrame.cov` and :func:`DataFrameStatFunctions.cov` are aliases.
882+
883+
:param col1: The name of the first column
884+
:param col2: The name of the second column
885+
"""
886+
if not isinstance(col1, str):
887+
raise ValueError("col1 should be a string.")
888+
if not isinstance(col2, str):
889+
raise ValueError("col2 should be a string.")
890+
return self._jdf.stat().cov(col1, col2)
891+
871892
@ignore_unicode_prefix
872893
def withColumn(self, colName, col):
873894
"""Returns a new :class:`DataFrame` by adding a column.
@@ -1311,6 +1332,19 @@ def fill(self, value, subset=None):
13111332
fill.__doc__ = DataFrame.fillna.__doc__
13121333

13131334

1335+
class DataFrameStatFunctions(object):
1336+
"""Functionality for statistic functions with :class:`DataFrame`.
1337+
"""
1338+
1339+
def __init__(self, df):
1340+
self.df = df
1341+
1342+
def cov(self, col1, col2):
1343+
return self.df.cov(col1, col2)
1344+
1345+
cov.__doc__ = DataFrame.cov.__doc__
1346+
1347+
13141348
def _test():
13151349
import doctest
13161350
from pyspark.context import SparkContext

python/pyspark/sql/tests.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,11 @@ def test_aggregator(self):
387387
self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0])
388388
self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0])
389389

390+
def test_cov(self):
391+
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
392+
cov = df.stat.cov("a", "b")
393+
self.assertTrue(abs(cov - 55.0 / 3) < 1e-6)
394+
390395
def test_math_functions(self):
391396
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
392397
from pyspark.sql import mathfunctions as functions

sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql
1919

2020
import org.apache.spark.annotation.Experimental
21-
import org.apache.spark.sql.execution.stat.FrequentItems
21+
import org.apache.spark.sql.execution.stat._
2222

2323
/**
2424
* :: Experimental ::
@@ -65,4 +65,14 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
6565
def freqItems(cols: List[String]): DataFrame = {
6666
FrequentItems.singlePassFreqItems(df, cols, 0.01)
6767
}
68+
69+
/**
70+
* Calculate the sample covariance of two numerical columns of a DataFrame.
71+
* @param col1 the name of the first column
72+
* @param col2 the name of the second column
73+
* @return the covariance of the two columns.
74+
*/
75+
def cov(col1: String, col2: String): Double = {
76+
StatFunctions.calculateCov(df, Seq(col1, col2))
77+
}
6878
}
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.stat
19+
20+
import org.apache.spark.sql.catalyst.expressions.Cast
21+
import org.apache.spark.sql.{Column, DataFrame}
22+
import org.apache.spark.sql.types.{DoubleType, NumericType}
23+
24+
private[sql] object StatFunctions {
25+
26+
/** Helper class to simplify tracking and merging counts. */
27+
private class CovarianceCounter extends Serializable {
28+
var xAvg = 0.0
29+
var yAvg = 0.0
30+
var Ck = 0.0
31+
var count = 0L
32+
// add an example to the calculation
33+
def add(x: Double, y: Double): this.type = {
34+
val oldX = xAvg
35+
count += 1
36+
xAvg += (x - xAvg) / count
37+
yAvg += (y - yAvg) / count
38+
Ck += (y - yAvg) * (x - oldX)
39+
this
40+
}
41+
// merge counters from other partitions. Formula can be found at:
42+
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Covariance
43+
def merge(other: CovarianceCounter): this.type = {
44+
val totalCount = count + other.count
45+
Ck += other.Ck +
46+
(xAvg - other.xAvg) * (yAvg - other.yAvg) * count / totalCount * other.count
47+
xAvg = (xAvg * count + other.xAvg * other.count) / totalCount
48+
yAvg = (yAvg * count + other.yAvg * other.count) / totalCount
49+
count = totalCount
50+
this
51+
}
52+
// return the sample covariance for the observed examples
53+
def cov: Double = Ck / (count - 1)
54+
}
55+
56+
/**
57+
* Calculate the covariance of two numerical columns of a DataFrame.
58+
* @param df The DataFrame
59+
* @param cols the column names
60+
* @return the covariance of the two columns.
61+
*/
62+
private[sql] def calculateCov(df: DataFrame, cols: Seq[String]): Double = {
63+
require(cols.length == 2, "Currently cov supports calculating the covariance " +
64+
"between two columns.")
65+
cols.map(name => (name, df.schema.fields.find(_.name == name))).foreach { case (name, data) =>
66+
require(data.nonEmpty, s"Couldn't find column with name $name")
67+
require(data.get.dataType.isInstanceOf[NumericType], "Covariance calculation for columns " +
68+
s"with dataType ${data.get.dataType} not supported.")
69+
}
70+
val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType)))
71+
val counts = df.select(columns:_*).rdd.aggregate(new CovarianceCounter)(
72+
seqOp = (counter, row) => {
73+
counter.add(row.getDouble(0), row.getDouble(1))
74+
},
75+
combOp = (baseCounter, other) => {
76+
baseCounter.merge(other)
77+
})
78+
counts.cov
79+
}
80+
}

sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,4 +186,11 @@ public void testFrequentItems() {
186186
DataFrame results = df.stat().freqItems(cols, 0.2);
187187
Assert.assertTrue(results.collect()[0].getSeq(0).contains(1));
188188
}
189+
190+
@Test
191+
public void testCovariance() {
192+
DataFrame df = context.table("testData2");
193+
Double result = df.stat().cov("a", "b");
194+
Assert.assertTrue(Math.abs(result) < 1e-6);
195+
}
189196
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@ import org.apache.spark.sql.test.TestSQLContext.implicits._
2525

2626
class DataFrameStatSuite extends FunSuite {
2727

28+
import TestData._
2829
val sqlCtx = TestSQLContext
29-
30+
def toLetter(i: Int): String = (i + 97).toChar.toString
31+
3032
test("Frequent Items") {
31-
def toLetter(i: Int): String = (i + 96).toChar.toString
3233
val rows = Array.tabulate(1000) { i =>
3334
if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0)
3435
}
@@ -44,4 +45,17 @@ class DataFrameStatSuite extends FunSuite {
4445
items2.getSeq[Double](0) should contain (-1.0)
4546

4647
}
48+
49+
test("covariance") {
50+
val rows = Array.tabulate(10)(i => (i, 2.0 * i, toLetter(i)))
51+
val df = sqlCtx.sparkContext.parallelize(rows).toDF("singles", "doubles", "letters")
52+
53+
val results = df.stat.cov("singles", "doubles")
54+
assert(math.abs(results - 55.0 / 3) < 1e-6)
55+
intercept[IllegalArgumentException] {
56+
df.stat.cov("singles", "letters") // doesn't accept non-numerical dataTypes
57+
}
58+
val decimalRes = decimalData.stat.cov("a", "b")
59+
assert(math.abs(decimalRes) < 1e-6)
60+
}
4761
}

0 commit comments

Comments
 (0)