Skip to content

Commit d10babb

Browse files
committed
addressed comments v0.2
1 parent 4b74b24 commit d10babb

File tree

3 files changed

+10
-10
lines changed

3 files changed

+10
-10
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -875,9 +875,9 @@ def fillna(self, value, subset=None):
875875

876876
return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), self.sql_ctx)
877877

878-
def corr(self, col1, col2, method="pearson"):
878+
def corr(self, col1, col2, method=None):
879879
"""
880-
Calculate the correlation of two columns of a DataFrame as a double value. Currently only
880+
Calculates the correlation of two columns of a DataFrame as a double value. Currently only
881881
supports the Pearson Correlation Coefficient.
882882
:func:`DataFrame.corr` and :func:`DataFrameStatFunctions.corr` are aliases.
883883
@@ -889,6 +889,8 @@ def corr(self, col1, col2, method="pearson"):
889889
raise ValueError("col1 should be a string.")
890890
if not isinstance(col2, str):
891891
raise ValueError("col2 should be a string.")
892+
if not method:
893+
method = "pearson"
892894
if not method == "pearson":
893895
raise ValueError("Currently only the calculation of the Pearson Correlation " +
894896
"coefficient is supported.")
@@ -1378,7 +1380,7 @@ class DataFrameStatFunctions(object):
13781380
def __init__(self, df):
13791381
self.df = df
13801382

1381-
def corr(self, col1, col2, method="pearson"):
1383+
def corr(self, col1, col2, method=None):
13821384
return self.df.corr(col1, col2, method)
13831385

13841386
corr.__doc__ = DataFrame.corr.__doc__

sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
3737
* @return The Pearson Correlation Coefficient as a Double.
3838
*/
3939
def corr(col1: String, col2: String, method: String): Double = {
40-
assert(method == "pearson", "Currently only the calculation of the Pearson Correlation " +
40+
require(method == "pearson", "Currently only the calculation of the Pearson Correlation " +
4141
"coefficient is supported.")
4242
StatFunctions.pearsonCorrelation(df, Seq(col1, col2))
4343
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ class DataFrameStatSuite extends FunSuite {
3030
def toLetter(i: Int): String = (i + 97).toChar.toString
3131

3232
test("Frequent Items") {
33-
val rows = Array.tabulate(1000) { i =>
33+
val rows = Seq.tabulate(1000) { i =>
3434
if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0)
3535
}
36-
val df = sqlCtx.sparkContext.parallelize(rows).toDF("numbers", "letters", "negDoubles")
36+
val df = rows.toDF("numbers", "letters", "negDoubles")
3737

3838
val results = df.stat.freqItems(Array("numbers", "letters"), 0.1)
3939
val items = results.collect().head
@@ -46,17 +46,15 @@ class DataFrameStatSuite extends FunSuite {
4646
}
4747

4848
test("pearson correlation") {
49-
val df = sqlCtx.sparkContext.parallelize(
50-
Array.tabulate(10)(i => (i, 2 * i, i * -1.0))).toDF("a", "b", "c")
49+
val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c")
5150
val corr1 = df.stat.corr("a", "b", "pearson")
5251
assert(math.abs(corr1 - 1.0) < 1e-6)
5352
val corr2 = df.stat.corr("a", "c", "pearson")
5453
assert(math.abs(corr2 + 1.0) < 1e-6)
5554
}
5655

5756
test("covariance") {
58-
val rows = Array.tabulate(10)(i => (i, 2.0 * i, toLetter(i)))
59-
val df = sqlCtx.sparkContext.parallelize(rows).toDF("singles", "doubles", "letters")
57+
val df = Seq.tabulate(10)(i => (i, 2.0 * i, toLetter(i))).toDF("singles", "doubles", "letters")
6058

6159
val results = df.stat.cov("singles", "doubles")
6260
assert(math.abs(results - 55.0 / 3) < 1e-6)

0 commit comments

Comments
 (0)