diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala index ac3d79d07755..be467c654aaa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala @@ -70,19 +70,46 @@ class MinHashLSHModel private[ml]( @Since("2.1.0") override protected[ml] def keyDistance(x: Vector, y: Vector): Double = { - val xSet = x.nonZeroIterator.map(_._1).toSet - val ySet = y.nonZeroIterator.map(_._1).toSet - val intersectionSize = xSet.intersect(ySet).size.toDouble - val unionSize = xSet.size + ySet.size - intersectionSize - assert(unionSize > 0, "The union of two input sets must have at least 1 elements") - 1 - intersectionSize / unionSize + val xIter = x.nonZeroIterator.map(_._1) + val yIter = y.nonZeroIterator.map(_._1) + if (xIter.isEmpty) { + require(yIter.hasNext, "The union of two input sets must have at least 1 elements") + return 1.0 + } else if (yIter.isEmpty) { + return 1.0 + } + + var xIndex = xIter.next + var yIndex = yIter.next + var xSize = 1 + var ySize = 1 + var intersectionSize = 0 + + while (xIndex != -1 && yIndex != -1) { + if (xIndex == yIndex) { + intersectionSize += 1 + xIndex = if (xIter.hasNext) { xSize += 1; xIter.next } else -1 + yIndex = if (yIter.hasNext) { ySize += 1; yIter.next } else -1 + } else if (xIndex > yIndex) { + yIndex = if (yIter.hasNext) { ySize += 1; yIter.next } else -1 + } else { + xIndex = if (xIter.hasNext) { xSize += 1; xIter.next } else -1 + } + } + + xSize += xIter.size + ySize += yIter.size + + val unionSize = xSize + ySize - intersectionSize + require(unionSize > 0, "The union of two input sets must have at least 1 elements") + 1 - intersectionSize.toDouble / unionSize } @Since("2.1.0") override protected[ml] def hashDistance(x: Seq[Vector], y: Seq[Vector]): Double = { // Since it's generated by hashing, it will be a pair of dense vectors. // TODO: This hashDistance function requires more discussion in SPARK-18454 - x.zip(y).map(vectorPair => + x.iterator.zip(y.iterator).map(vectorPair => vectorPair._1.toArray.zip(vectorPair._2.toArray).count(pair => pair._1 != pair._2) ).min }