Skip to content

Commit 5ef006f

Browse files
committed
[SPARK-6756] [MLLIB] add toSparse, toDense, numActives, numNonzeros, and compressed to Vector
Add `compressed` to `Vector` with some other methods: `numActives`, `numNonzeros`, `toSparse`, and `toDense`. jkbradley Author: Xiangrui Meng <[email protected]> Closes apache#5756 from mengxr/SPARK-6756 and squashes the following commits: 8d4ecbd [Xiangrui Meng] address comment and add mima excludes da54179 [Xiangrui Meng] add toSparse, toDense, numActives, numNonzeros, and compressed to Vector
1 parent a8aeadb commit 5ef006f

File tree

3 files changed

+149
-0
lines changed

3 files changed

+149
-0
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,40 @@ sealed trait Vector extends Serializable {
116116
* with type `Double`.
117117
*/
118118
private[spark] def foreachActive(f: (Int, Double) => Unit)
119+
120+
/**
121+
* Number of active entries. An "active entry" is an element which is explicitly stored,
122+
* regardless of its value. Note that inactive entries have value 0.
123+
*/
124+
def numActives: Int
125+
126+
/**
127+
* Number of nonzero elements. This scans all active values and count nonzeros.
128+
*/
129+
def numNonzeros: Int
130+
131+
/**
132+
* Converts this vector to a sparse vector with all explicit zeros removed.
133+
*/
134+
def toSparse: SparseVector
135+
136+
/**
137+
* Converts this vector to a dense vector.
138+
*/
139+
def toDense: DenseVector = new DenseVector(this.toArray)
140+
141+
/**
142+
* Returns a vector in either dense or sparse format, whichever uses less storage.
143+
*/
144+
def compressed: Vector = {
145+
val nnz = numNonzeros
146+
// A dense vector needs 8 * size + 8 bytes, while a sparse vector needs 12 * nnz + 20 bytes.
147+
if (1.5 * (nnz + 1.0) < size) {
148+
toSparse
149+
} else {
150+
toDense
151+
}
152+
}
119153
}
120154

