Skip to content

Commit 8fb7d74

Browse files
committed
update impl
1 parent 1ebad60 commit 8fb7d74

File tree

1 file changed

+64
-24
lines changed

1 file changed

+64
-24
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala

Lines changed: 64 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -67,21 +67,23 @@ sealed trait Vector extends Serializable {
6767
}
6868
}
6969

70+
/**
71+
* Returns a hash code value for the vector. The hash code is based on its size and its nonzeros
72+
* in the first 16 entries, using a hash algorithm similar to [[java.util.Arrays.hashCode]].
73+
*/
7074
override def hashCode(): Int = {
71-
var result: Int = size + 31
72-
var i = 0
73-
this.foreachActive { case (index, value) =>
74-
// ignore explict 0 for comparison between sparse and dense
75-
if (value != 0) {
76-
result = 31 * result + index
77-
// refer to {@link java.util.Arrays.equals} for hash algorithm
78-
val bits = java.lang.Double.doubleToLongBits(value)
79-
result = 31 * result + (bits ^ (bits >>> 32)).toInt
80-
i += 1
81-
// only scan the first 16 nonzeros
82-
if (i > 16) {
83-
return result
75+
// This is a reference implementation. It calls return in foreachActive, which is slow.
76+
var result: Int = 31 + size
77+
this.foreachActive { (index, value) =>
78+
if (index < 16) {
79+
// ignore explict 0 for comparison between sparse and dense
80+
if (value != 0) {
81+
result = 31 * result + index
82+
val bits = java.lang.Double.doubleToLongBits(value)
83+
result = 31 * result + (bits ^ (bits >>> 32)).toInt
8484
}
85+
} else {
86+
return result
8587
}
8688
}
8789
result
@@ -322,7 +324,7 @@ object Vectors {
322324
case SparseVector(n, ids, vs) => vs
323325
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
324326
}
325-
val size = values.size
327+
val size = values.length
326328

327329
if (p == 1) {
328330
var sum = 0.0
@@ -376,8 +378,8 @@ object Vectors {
376378
val v1Indices = v1.indices
377379
val v2Values = v2.values
378380
val v2Indices = v2.indices
379-
val nnzv1 = v1Indices.size
380-
val nnzv2 = v2Indices.size
381+
val nnzv1 = v1Indices.length
382+
val nnzv2 = v2Indices.length
381383

382384
var kv1 = 0
383385
var kv2 = 0
@@ -406,7 +408,7 @@ object Vectors {
406408

407409
case (DenseVector(vv1), DenseVector(vv2)) =>
408410
var kv = 0
409-
val sz = vv1.size
411+
val sz = vv1.length
410412
while (kv < sz) {
411413
val score = vv1(kv) - vv2(kv)
412414
squaredDistance += score * score
@@ -427,7 +429,7 @@ object Vectors {
427429
var kv2 = 0
428430
val indices = v1.indices
429431
var squaredDistance = 0.0
430-
val nnzv1 = indices.size
432+
val nnzv1 = indices.length
431433
val nnzv2 = v2.size
432434
var iv1 = if (nnzv1 > 0) indices(kv1) else -1
433435

@@ -456,8 +458,8 @@ object Vectors {
456458
v1Values: Array[Double],
457459
v2Indices: IndexedSeq[Int],
458460
v2Values: Array[Double]): Boolean = {
459-
val v1Size = v1Values.size
460-
val v2Size = v2Values.size
461+
val v1Size = v1Values.length
462+
val v2Size = v2Values.length
461463
var k1 = 0
462464
var k2 = 0
463465
var allEqual = true
@@ -498,14 +500,30 @@ class DenseVector(val values: Array[Double]) extends Vector {
498500

499501
private[spark] override def foreachActive(f: (Int, Double) => Unit) = {
500502
var i = 0
501-
val localValuesSize = values.size
503+
val localValuesSize = values.length
502504
val localValues = values
503505

504506
while (i < localValuesSize) {
505507
f(i, localValues(i))
506508
i += 1
507509
}
508510
}
511+
512+
override def hashCode(): Int = {
513+
var result: Int = 31 + size
514+
var i = 0
515+
val end = math.min(values.length, 16)
516+
while (i < end) {
517+
val v = values(i)
518+
if (v != 0.0) {
519+
result = 31 * result + i
520+
val bits = java.lang.Double.doubleToLongBits(values(i))
521+
result = 31 * result + (bits ^ (bits >>> 32)).toInt
522+
}
523+
i += 1
524+
}
525+
result
526+
}
509527
}
510528

511529
object DenseVector {
@@ -527,8 +545,8 @@ class SparseVector(
527545
val values: Array[Double]) extends Vector {
528546

529547
require(indices.length == values.length, "Sparse vectors require that the dimension of the" +
530-
s" indices match the dimension of the values. You provided ${indices.size} indices and " +
531-
s" ${values.size} values.")
548+
s" indices match the dimension of the values. You provided ${indices.length} indices and " +
549+
s" ${values.length} values.")
532550

533551
override def toString: String =
534552
"(%s,%s,%s)".format(size, indices.mkString("[", ",", "]"), values.mkString("[", ",", "]"))
@@ -552,7 +570,7 @@ class SparseVector(
552570

553571
private[spark] override def foreachActive(f: (Int, Double) => Unit) = {
554572
var i = 0
555-
val localValuesSize = values.size
573+
val localValuesSize = values.length
556574
val localIndices = indices
557575
val localValues = values
558576

@@ -561,6 +579,28 @@ class SparseVector(
561579
i += 1
562580
}
563581
}
582+
583+
override def hashCode(): Int = {
584+
var result: Int = 31 + size
585+
val end = values.length
586+
var continue = true
587+
var k = 0
588+
while ((k < end) & continue) {
589+
val i = indices(k)
590+
if (i < 16) {
591+
val v = values(k)
592+
if (v != 0.0) {
593+
result = 31 * result + i
594+
val bits = java.lang.Double.doubleToLongBits(v)
595+
result = 31 * result + (bits ^ (bits >>> 32)).toInt
596+
}
597+
} else {
598+
continue = false
599+
}
600+
k += 1
601+
}
602+
result
603+
}
564604
}
565605

566606
object SparseVector {

0 commit comments

Comments
 (0)