Skip to content

Commit a85a889

Browse files
committed
moving code
1 parent a2d7e2d commit a85a889

File tree

3 files changed

+64
-33
lines changed

3 files changed

+64
-33
lines changed

mllib/src/main/scala/org/apache/spark/ml/stat/Statistics.scala

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.ml.stat
1919

2020
import scala.collection.JavaConverters._
2121

22-
import org.apache.spark.annotation.Since
22+
import org.apache.spark.annotation.{Experimental, Since}
2323
import org.apache.spark.ml.linalg.{SQLDataTypes, Vector}
2424
import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
2525
import org.apache.spark.mllib.stat.{Statistics => OldStatistics}
@@ -30,17 +30,18 @@ import org.apache.spark.sql.types.{StructField, StructType}
3030
* API for statistical functions in MLlib, compatible with Dataframes and Datasets.
3131
*
3232
* The functions in this package generalize the functions in [[org.apache.spark.sql.Dataset.stat]]
33-
* to MLlib's Vector types.
33+
* to spark.ml's Vector types.
3434
*/
3535
@Since("2.2.0")
36+
@Experimental
3637
object Statistics {
3738

3839
/**
3940
* Compute the correlation matrix for the input RDD of Vectors using the specified method.
4041
* Methods currently supported: `pearson` (default), `spearman`.
4142
*
42-
* @param dataset a dataset or a dataframe
43-
* @param column the name of the column of vectors for which the correlation coefficient needs
43+
* @param dataset A dataset or a dataframe
44+
* @param column The name of the column of vectors for which the correlation coefficient needs
4445
* to be computed. This must be a column of the dataset, and it must contain
4546
* Vector objects.
4647
* @param method String specifying the method to use for computing correlation.
@@ -63,12 +64,10 @@ object Statistics {
6364
* which is fairly costly. Cache the input RDD before calling corr with `method = "spearman"` to
6465
* avoid recomputing the common lineage.
6566
*/
66-
// TODO: how do we handle missing values?
6767
@Since("2.2.0")
6868
def corr(dataset: Dataset[_], column: String, method: String): DataFrame = {
6969
val rdd = dataset.select(column).rdd.map {
7070
case Row(v: Vector) => OldVectors.fromML(v)
71-
// case r: GenericRowWithSchema => OldVectors.fromML(r.getAs[Vector](0))
7271
}
7372
val oldM = OldStatistics.corr(rdd, method)
7473
val name = s"$method($column)"
@@ -78,8 +77,8 @@ object Statistics {
7877

7978
/**
8079
* Compute the correlation matrix for the input Dataset of Vectors.
81-
* @param dataset a dataset or dataframe
82-
* @param column a column of this dataset
80+
* @param dataset A dataset or dataframe
81+
* @param column A column of this dataset
8382
* @return
8483
*/
8584
@Since("2.2.0")

mllib/src/test/scala/org/apache/spark/ml/stat/StatisticsSuite.scala

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,20 @@
1717

1818
package org.apache.spark.ml.stat
1919

20-
import breeze.linalg.{DenseMatrix => BDM, Matrix => BM}
20+
import breeze.linalg.{DenseMatrix => BDM}
2121

2222
import org.apache.spark.SparkFunSuite
2323
import org.apache.spark.internal.Logging
2424
import org.apache.spark.ml.linalg.Matrix
2525
import org.apache.spark.ml.linalg.Vectors
26+
import org.apache.spark.ml.util.LinalgUtils
2627
import org.apache.spark.mllib.util.MLlibTestSparkContext
2728
import org.apache.spark.sql.{DataFrame, Row}
2829

2930

3031
class StatisticsSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
3132

32-
import StatisticsSuite._
33+
import LinalgUtils._
3334

3435
val xData = Array(1.0, 0.0, -2.0)
3536
val yData = Array(4.0, 5.0, 3.0)
@@ -77,26 +78,3 @@ class StatisticsSuite extends SparkFunSuite with MLlibTestSparkContext with Logg
7778
}
7879

7980
}
80-
81-
82-
object StatisticsSuite extends Logging {
83-
84-
def approxEqual(v1: Double, v2: Double, threshold: Double = 1e-6): Boolean = {
85-
if (v1.isNaN) {
86-
v2.isNaN
87-
} else {
88-
math.abs(v1 - v2) <= threshold
89-
}
90-
}
91-
92-
def matrixApproxEqual(A: BM[Double], B: BM[Double], threshold: Double = 1e-6): Boolean = {
93-
for (i <- 0 until A.rows; j <- 0 until A.cols) {
94-
if (!approxEqual(A(i, j), B(i, j), threshold)) {
95-
logInfo("i, j = " + i + ", " + j + " actual: " + A(i, j) + " expected:" + B(i, j))
96-
return false
97-
}
98-
}
99-
true
100-
}
101-
102-
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.util
19+
20+
import breeze.linalg.{Matrix => BM}
21+
22+
import org.apache.spark.internal.Logging
23+
24+
/**
25+
* Utility test methods for linear algebra.
26+
*/
27+
object LinalgUtils extends Logging {
28+
29+
30+
/**
31+
* Returns true if two numbers are equal up to some tolerance.
32+
*/
33+
def approxEqual(v1: Double, v2: Double, threshold: Double = 1e-6): Boolean = {
34+
if (v1.isNaN) {
35+
v2.isNaN
36+
} else {
37+
math.abs(v1 - v2) <= threshold
38+
}
39+
}
40+
41+
/**
42+
* Returns true if two numbers are equal coefficient-wise up to some tolerance.
43+
*/
44+
def matrixApproxEqual(A: BM[Double], B: BM[Double], threshold: Double = 1e-6): Boolean = {
45+
for (i <- 0 until A.rows; j <- 0 until A.cols) {
46+
if (!approxEqual(A(i, j), B(i, j), threshold)) {
47+
logInfo("i, j = " + i + ", " + j + " actual: " + A(i, j) + " expected:" + B(i, j))
48+
return false
49+
}
50+
}
51+
true
52+
}
53+
54+
}

0 commit comments

Comments
 (0)