Skip to content

Commit c66a976

Browse files
coderxiangmengxr
authored andcommitted
[SPARK-5116][MLlib] Add extractor for SparseVector and DenseVector
Add extractor for SparseVector and DenseVector in MLlib to save some code while performing pattern matching on Vectors. For example, previously we may use: vec match { case dv: DenseVector => val values = dv.values ... case sv: SparseVector => val indices = sv.indices val values = sv.values val size = sv.size ... } with extractor it is: vec match { case DenseVector(values) => ... case SparseVector(size, indices, values) => ... } Author: Shuo Xiang <[email protected]> Closes apache#3919 from coderxiang/extractor and squashes the following commits: 359e8d5 [Shuo Xiang] merge master ca5fc3e [Shuo Xiang] merge master 0b1e190 [Shuo Xiang] use extractor for vectors in RowMatrix.scala e961805 [Shuo Xiang] use extractor for vectors in StandardScaler.scala c2bbdaf [Shuo Xiang] use extractor for vectors in IDFscala 8433922 [Shuo Xiang] use extractor for vectors in NaiveBayes.scala and Normalizer.scala d83c7ca [Shuo Xiang] use extractor for vectors in Vectors.scala 5523dad [Shuo Xiang] Add extractor for SparseVector and DenseVector
1 parent 2b729d2 commit c66a976

File tree

6 files changed

+57
-51
lines changed

6 files changed

+57
-51
lines changed

mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,10 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with
9393
def run(data: RDD[LabeledPoint]) = {
9494
val requireNonnegativeValues: Vector => Unit = (v: Vector) => {
9595
val values = v match {
96-
case sv: SparseVector =>
97-
sv.values
98-
case dv: DenseVector =>
99-
dv.values
96+
case SparseVector(size, indices, values) =>
97+
values
98+
case DenseVector(values) =>
99+
values
100100
}
101101
if (!values.forall(_ >= 0.0)) {
102102
throw new SparkException(s"Naive Bayes requires nonnegative feature values but found $v.")

mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -86,20 +86,20 @@ private object IDF {
8686
df = BDV.zeros(doc.size)
8787
}
8888
doc match {
89-
case sv: SparseVector =>
90-
val nnz = sv.indices.size
89+
case SparseVector(size, indices, values) =>
90+
val nnz = indices.size
9191
var k = 0
9292
while (k < nnz) {
93-
if (sv.values(k) > 0) {
94-
df(sv.indices(k)) += 1L
93+
if (values(k) > 0) {
94+
df(indices(k)) += 1L
9595
}
9696
k += 1
9797
}
98-
case dv: DenseVector =>
99-
val n = dv.size
98+
case DenseVector(values) =>
99+
val n = values.size
100100
var j = 0
101101
while (j < n) {
102-
if (dv.values(j) > 0.0) {
102+
if (values(j) > 0.0) {
103103
df(j) += 1L
104104
}
105105
j += 1
@@ -207,20 +207,20 @@ private object IDFModel {
207207
def transform(idf: Vector, v: Vector): Vector = {
208208
val n = v.size
209209
v match {
210-
case sv: SparseVector =>
211-
val nnz = sv.indices.size
210+
case SparseVector(size, indices, values) =>
211+
val nnz = indices.size
212212
val newValues = new Array[Double](nnz)
213213
var k = 0
214214
while (k < nnz) {
215-
newValues(k) = sv.values(k) * idf(sv.indices(k))
215+
newValues(k) = values(k) * idf(indices(k))
216216
k += 1
217217
}
218-
Vectors.sparse(n, sv.indices, newValues)
219-
case dv: DenseVector =>
218+
Vectors.sparse(n, indices, newValues)
219+
case DenseVector(values) =>
220220
val newValues = new Array[Double](n)
221221
var j = 0
222222
while (j < n) {
223-
newValues(j) = dv.values(j) * idf(j)
223+
newValues(j) = values(j) * idf(j)
224224
j += 1
225225
}
226226
Vectors.dense(newValues)

mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,24 +52,24 @@ class Normalizer(p: Double) extends VectorTransformer {
5252
// However, for sparse vector, the `index` array will not be changed,
5353
// so we can re-use it to save memory.
5454
vector match {
55-
case dv: DenseVector =>
56-
val values = dv.values.clone()
55+
case DenseVector(vs) =>
56+
val values = vs.clone()
5757
val size = values.size
5858
var i = 0
5959
while (i < size) {
6060
values(i) /= norm
6161
i += 1
6262
}
6363
Vectors.dense(values)
64-
case sv: SparseVector =>
65-
val values = sv.values.clone()
64+
case SparseVector(size, ids, vs) =>
65+
val values = vs.clone()
6666
val nnz = values.size
6767
var i = 0
6868
while (i < nnz) {
6969
values(i) /= norm
7070
i += 1
7171
}
72-
Vectors.sparse(sv.size, sv.indices, values)
72+
Vectors.sparse(size, ids, values)
7373
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
7474
}
7575
} else {

mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ class StandardScalerModel private[mllib] (
105105
// This can be avoid by having a local reference of `shift`.
106106
val localShift = shift
107107
vector match {
108-
case dv: DenseVector =>
109-
val values = dv.values.clone()
108+
case DenseVector(vs) =>
109+
val values = vs.clone()
110110
val size = values.size
111111
if (withStd) {
112112
// Having a local reference of `factor` to avoid overhead as the comment before.
@@ -130,27 +130,26 @@ class StandardScalerModel private[mllib] (
130130
// Having a local reference of `factor` to avoid overhead as the comment before.
131131
val localFactor = factor
132132
vector match {
133-
case dv: DenseVector =>
134-
val values = dv.values.clone()
133+
case DenseVector(vs) =>
134+
val values = vs.clone()
135135
val size = values.size
136136
var i = 0
137137
while(i < size) {
138138
values(i) *= localFactor(i)
139139
i += 1
140140
}
141141
Vectors.dense(values)
142-
case sv: SparseVector =>
142+
case SparseVector(size, indices, vs) =>
143143
// For sparse vector, the `index` array inside sparse vector object will not be changed,
144144
// so we can re-use it to save memory.
145-
val indices = sv.indices
146-
val values = sv.values.clone()
145+
val values = vs.clone()
147146
val nnz = values.size
148147
var i = 0
149148
while (i < nnz) {
150149
values(i) *= localFactor(indices(i))
151150
i += 1
152151
}
153-
Vectors.sparse(sv.size, indices, values)
152+
Vectors.sparse(size, indices, values)
154153
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
155154
}
156155
} else {

mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -108,16 +108,16 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
108108
override def serialize(obj: Any): Row = {
109109
val row = new GenericMutableRow(4)
110110
obj match {
111-
case sv: SparseVector =>
111+
case SparseVector(size, indices, values) =>
112112
row.setByte(0, 0)
113-
row.setInt(1, sv.size)
114-
row.update(2, sv.indices.toSeq)
115-
row.update(3, sv.values.toSeq)
116-
case dv: DenseVector =>
113+
row.setInt(1, size)
114+
row.update(2, indices.toSeq)
115+
row.update(3, values.toSeq)
116+
case DenseVector(values) =>
117117
row.setByte(0, 1)
118118
row.setNullAt(1)
119119
row.setNullAt(2)
120-
row.update(3, dv.values.toSeq)
120+
row.update(3, values.toSeq)
121121
}
122122
row
123123
}
@@ -271,8 +271,8 @@ object Vectors {
271271
def norm(vector: Vector, p: Double): Double = {
272272
require(p >= 1.0)
273273
val values = vector match {
274-
case dv: DenseVector => dv.values
275-
case sv: SparseVector => sv.values
274+
case DenseVector(vs) => vs
275+
case SparseVector(n, ids, vs) => vs
276276
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
277277
}
278278
val size = values.size
@@ -427,6 +427,10 @@ class DenseVector(val values: Array[Double]) extends Vector {
427427
}
428428
}
429429

430+
object DenseVector {
431+
def unapply(dv: DenseVector): Option[Array[Double]] = Some(dv.values)
432+
}
433+
430434
/**
431435
* A sparse vector represented by an index array and an value array.
432436
*
@@ -474,3 +478,8 @@ class SparseVector(
474478
}
475479
}
476480
}
481+
482+
object SparseVector {
483+
def unapply(sv: SparseVector): Option[(Int, Array[Int], Array[Double])] =
484+
Some((sv.size, sv.indices, sv.values))
485+
}

mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -528,21 +528,21 @@ class RowMatrix(
528528
iter.flatMap { row =>
529529
val buf = new ListBuffer[((Int, Int), Double)]()
530530
row match {
531-
case sv: SparseVector =>
532-
val nnz = sv.indices.size
531+
case SparseVector(size, indices, values) =>
532+
val nnz = indices.size
533533
var k = 0
534534
while (k < nnz) {
535-
scaled(k) = sv.values(k) / q(sv.indices(k))
535+
scaled(k) = values(k) / q(indices(k))
536536
k += 1
537537
}
538538
k = 0
539539
while (k < nnz) {
540-
val i = sv.indices(k)
540+
val i = indices(k)
541541
val iVal = scaled(k)
542542
if (iVal != 0 && rand.nextDouble() < p(i)) {
543543
var l = k + 1
544544
while (l < nnz) {
545-
val j = sv.indices(l)
545+
val j = indices(l)
546546
val jVal = scaled(l)
547547
if (jVal != 0 && rand.nextDouble() < p(j)) {
548548
buf += (((i, j), iVal * jVal))
@@ -552,11 +552,11 @@ class RowMatrix(
552552
}
553553
k += 1
554554
}
555-
case dv: DenseVector =>
556-
val n = dv.values.size
555+
case DenseVector(values) =>
556+
val n = values.size
557557
var i = 0
558558
while (i < n) {
559-
scaled(i) = dv.values(i) / q(i)
559+
scaled(i) = values(i) / q(i)
560560
i += 1
561561
}
562562
i = 0
@@ -620,11 +620,9 @@ object RowMatrix {
620620
// TODO: Find a better home (breeze?) for this method.
621621
val n = v.size
622622
v match {
623-
case dv: DenseVector =>
624-
blas.dspr("U", n, alpha, dv.values, 1, U)
625-
case sv: SparseVector =>
626-
val indices = sv.indices
627-
val values = sv.values
623+
case DenseVector(values) =>
624+
blas.dspr("U", n, alpha, values, 1, U)
625+
case SparseVector(size, indices, values) =>
628626
val nnz = indices.length
629627
var colStartIdx = 0
630628
var prevCol = 0

0 commit comments

Comments
 (0)