121155
/**
@@ -525,6 +559,34 @@ class DenseVector(val values: Array[Double]) extends Vector {
525559
}
526560
result
527561
}
562+
563+
override def numActives: Int = size
564+
565+
override def numNonzeros: Int = {
566+
// same as values.count(_ != 0.0) but faster
567+
var nnz = 0
568+
values.foreach { v =>
569+
if (v != 0.0) {
570+
nnz += 1
571+
}
572+
}
573+
nnz
574+
}
575+
576+
override def toSparse: SparseVector = {
577+
val nnz = numNonzeros
578+
val ii = new Array[Int](nnz)
579+
val vv = new Array[Double](nnz)
580+
var k = 0
581+
foreachActive { (i, v) =>
582+
if (v != 0) {
583+
ii(k) = i
584+
vv(k) = v
585+
k += 1
586+
}
587+
}
588+
new SparseVector(size, ii, vv)
589+
}
528590
}
529591

530592
object DenseVector {
@@ -602,6 +664,37 @@ class SparseVector(
602664
}
603665
result
604666
}
667+
668+
override def numActives: Int = values.length
669+
670+
override def numNonzeros: Int = {
671+
var nnz = 0
672+
values.foreach { v =>
673+
if (v != 0.0) {
674+
nnz += 1
675+
}
676+
}
677+
nnz
678+
}
679+
680+
override def toSparse: SparseVector = {
681+
val nnz = numNonzeros
682+
if (nnz == numActives) {
683+
this
684+
} else {
685+
val ii = new Array[Int](nnz)
686+
val vv = new Array[Double](nnz)
687+
var k = 0
688+
foreachActive { (i, v) =>
689+
if (v != 0.0) {
690+
ii(k) = i
691+
vv(k) = v
692+
k += 1
693+
}
694+
}
695+
new SparseVector(size, ii, vv)
696+
}
697+
}
605698
}
606699

607700
object SparseVector {

mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,4 +270,48 @@ class VectorsSuite extends FunSuite {
270270
assert(Vectors.norm(sv, 3.7) ~== math.pow(sv.toArray.foldLeft(0.0)((a, v) =>
271271
a + math.pow(math.abs(v), 3.7)), 1.0 / 3.7) relTol 1E-8)
272272
}
273+
274+
test("Vector numActive and numNonzeros") {
275+
val dv = Vectors.dense(0.0, 2.0, 3.0, 0.0)
276+
assert(dv.numActives === 4)
277+
assert(dv.numNonzeros === 2)
278+
279+
val sv = Vectors.sparse(4, Array(0, 1, 2), Array(0.0, 2.0, 3.0))
280+
assert(sv.numActives === 3)
281+
assert(sv.numNonzeros === 2)
282+
}
283+
284+
test("Vector toSparse and toDense") {
285+
val dv0 = Vectors.dense(0.0, 2.0, 3.0, 0.0)
286+
assert(dv0.toDense === dv0)
287+
val dv0s = dv0.toSparse
288+
assert(dv0s.numActives === 2)
289+
assert(dv0s === dv0)
290+
291+
val sv0 = Vectors.sparse(4, Array(0, 1, 2), Array(0.0, 2.0, 3.0))
292+
assert(sv0.toDense === sv0)
293+
val sv0s = sv0.toSparse
294+
assert(sv0s.numActives === 2)
295+
assert(sv0s === sv0)
296+
}
297+
298+
test("Vector.compressed") {
299+
val dv0 = Vectors.dense(1.0, 2.0, 3.0, 0.0)
300+
val dv0c = dv0.compressed.asInstanceOf[DenseVector]
301+
assert(dv0c === dv0)
302+
303+
val dv1 = Vectors.dense(0.0, 2.0, 0.0, 0.0)
304+
val dv1c = dv1.compressed.asInstanceOf[SparseVector]
305+
assert(dv1 === dv1c)
306+
assert(dv1c.numActives === 1)
307+
308+
val sv0 = Vectors.sparse(4, Array(1, 2), Array(2.0, 0.0))
309+
val sv0c = sv0.compressed.asInstanceOf[SparseVector]
310+
assert(sv0 === sv0c)
311+
assert(sv0c.numActives === 1)
312+
313+
val sv1 = Vectors.sparse(4, Array(0, 1, 2), Array(1.0, 2.0, 3.0))
314+
val sv1c = sv1.compressed.asInstanceOf[DenseVector]
315+
assert(sv1 === sv1c)
316+
}
273317
}

project/MimaExcludes.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,18 @@ object MimaExcludes {
7676
// SPARK-7090 Introduce LDAOptimizer to LDA to further improve extensibility
7777
ProblemFilters.exclude[MissingClassProblem](
7878
"org.apache.spark.mllib.clustering.LDA$EMOptimizer")
79+
) ++ Seq(
80+
// SPARK-6756 add toSparse, toDense, numActives, numNonzeros, and compressed to Vector
81+
ProblemFilters.exclude[MissingMethodProblem](
82+
"org.apache.spark.mllib.linalg.Vector.compressed"),
83+
ProblemFilters.exclude[MissingMethodProblem](
84+
"org.apache.spark.mllib.linalg.Vector.toDense"),
85+
ProblemFilters.exclude[MissingMethodProblem](
86+
"org.apache.spark.mllib.linalg.Vector.numNonzeros"),
87+
ProblemFilters.exclude[MissingMethodProblem](
88+
"org.apache.spark.mllib.linalg.Vector.toSparse"),
89+
ProblemFilters.exclude[MissingMethodProblem](
90+
"org.apache.spark.mllib.linalg.Vector.numActives")
7991
)
8092

8193
case v if v.startsWith("1.3") =>

0 commit comments

Comments
 (0)