Skip to content

Commit 285b838

Browse files
committed
addressed comments v2.0
1 parent d10babb commit 285b838

File tree

2 files changed

+20
-5
lines changed

2 files changed

+20
-5
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ private[sql] object StatFunctions {
7676
s"with dataType ${data.get.dataType} not supported.")
7777
}
7878
val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType)))
79-
df.select(columns:_*).rdd.aggregate(new CovarianceCounter)(
79+
df.select(columns: _*).rdd.aggregate(new CovarianceCounter)(
8080
seqOp = (counter, row) => {
8181
counter.add(row.getDouble(0), row.getDouble(1))
8282
},

sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,20 +48,35 @@ class DataFrameStatSuite extends FunSuite {
4848
test("pearson correlation") {
4949
val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c")
5050
val corr1 = df.stat.corr("a", "b", "pearson")
51-
assert(math.abs(corr1 - 1.0) < 1e-6)
51+
assert(math.abs(corr1 - 1.0) < 1e-12)
5252
val corr2 = df.stat.corr("a", "c", "pearson")
53-
assert(math.abs(corr2 + 1.0) < 1e-6)
53+
assert(math.abs(corr2 + 1.0) < 1e-12)
54+
// non-trivial example. To reproduce in python, use:
55+
// >>> from scipy.stats import pearsonr
56+
// >>> import numpy as np
57+
// >>> a = np.array(range(20))
58+
// >>> b = np.array([x * x - 2 * x + 3.5 for x in range(20)])
59+
// >>> pearsonr(a, b)
60+
// (0.95723391394758572, 3.8902121417802199e-11)
61+
// In R, use:
62+
// > a <- 0:19
63+
// > b <- mapply(function(x) x * x - 2 * x + 3.5, a)
64+
// > cor(a, b)
65+
// [1] 0.957233913947585835
66+
val df2 = Seq.tabulate(20)(x => (x, x * x - 2 * x + 3.5)).toDF("a", "b")
67+
val corr3 = df2.stat.corr("a", "b", "pearson")
68+
assert(math.abs(corr3 - 0.95723391394758572) < 1e-12)
5469
}
5570

5671
test("covariance") {
5772
val df = Seq.tabulate(10)(i => (i, 2.0 * i, toLetter(i))).toDF("singles", "doubles", "letters")
5873

5974
val results = df.stat.cov("singles", "doubles")
60-
assert(math.abs(results - 55.0 / 3) < 1e-6)
75+
assert(math.abs(results - 55.0 / 3) < 1e-12)
6176
intercept[IllegalArgumentException] {
6277
df.stat.cov("singles", "letters") // doesn't accept non-numerical dataTypes
6378
}
6479
val decimalRes = decimalData.stat.cov("a", "b")
65-
assert(math.abs(decimalRes) < 1e-6)
80+
assert(math.abs(decimalRes) < 1e-12)
6681
}
6782
}

0 commit comments

Comments
 (0)