Skip to content

Commit f41b135

Browse files
committed
iterative equals for sparse vector
1 parent 5741144 commit f41b135

File tree

2 files changed

+42
-4
lines changed

2 files changed

+42
-4
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -451,10 +451,25 @@ class SparseVector(
451451

452452
override def equals(other: Any): Boolean = {
453453
other match {
454-
case v: SparseVector =>
455-
this.size == v.size &&
456-
util.Arrays.equals(this.indices, v.indices) &&
457-
util.Arrays.equals(this.values, v.values)
454+
case v: SparseVector => {
455+
if (this.size != v.size) { return false }
456+
var k1 = 0
457+
var k2 = 0
458+
while (true) {
459+
while (k1 < this.values.size && this.values(k1) == 0) k1 += 1
460+
while (k2 < v.values.size && v.values(k2) == 0) k2 += 1
461+
462+
if (k1 == this.values.size || k2 == v.values.size) {
463+
return (k1 == this.values.size && k2 == v.values.size) //check end alignment
464+
}
465+
if (this.indices(k1) != v.indices(k2) || this.values(k1) != v.values(k2)) {
466+
return false
467+
}
468+
k1 += 1
469+
k2 += 1
470+
}
471+
throw new Exception("unreachable")
472+
}
458473
case _ => super.equals(other)
459474
}
460475
}

mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,29 @@ class VectorsSuite extends FunSuite {
8989
}
9090
}
9191

92+
test("sparse equals with explicit 0") {
93+
val sv1 = Vectors.sparse(10, Array(1, 3, 4, 7), Array(0.9, 0.8, 0.7, 0.6))
94+
val sv2 = Vectors.sparse(10, Array(0, 1, 3, 4, 7), Array(0.0, 0.9, 0.8, 0.7, 0.6))
95+
val sv3 = Vectors.sparse(10, Array(1, 3, 4, 5, 6, 7), Array(0.9, 0.8, 0, 0, 0.7, 0.6))
96+
val sv4 = Vectors.sparse(10, Array(1, 3, 4, 7, 9), Array(0.9, 0.8, 0.7, 0.6, 0))
97+
val sv5 = Vectors.sparse(10, Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9),
98+
Array(0, 0.9, 0, 0.8, 0.7, 0, 0, 0.6, 0, 0))
99+
100+
val vectors = Seq(sv1, sv2, sv3, sv4, sv5)
101+
102+
for (v <- vectors; u <- vectors) {
103+
assert(v === u)
104+
assert(v.## === u.##)
105+
}
106+
107+
val another = Vectors.sparse(10, Array(1, 3, 4, 7), Array(0.9, 0.2, 0.7, 0.6))
108+
109+
for (v <- vectors) {
110+
assert(v != another)
111+
assert(v.## != another.##)
112+
}
113+
}
114+
92115
test("indexing dense vectors") {
93116
val vec = Vectors.dense(1.0, 2.0, 3.0, 4.0)
94117
assert(vec(0) === 1.0)

0 commit comments

Comments
 (0)