Skip to content

Conversation

@zhengruifeng
Copy link
Contributor

@zhengruifeng zhengruifeng commented Apr 13, 2020

What changes were proposed in this pull request?

re-impl keyDistance:
if both vectors are dense, new impl is 9.09x faster;
if both vectors are sparse, new impl is 5.66x faster;
if one is dense and the other is sparse, new impl is 7.8x faster;

Why are the changes needed?

current implementation based on set operations is inefficient

Does this PR introduce any user-facing change?

No

How was this patch tested?

existing testsuites

@zhengruifeng
Copy link
Contributor Author

zhengruifeng commented Apr 13, 2020

testcode:


import scala.util.Random
import org.apache.spark.ml.linalg._

val rng = new Random(42)

val vec1 = Vectors.dense(Array.fill(10000)(1.0).map(_.toDouble))
val vec2 = Vectors.sparse(10000, rng.shuffle(Seq.range(0, 10000)).take(100).toArray.sorted, Array.fill(100)(1.0))
val vec3 = Vectors.sparse(10000, rng.shuffle(Seq.range(0, 10000)).take(100).toArray.sorted, Array.fill(100)(1.0))



def getNonZeroIterator(vec: Vector): Iterator[(Int, Double)] = {
    vec match {
        case DenseVector(values) => Iterator.tabulate(values.length)(i => (i, values(i))).filter(_._2 != 0)
        case SparseVector(_, indices, values) => Iterator.tabulate(indices.length)(i => (indices(i), values(i))).filter(_._2 != 0)
    }
}


def keyDistance1(x: Vector, y: Vector): Double = {
    val xSet = getNonZeroIterator(x).map(_._1).toSet
    val ySet = getNonZeroIterator(y).map(_._1).toSet
    val intersectionSize = xSet.intersect(ySet).size.toDouble
    val unionSize = xSet.size + ySet.size - intersectionSize
    assert(unionSize > 0, "The union of two input sets must have at least 1 elements")
    1 - intersectionSize / unionSize
}




def keyDistance2(x: Vector, y: Vector): Double = {
    val xIter = getNonZeroIterator(x).map(_._1)
    val yIter = getNonZeroIterator(y).map(_._1)
    if (xIter.isEmpty) {
      assert(yIter.hasNext, "The union of two input sets must have at least 1 elements")
      return 0.0
    } else if (yIter.isEmpty) {
      return 0.0
    }

    var xIndex = xIter.next
    var yIndex = yIter.next
    var xSize = 1
    var ySize = 1
    var intersectionSize = 0

    while (xIndex != -1 || yIndex != -1) {
      if (xIndex != -1 && yIndex != -1) {
        if (xIndex == yIndex) {
          intersectionSize += 1
          xIndex = if (xIter.hasNext) { xSize += 1; xIter.next } else -1
          yIndex = if (yIter.hasNext) { ySize += 1; yIter.next } else -1
        } else if (xIndex > yIndex) {
          yIndex = if (yIter.hasNext) { ySize += 1; yIter.next } else -1
        } else {
          xIndex = if (xIter.hasNext) { xSize += 1; xIter.next } else -1
        }
      } else if (xIndex != -1) {
        while (xIter.hasNext) { xIndex = xIter.next; xSize += 1 }
        xIndex = -1
      } else {
        while (yIter.hasNext) { yIndex = yIter.next; ySize += 1 }
        yIndex = -1
      }
    }

    val unionSize = xSize + ySize - intersectionSize
    assert(unionSize > 0, "The union of two input sets must have at least 1 elements")
    1 - intersectionSize.toDouble / unionSize
  }

results:

scala> val start = System.currentTimeMillis; Seq.range(0, 10000).foreach { i => keyDistance1(vec1, vec1) }; val end = System.currentTimeMillis; val duration = end - start; 
start: Long = 1586778279745
end: Long = 1586778324648
duration: Long = 44903

scala> val start = System.currentTimeMillis; Seq.range(0, 10000).foreach { i => keyDistance2(vec1, vec1) }; val end = System.currentTimeMillis; val duration = end - start; 
start: Long = 1586778402039
end: Long = 1586778406977
duration: Long = 4938

scala> val start = System.currentTimeMillis; Seq.range(0, 10000).foreach { i => keyDistance1(vec1, vec2) }; val end = System.currentTimeMillis; val duration = end - start; 
start: Long = 1586778414223
end: Long = 1586778432697
duration: Long = 18474

scala> val start = System.currentTimeMillis; Seq.range(0, 10000).foreach { i => keyDistance2(vec1, vec2) }; val end = System.currentTimeMillis; val duration = end - start; 
start: Long = 1586778439978
end: Long = 1586778442346
duration: Long = 2368

scala> val start = System.currentTimeMillis; Seq.range(0, 10000).foreach { i => keyDistance1(vec2, vec3) }; val end = System.currentTimeMillis; val duration = end - start; 
start: Long = 1586778451556
end: Long = 1586778451851
duration: Long = 295

scala> val start = System.currentTimeMillis; Seq.range(0, 10000).foreach { i => keyDistance2(vec2, vec3) }; val end = System.currentTimeMillis; val duration = end - start; 
start: Long = 1586778458768
end: Long = 1586778458821
duration: Long = 53

if both vectors are dense, new impl is 9.09x faster;
if both vectors are sparse, new impl is 5.66x faster;
if one is dense and the other is sparse, new impl is 7.8x faster;

@SparkQA
Copy link

SparkQA commented Apr 13, 2020

Test build #121206 has finished for PR 28206 at commit a35d739.

  • This patch fails PySpark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Apr 13, 2020

Test build #121207 has finished for PR 28206 at commit a77225e.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

Copy link
Member

@srowen srowen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks OK if tests pass

xIndex = if (xIter.hasNext) { xSize += 1; xIter.next } else -1
}
} else if (xIndex != -1) {
while (xIter.hasNext) { xIndex = xIter.next; xSize += 1 }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to update xIndex in this case, and yIndex below? looks like you're just counting the remaining size of the iterator. It's possible .size() is as fast from here.

use iter.size for remaining elements
hashValues.map(Vectors.dense(_))
}

@Since("2.1.0")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keyDistance was not exposed to the end users, is this since annotation needed?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose it's protected so a little more visible to callers. I wouldn't remove it just for its own sake.

1 - intersectionSize.toDouble / unionSize
}

@Since("2.1.0")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@SparkQA
Copy link

SparkQA commented Apr 14, 2020

Test build #121248 has finished for PR 28206 at commit 7725d09.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@srowen srowen closed this in f1489e6 Apr 17, 2020
@srowen
Copy link
Member

srowen commented Apr 17, 2020

Merged to master

@zhengruifeng zhengruifeng deleted the minhash_opt branch April 20, 2020 01:48
@zhengruifeng
Copy link
Contributor Author

Thanks @srowen for reviewing!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants