Skip to content

Commit b98bb18

Browse files
committed
add comments
1 parent 0846e07 commit b98bb18

File tree

1 file changed

+18
-18
lines changed

1 file changed

+18
-18
lines changed

mllib/src/main/scala/org/apache/spark/mllib/stat/correlation/SpearmanCorrelation.scala

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,13 @@
1717

1818
package org.apache.spark.mllib.stat.correlation
1919

20-
import org.apache.spark.storage.StorageLevel
21-
2220
import scala.collection.mutable.ArrayBuffer
2321

24-
import org.apache.spark.{Logging, HashPartitioner}
22+
import org.apache.spark.Logging
2523
import org.apache.spark.SparkContext._
26-
import org.apache.spark.mllib.linalg.{Vectors, DenseVector, Matrix, Vector}
27-
import org.apache.spark.rdd.{CoGroupedRDD, RDD}
24+
import org.apache.spark.mllib.linalg.{Matrix, Vector, Vectors}
25+
import org.apache.spark.rdd.RDD
26+
import org.apache.spark.storage.StorageLevel
2827

2928
/**
3029
* Compute Spearman's correlation for two RDDs of the type RDD[Double] or the correlation matrix
@@ -45,18 +44,18 @@ private[stat] object SpearmanCorrelation extends Correlation with Logging {
4544
/**
4645
* Compute Spearman's correlation matrix S, for the input matrix, where S(i, j) is the
4746
* correlation between column i and j.
48-
*
49-
* Input RDD[Vector] should be cached or checkpointed if possible since it would be split into
50-
* numCol RDD[Double]s, each of which sorted, and the joined back into a single RDD[Vector].
5147
*/
5248
override def computeCorrelationMatrix(X: RDD[Vector]): Matrix = {
53-
val transposed = X.zipWithUniqueId().flatMap { case (vec, uid) =>
49+
// ((columnIndex, value), rowId)
50+
val colBased = X.zipWithUniqueId().flatMap { case (vec, uid) =>
5451
vec.toArray.view.zipWithIndex.map { case (v, j) =>
5552
((j, v), uid)
5653
}
57-
}.persist(StorageLevel.MEMORY_AND_DISK)
58-
val sorted = transposed.sortByKey().persist(StorageLevel.MEMORY_AND_DISK)
59-
val ranked = sorted.zipWithIndex().mapPartitions { iter =>
54+
}.persist(StorageLevel.MEMORY_AND_DISK) // used by sortByKey
55+
// global sort by (columnIndex, value)
56+
val sorted = colBased.sortByKey().persist(StorageLevel.MEMORY_AND_DISK) // used by zipWithIndex
57+
// Assign global ranks (using average ranks for tied values)
58+
val globalRanks = sorted.zipWithIndex().mapPartitions { iter =>
6059
var preCol = -1
6160
var preVal = Double.NaN
6261
var startRank = -1.0
@@ -85,14 +84,15 @@ private[stat] object SpearmanCorrelation extends Correlation with Logging {
8584
flush()
8685
}
8786
}
88-
val ranks = tied.groupByKey().map { case (uid, iter) =>
89-
val values = iter.toSeq.sortBy(_._1).map(_._2).toArray
90-
println(values.toSeq)
91-
Vectors.dense(values)
87+
// Replace values in the input matrix by their ranks compared with values in the same column.
88+
// Note that shifting all ranks in a column by a constant value doesn't affect result.
89+
val groupedRanks = globalRanks.groupByKey().map { case (uid, iter) =>
90+
// sort by column index and then convert values to a vector
91+
Vectors.dense(iter.toSeq.sortBy(_._1).map(_._2).toArray)
9292
}
93-
val corrMatrix = PearsonCorrelation.computeCorrelationMatrix(ranks)
93+
val corrMatrix = PearsonCorrelation.computeCorrelationMatrix(groupedRanks)
9494

95-
transposed.unpersist(blocking = false)
95+
colBased.unpersist(blocking = false)
9696
sorted.unpersist(blocking = false)
9797

9898
corrMatrix

0 commit comments

Comments
 (0)