Skip to content

Commit e3b0b85

Browse files
committed
addressed comments v0.1
1 parent a7115f1 commit e3b0b85

File tree

4 files changed

+16
-25
lines changed

4 files changed

+16
-25
lines changed

python/pyspark/sql/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,9 @@
5353

5454
from pyspark.sql.types import Row
5555
from pyspark.sql.context import SQLContext, HiveContext
56-
from pyspark.sql.dataframe import DataFrame, GroupedData, Column, SchemaRDD, DataFrameNaFunctions
56+
from pyspark.sql.dataframe import DataFrame, GroupedData, Column, SchemaRDD, DataFrameNaFunctions, DataFrameStatFunctions
5757

5858
__all__ = [
59-
'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row', 'DataFrameNaFunctions'
59+
'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row',
60+
'DataFrameNaFunctions', 'DataFrameStatFunctions'
6061
]

python/pyspark/sql/dataframe.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -877,12 +877,11 @@ def fillna(self, value, subset=None):
877877

878878
def cov(self, col1, col2):
879879
"""
880-
Calculate the covariance for the given columns, specified by their names.
881-
alias for ``stat.cov()``.
880+
Calculate the covariance for the given columns, specified by their names as a double value.
881+
:func:`DataFrame.cov` and :func:`DataFrameStatFunctions.cov` are aliases.
882882
883883
:param col1: The name of the first column
884884
:param col2: The name of the second column
885-
:return: the covariance of the columns
886885
"""
887886
return self.stat.cov(col1, col2)
888887

@@ -1337,19 +1336,13 @@ def __init__(self, df):
13371336
self.df = df
13381337

13391338
def cov(self, col1, col2):
1340-
"""
1341-
Calculate the covariance for the given columns, specified by their names.
1342-
1343-
:param col1: The name of the first column
1344-
:param col2: The name of the second column
1345-
:return: the covariance of the columns
1346-
"""
13471339
if not isinstance(col1, str):
13481340
raise ValueError("col1 should be a string.")
13491341
if not isinstance(col2, str):
13501342
raise ValueError("col2 should be a string.")
13511343
return self.df._jdf.stat().cov(col1, col2)
1352-
1344+
1345+
cov.__doc__ = DataFrame.cov.__doc__
13531346

13541347
def _test():
13551348
import doctest

sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717

1818
package org.apache.spark.sql.execution.stat
1919

20+
import org.apache.spark.sql.catalyst.expressions.Cast
2021
import org.apache.spark.sql.{Column, DataFrame}
21-
import org.apache.spark.sql.types.NumericType
22+
import org.apache.spark.sql.types.{DoubleType, NumericType}
2223

2324
private[sql] object StatFunctions {
2425

@@ -29,15 +30,12 @@ private[sql] object StatFunctions {
2930
var Ck = 0.0
3031
var count = 0
3132
// add an example to the calculation
32-
def add(x: Number, y: Number): this.type = {
33+
def add(x: Double, y: Double): this.type = {
3334
val oldX = xAvg
34-
val otherX = x.doubleValue()
35-
val otherY = y.doubleValue()
3635
count += 1
37-
xAvg += (otherX - xAvg) / count
38-
yAvg += (otherY - yAvg) / count
39-
println(oldX)
40-
Ck += (otherY - yAvg) * (otherX - oldX)
36+
xAvg += (x - xAvg) / count
37+
yAvg += (y - yAvg) / count
38+
Ck += (y - yAvg) * (x - oldX)
4139
this
4240
}
4341
// merge counters from other partitions
@@ -68,9 +66,10 @@ private[sql] object StatFunctions {
6866
require(data.get.dataType.isInstanceOf[NumericType], "Covariance calculation for columns " +
6967
s"with dataType ${data.get.dataType} not supported.")
7068
}
71-
val counts = df.select(cols.map(Column(_)):_*).rdd.aggregate(new CovarianceCounter)(
69+
val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType)))
70+
val counts = df.select(columns:_*).rdd.aggregate(new CovarianceCounter)(
7271
seqOp = (counter, row) => {
73-
counter.add(row.getAs[Number](0), row.getAs[Number](1))
72+
counter.add(row.getDouble(0), row.getDouble(1))
7473
},
7574
combOp = (baseCounter, other) => {
7675
baseCounter.merge(other)

sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,8 @@ class DataFrameStatSuite extends FunSuite {
4949
test("covariance") {
5050
val rows = Array.tabulate(10)(i => (i, 2.0 * i, toLetter(i)))
5151
val df = sqlCtx.sparkContext.parallelize(rows).toDF("singles", "doubles", "letters")
52-
df.show()
5352

5453
val results = df.stat.cov("singles", "doubles")
55-
println(results)
5654
assert(math.abs(results - 16.5) < 1e-6)
5755
intercept[IllegalArgumentException] {
5856
df.stat.cov("singles", "letters") // doesn't accept non-numerical dataTypes

0 commit comments

Comments
 (0)