Skip to content

Commit a682d06

Browse files
committed
ready for PR
1 parent 5c1faba commit a682d06

File tree

6 files changed

+109
-20
lines changed

6 files changed

+109
-20
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -875,6 +875,25 @@ def fillna(self, value, subset=None):
875875

876876
return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx)
877877

878+
def corr(self, col1, col2, method="pearson"):
879+
"""
880+
Calculate the correlation of two columns of a DataFrame as a double value. Currently only
881+
supports the Pearson Correlation Coefficient.
882+
:func:`DataFrame.corr` and :func:`DataFrameStatFunctions.corr` are aliases.
883+
884+
:param col1: The name of the first column
885+
:param col2: The name of the second column
886+
:param method: The correlation method. Currently only supports "pearson"
887+
"""
888+
if not isinstance(col1, str):
889+
raise ValueError("col1 should be a string.")
890+
if not isinstance(col2, str):
891+
raise ValueError("col2 should be a string.")
892+
if not method == "pearson":
893+
raise ValueError("Currently only the calculation of the Pearson Correlation " +
894+
"coefficient is supported.")
895+
return self._jdf.stat().corr(col1, col2, method)
896+
878897
def cov(self, col1, col2):
879898
"""
880899
Calculate the sample covariance for the given columns, specified by their names, as a
@@ -1339,6 +1358,11 @@ class DataFrameStatFunctions(object):
13391358
def __init__(self, df):
13401359
self.df = df
13411360

1361+
def corr(self, col1, col2, method="pearson"):
1362+
return self.df.corr(col1, col2, method)
1363+
1364+
corr.__doc__ = DataFrame.corr.__doc__
1365+
13421366
def cov(self, col1, col2):
13431367
return self.df.cov(col1, col2)
13441368

python/pyspark/sql/tests.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,12 @@ def test_aggregator(self):
387387
self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0])
388388
self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0])
389389

390+
def test_corr(self):
391+
import math
392+
df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF()
393+
corr = df.stat.corr("a", "b")
394+
self.assertTrue(abs(corr - 0.95734012) < 1e-6)
395+
390396
def test_cov(self):
391397
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
392398
cov = df.stat.cov("a", "b")

sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,32 @@ import org.apache.spark.sql.execution.stat._
2727
@Experimental
2828
final class DataFrameStatFunctions private[sql](df: DataFrame) {
2929

30+
/**
31+
* Calculate the correlation of two columns of a DataFrame. Currently only supports the Pearson
32+
* Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in
33+
* MLlib's Statistics.
34+
*
35+
* @param col1 the name of the column
36+
* @param col2 the name of the column to calculate the correlation against
37+
* @return The Pearson Correlation Coefficient as a Double.
38+
*/
39+
def corr(col1: String, col2: String, method: String): Double = {
40+
assert(method == "pearson", "Currently only the calculation of the Pearson Correlation " +
41+
"coefficient is supported.")
42+
StatFunctions.pearsonCorrelation(df, Seq(col1, col2))
43+
}
44+
45+
/**
46+
* Java Friendly implementation to calculate the Pearson correlation coefficient of two columns.
47+
*
48+
* @param col1 the name of the column
49+
* @param col2 the name of the column to calculate the correlation against
50+
* @return The Pearson Correlation Coefficient as a Double.
51+
*/
52+
def corr(col1: String, col2: String): Double = {
53+
corr(col1, col2, "pearson")
54+
}
55+
3056
/**
3157
* Finding frequent items for columns, possibly with false positives. Using the
3258
* frequent element count algorithm described in

sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,43 +23,51 @@ import org.apache.spark.sql.types.{DoubleType, NumericType}
2323

2424
private[sql] object StatFunctions {
2525

26+
/** Calculate the Pearson Correlation Coefficient for the given columns */
27+
private[sql] def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = {
28+
val counts = collectStatisticalData(df, cols)
29+
counts.Ck / math.sqrt(counts.MkX * counts.MkY)
30+
}
31+
2632
/** Helper class to simplify tracking and merging counts. */
2733
private class CovarianceCounter extends Serializable {
28-
var xAvg = 0.0
29-
var yAvg = 0.0
30-
var Ck = 0.0
31-
var count = 0L
34+
var xAvg = 0.0 // the mean of all examples seen so far in col1
35+
var yAvg = 0.0 // the mean of all examples seen so far in col2
36+
var Ck = 0.0 // the co-moment after k examples
37+
var MkX = 0.0 // sum of squares of differences from the (current) mean for col1
38+
var MkY = 0.0 // sum of squares of differences from the (current) mean for col1
39+
var count = 0L // count of observed examples
3240
// add an example to the calculation
3341
def add(x: Double, y: Double): this.type = {
34-
val oldX = xAvg
42+
val deltaX = x - xAvg
43+
val deltaY = y - yAvg
3544
count += 1
36-
xAvg += (x - xAvg) / count
37-
yAvg += (y - yAvg) / count
38-
Ck += (y - yAvg) * (x - oldX)
45+
xAvg += deltaX / count
46+
yAvg += deltaY / count
47+
Ck += deltaX * (y - yAvg)
48+
MkX += deltaX * (x - xAvg)
49+
MkY += deltaY * (y - yAvg)
3950
this
4051
}
4152
// merge counters from other partitions. Formula can be found at:
42-
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Covariance
53+
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
4354
def merge(other: CovarianceCounter): this.type = {
4455
val totalCount = count + other.count
45-
Ck += other.Ck +
46-
(xAvg - other.xAvg) * (yAvg - other.yAvg) * count / totalCount * other.count
56+
val deltaX = xAvg - other.xAvg
57+
val deltaY = yAvg - other.yAvg
58+
Ck += other.Ck + deltaX * deltaY * count / totalCount * other.count
4759
xAvg = (xAvg * count + other.xAvg * other.count) / totalCount
4860
yAvg = (yAvg * count + other.yAvg * other.count) / totalCount
61+
MkX += other.MkX + deltaX * deltaX * count / totalCount * other.count
62+
MkY += other.MkY + deltaY * deltaY * count / totalCount * other.count
4963
count = totalCount
5064
this
5165
}
5266
// return the sample covariance for the observed examples
5367
def cov: Double = Ck / (count - 1)
5468
}
5569

56-
/**
57-
* Calculate the covariance of two numerical columns of a DataFrame.
58-
* @param df The DataFrame
59-
* @param cols the column names
60-
* @return the covariance of the two columns.
61-
*/
62-
private[sql] def calculateCov(df: DataFrame, cols: Seq[String]): Double = {
70+
private def collectStatisticalData(df: DataFrame, cols: Seq[String]): CovarianceCounter = {
6371
require(cols.length == 2, "Currently cov supports calculating the covariance " +
6472
"between two columns.")
6573
cols.map(name => (name, df.schema.fields.find(_.name == name))).foreach { case (name, data) =>
@@ -68,13 +76,23 @@ private[sql] object StatFunctions {
6876
s"with dataType ${data.get.dataType} not supported.")
6977
}
7078
val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType)))
71-
val counts = df.select(columns:_*).rdd.aggregate(new CovarianceCounter)(
79+
df.select(columns:_*).rdd.aggregate(new CovarianceCounter)(
7280
seqOp = (counter, row) => {
7381
counter.add(row.getDouble(0), row.getDouble(1))
7482
},
7583
combOp = (baseCounter, other) => {
7684
baseCounter.merge(other)
77-
})
85+
})
86+
}
87+
88+
/**
89+
* Calculate the covariance of two numerical columns of a DataFrame.
90+
* @param df The DataFrame
91+
* @param cols the column names
92+
* @return the covariance of the two columns.
93+
*/
94+
private[sql] def calculateCov(df: DataFrame, cols: Seq[String]): Double = {
95+
val counts = collectStatisticalData(df, cols)
7896
counts.cov
7997
}
8098
}

sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,13 @@ public void testFrequentItems() {
187187
Assert.assertTrue(results.collect()[0].getSeq(0).contains(1));
188188
}
189189

190+
@Test
191+
public void testCorrelation() {
192+
DataFrame df = context.table("testData2");
193+
Double pearsonCorr = df.stat().corr("a", "b", "pearson");
194+
Assert.assertTrue(Math.abs(pearsonCorr) < 1e-6);
195+
}
196+
190197
@Test
191198
public void testCovariance() {
192199
DataFrame df = context.table("testData2");

sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,15 @@ class DataFrameStatSuite extends FunSuite {
4343
val singleColResults = df.stat.freqItems(Array("negDoubles"), 0.1)
4444
val items2 = singleColResults.collect().head
4545
items2.getSeq[Double](0) should contain (-1.0)
46+
}
4647

48+
test("pearson correlation") {
49+
val df = sqlCtx.sparkContext.parallelize(
50+
Array.tabulate(10)(i => (i, 2 * i, i * -1.0))).toDF("a", "b", "c")
51+
val corr1 = df.stat.corr("a", "b", "pearson")
52+
assert(math.abs(corr1 - 1.0) < 1e-6)
53+
val corr2 = df.stat.corr("a", "c", "pearson")
54+
assert(math.abs(corr2 + 1.0) < 1e-6)
4755
}
4856

4957
test("covariance") {

0 commit comments

Comments
 (0)