Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 34 additions & 7 deletions mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,46 @@ class MinHashLSHModel private[ml](

@Since("2.1.0")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keyDistance was not exposed to the end users, is this since annotation needed?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose it's protected so a little more visible to callers. I wouldn't remove it just for its own sake.

override protected[ml] def keyDistance(x: Vector, y: Vector): Double = {
val xSet = x.nonZeroIterator.map(_._1).toSet
val ySet = y.nonZeroIterator.map(_._1).toSet
val intersectionSize = xSet.intersect(ySet).size.toDouble
val unionSize = xSet.size + ySet.size - intersectionSize
assert(unionSize > 0, "The union of two input sets must have at least 1 elements")
1 - intersectionSize / unionSize
val xIter = x.nonZeroIterator.map(_._1)
val yIter = y.nonZeroIterator.map(_._1)
if (xIter.isEmpty) {
require(yIter.hasNext, "The union of two input sets must have at least 1 elements")
return 1.0
} else if (yIter.isEmpty) {
return 1.0
}

var xIndex = xIter.next
var yIndex = yIter.next
var xSize = 1
var ySize = 1
var intersectionSize = 0

while (xIndex != -1 && yIndex != -1) {
if (xIndex == yIndex) {
intersectionSize += 1
xIndex = if (xIter.hasNext) { xSize += 1; xIter.next } else -1
yIndex = if (yIter.hasNext) { ySize += 1; yIter.next } else -1
} else if (xIndex > yIndex) {
yIndex = if (yIter.hasNext) { ySize += 1; yIter.next } else -1
} else {
xIndex = if (xIter.hasNext) { xSize += 1; xIter.next } else -1
}
}

xSize += xIter.size
ySize += yIter.size

val unionSize = xSize + ySize - intersectionSize
require(unionSize > 0, "The union of two input sets must have at least 1 elements")
1 - intersectionSize.toDouble / unionSize
}

@Since("2.1.0")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

override protected[ml] def hashDistance(x: Seq[Vector], y: Seq[Vector]): Double = {
// Since it's generated by hashing, it will be a pair of dense vectors.
// TODO: This hashDistance function requires more discussion in SPARK-18454
x.zip(y).map(vectorPair =>
x.iterator.zip(y.iterator).map(vectorPair =>
vectorPair._1.toArray.zip(vectorPair._2.toArray).count(pair => pair._1 != pair._2)
).min
}
Expand Down