Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 52 additions & 3 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,35 @@ sealed trait Vector extends Serializable {

override def equals(other: Any): Boolean = {
other match {
case v: Vector =>
util.Arrays.equals(this.toArray, v.toArray)
case v2: Vector => {
if (this.size != v2.size) return false
(this, v2) match {
case (s1: SparseVector, s2: SparseVector) =>
Vectors.equals(s1.indices, s1.values, s2.indices, s2.values)
case (s1: SparseVector, d1: DenseVector) =>
Vectors.equals(s1.indices, s1.values, 0 until d1.size, d1.values)
case (d1: DenseVector, s1: SparseVector) =>
Vectors.equals(0 until d1.size, d1.values, s1.indices, s1.values)
case (_, _) => util.Arrays.equals(this.toArray, v2.toArray)
}
}
case _ => false
}
}

override def hashCode(): Int = util.Arrays.hashCode(this.toArray)
override def hashCode(): Int = {
var result: Int = size + 31
this.foreachActive { case (index, value) =>
// ignore explict 0 for comparison between sparse and dense
if (value != 0) {
result = 31 * result + index
// refer to {@link java.util.Arrays.equals} for hash algorithm
val bits = java.lang.Double.doubleToLongBits(value)
result = 31 * result + (bits ^ (bits >>> 32)).toInt
}
}
return result
}

/**
* Converts the instance to a breeze vector.
Expand Down Expand Up @@ -312,6 +334,33 @@ object Vectors {
math.pow(sum, 1.0 / p)
}
}

/**
* Check equality between sparse/dense vectors
*/
private[mllib] def equals(
v1Indices: IndexedSeq[Int],
v1Values: Array[Double],
v2Indices: IndexedSeq[Int],
v2Values: Array[Double]): Boolean = {
val v1Size = v1Values.size
val v2Size = v2Values.size
var k1 = 0
var k2 = 0
var allEqual = true
while (allEqual) {
while (k1 < v1Size && v1Values(k1) == 0) k1 += 1
while (k2 < v2Size && v2Values(k2) == 0) k2 += 1

if (k1 >= v1Size || k2 >= v2Size) {
return k1 >= v1Size && k2 >= v2Size // check end alignment
}
allEqual = v1Indices(k1) == v2Indices(k2) && v1Values(k1) == v2Values(k2)
k1 += 1
k2 += 1
}
allEqual
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,24 @@ class VectorsSuite extends FunSuite {
}
}

test("vectors equals with explicit 0") {
val dv1 = Vectors.dense(Array(0, 0.9, 0, 0.8, 0))
val sv1 = Vectors.sparse(5, Array(1, 3), Array(0.9, 0.8))
val sv2 = Vectors.sparse(5, Array(0, 1, 2, 3, 4), Array(0, 0.9, 0, 0.8, 0))

val vectors = Seq(dv1, sv1, sv2)
for (v <- vectors; u <- vectors) {
assert(v === u)
assert(v.## === u.##)
}

val another = Vectors.sparse(5, Array(0, 1, 3), Array(0, 0.9, 0.2))
for (v <- vectors) {
assert(v != another)
assert(v.## != another.##)
}
}

test("indexing dense vectors") {
val vec = Vectors.dense(1.0, 2.0, 3.0, 4.0)
assert(vec(0) === 1.0)
Expand Down