Skip to content

Commit f1489e6

Browse files
zhengruifengsrowen
authored andcommitted
[SPARK-31436][ML] MinHash keyDistance optimization
### What changes were proposed in this pull request? re-impl `keyDistance`: if both vectors are dense, new impl is 9.09x faster; if both vectors are sparse, new impl is 5.66x faster; if one is dense and the other is sparse, new impl is 7.8x faster; ### Why are the changes needed? current implementation based on set operations is inefficient ### Does this PR introduce any user-facing change? No ### How was this patch tested? existing testsuites Closes apache#28206 from zhengruifeng/minhash_opt. Authored-by: zhengruifeng <[email protected]> Signed-off-by: Sean Owen <[email protected]>
1 parent 1513673 commit f1489e6

File tree

1 file changed

+34
-7
lines changed

1 file changed

+34
-7
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,19 +70,46 @@ class MinHashLSHModel private[ml](
7070

7171
@Since("2.1.0")
7272
override protected[ml] def keyDistance(x: Vector, y: Vector): Double = {
73-
val xSet = x.nonZeroIterator.map(_._1).toSet
74-
val ySet = y.nonZeroIterator.map(_._1).toSet
75-
val intersectionSize = xSet.intersect(ySet).size.toDouble
76-
val unionSize = xSet.size + ySet.size - intersectionSize
77-
assert(unionSize > 0, "The union of two input sets must have at least 1 elements")
78-
1 - intersectionSize / unionSize
73+
val xIter = x.nonZeroIterator.map(_._1)
74+
val yIter = y.nonZeroIterator.map(_._1)
75+
if (xIter.isEmpty) {
76+
require(yIter.hasNext, "The union of two input sets must have at least 1 elements")
77+
return 1.0
78+
} else if (yIter.isEmpty) {
79+
return 1.0
80+
}
81+
82+
var xIndex = xIter.next
83+
var yIndex = yIter.next
84+
var xSize = 1
85+
var ySize = 1
86+
var intersectionSize = 0
87+
88+
while (xIndex != -1 && yIndex != -1) {
89+
if (xIndex == yIndex) {
90+
intersectionSize += 1
91+
xIndex = if (xIter.hasNext) { xSize += 1; xIter.next } else -1
92+
yIndex = if (yIter.hasNext) { ySize += 1; yIter.next } else -1
93+
} else if (xIndex > yIndex) {
94+
yIndex = if (yIter.hasNext) { ySize += 1; yIter.next } else -1
95+
} else {
96+
xIndex = if (xIter.hasNext) { xSize += 1; xIter.next } else -1
97+
}
98+
}
99+
100+
xSize += xIter.size
101+
ySize += yIter.size
102+
103+
val unionSize = xSize + ySize - intersectionSize
104+
require(unionSize > 0, "The union of two input sets must have at least 1 elements")
105+
1 - intersectionSize.toDouble / unionSize
79106
}
80107

81108
@Since("2.1.0")
82109
override protected[ml] def hashDistance(x: Seq[Vector], y: Seq[Vector]): Double = {
83110
// Since it's generated by hashing, it will be a pair of dense vectors.
84111
// TODO: This hashDistance function requires more discussion in SPARK-18454
85-
x.zip(y).map(vectorPair =>
112+
x.iterator.zip(y.iterator).map(vectorPair =>
86113
vectorPair._1.toArray.zip(vectorPair._2.toArray).count(pair => pair._1 != pair._2)
87114
).min
88115
}

0 commit comments

Comments
 (0)