Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,7 @@ class BucketedRandomProjectionLSHModel private[ml](
val hashValues: Array[Double] = randUnitVectors.map({
randUnitVector => Math.floor(BLAS.dot(key, randUnitVector) / $(bucketLength))
})
// TODO: Output vectors of dimension numHashFunctions in SPARK-18450
hashValues.map(Vectors.dense(_))
hashValues.grouped($(numHashFunctions)).map(Vectors.dense).toArray
}
}

Expand Down Expand Up @@ -137,6 +136,9 @@ class BucketedRandomProjectionLSH(override val uid: String)
@Since("2.1.0")
override def setNumHashTables(value: Int): this.type = super.setNumHashTables(value)

@Since("2.2.0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Since("2.4.0")

override def setNumHashFunctions(value: Int): this.type = super.setNumHashFunctions(value)

@Since("2.1.0")
def this() = {
this(Identifiable.randomUID("brp-lsh"))
Expand All @@ -155,7 +157,7 @@ class BucketedRandomProjectionLSH(override val uid: String)
inputDim: Int): BucketedRandomProjectionLSHModel = {
val rand = new Random($(seed))
val randUnitVectors: Array[Vector] = {
Array.fill($(numHashTables)) {
Array.fill($(numHashTables) * $(numHashFunctions)) {
val randArray = Array.fill(inputDim)(rand.nextGaussian())
Vectors.fromBreeze(normalize(breeze.linalg.Vector(randArray)))
}
Expand Down
19 changes: 18 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,24 @@ private[ml] trait LSHParams extends HasInputCol with HasOutputCol {
"tables, where increasing number of hash tables lowers the false negative rate, and " +
"decreasing it improves the running performance", ParamValidators.gt(0))

/**
* Param for the number of hash functions used in LSH AND-amplification.
*
* LSH AND-amplification can be used to reduce the false positive rate. Higher values for this
* param lead to a reduced false positive rate and lower computational complexity.
* @group param
*/
final val numHashFunctions: IntParam = new IntParam(this, "numHashFunctions", "number of hash " +
"functions, where increasing number of hash functions lowers the false positive rate, and " +
"decreasing it improves the false negative rate", ParamValidators.gt(0))

/** @group getParam */
final def getNumHashTables: Int = $(numHashTables)

setDefault(numHashTables -> 1)
/** @group getParam */
final def getNumHashFunctions: Int = $(numHashFunctions)

setDefault(numHashTables -> 1, numHashFunctions -> 1)

/**
* Transform the Schema for LSH
Expand Down Expand Up @@ -308,6 +322,9 @@ private[ml] abstract class LSH[T <: LSHModel[T]]
/** @group setParam */
def setNumHashTables(value: Int): this.type = set(numHashTables, value)

/** @group setParam */
def setNumHashFunctions(value: Int): this.type = set(numHashFunctions, value)

/**
* Validate and create a new instance of concrete LSHModel. Because different LSHModel may have
* different initial setting, developer needs to define how their LSHModel is created instead of
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ class MinHashLSHModel private[ml](
((1 + elem) * a + b) % MinHashLSH.HASH_PRIME
}.min.toDouble
}
// TODO: Output vectors of dimension numHashFunctions in SPARK-18450
hashValues.map(Vectors.dense(_))
hashValues.grouped($(numHashFunctions)).map(Vectors.dense).toArray
}
}

Expand Down Expand Up @@ -119,6 +118,9 @@ class MinHashLSH(override val uid: String) extends LSH[MinHashLSHModel] with Has
@Since("2.1.0")
override def setNumHashTables(value: Int): this.type = super.setNumHashTables(value)

@Since("2.2.0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.

override def setNumHashFunctions(value: Int): this.type = super.setNumHashFunctions(value)

@Since("2.1.0")
def this() = {
this(Identifiable.randomUID("mh-lsh"))
Expand All @@ -133,7 +135,7 @@ class MinHashLSH(override val uid: String) extends LSH[MinHashLSHModel] with Has
require(inputDim <= MinHashLSH.HASH_PRIME,
s"The input vector dimension $inputDim exceeds the threshold ${MinHashLSH.HASH_PRIME}.")
val rand = new Random($(seed))
val randCoefs: Array[(Int, Int)] = Array.fill($(numHashTables)) {
val randCoefs: Array[(Int, Int)] = Array.fill($(numHashTables) * $(numHashFunctions)) {
(1 + rand.nextInt(MinHashLSH.HASH_PRIME - 1), rand.nextInt(MinHashLSH.HASH_PRIME - 1))
}
new MinHashLSHModel(uid, randCoefs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class BucketedRandomProjectionLSHSuite
test("BucketedRandomProjectionLSH: default params") {
val brp = new BucketedRandomProjectionLSH
assert(brp.getNumHashTables === 1.0)
assert(brp.getNumHashFunctions === 1.0)
}

test("read/write") {
Expand Down Expand Up @@ -85,6 +86,7 @@ class BucketedRandomProjectionLSHSuite
test("BucketedRandomProjectionLSH: randUnitVectors") {
val brp = new BucketedRandomProjectionLSH()
.setNumHashTables(20)
.setNumHashFunctions(10)
.setInputCol("keys")
.setOutputCol("values")
.setBucketLength(1.0)
Expand Down Expand Up @@ -119,6 +121,7 @@ class BucketedRandomProjectionLSHSuite
// Project from 100 dimensional Euclidean Space to 10 dimensions
val brp = new BucketedRandomProjectionLSH()
.setNumHashTables(10)
.setNumHashFunctions(5)
.setInputCol("keys")
.setOutputCol("values")
.setBucketLength(2.5)
Expand All @@ -133,7 +136,8 @@ class BucketedRandomProjectionLSHSuite
val key = Vectors.dense(1.2, 3.4)

val brp = new BucketedRandomProjectionLSH()
.setNumHashTables(2)
.setNumHashTables(8)
.setNumHashFunctions(2)
.setInputCol("keys")
.setOutputCol("values")
.setBucketLength(4.0)
Expand All @@ -150,6 +154,7 @@ class BucketedRandomProjectionLSHSuite

val brp = new BucketedRandomProjectionLSH()
.setNumHashTables(20)
.setNumHashFunctions(10)
.setInputCol("keys")
.setOutputCol("values")
.setBucketLength(1.0)
Expand Down Expand Up @@ -182,6 +187,7 @@ class BucketedRandomProjectionLSHSuite
val dataset2 = spark.createDataFrame(data2.map(Tuple1.apply)).toDF("keys")

val brp = new BucketedRandomProjectionLSH()
.setNumHashFunctions(4)
.setNumHashTables(2)
.setInputCol("keys")
.setOutputCol("values")
Expand All @@ -200,13 +206,14 @@ class BucketedRandomProjectionLSHSuite
val df = spark.createDataFrame(data.map(Tuple1.apply)).toDF("keys")

val brp = new BucketedRandomProjectionLSH()
.setNumHashFunctions(4)
.setNumHashTables(2)
.setInputCol("keys")
.setOutputCol("values")
.setBucketLength(4.0)
.setSeed(12345)

val (precision, recall) = LSHTest.calculateApproxSimilarityJoin(brp, df, df, 3.0)
val (precision, recall) = LSHTest.calculateApproxSimilarityJoin(brp, df, df, 2.0)
assert(precision == 1.0)
assert(recall >= 0.7)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
}

test("MinHashLSH: default params") {
val rp = new MinHashLSH
assert(rp.getNumHashTables === 1.0)
val mh = new MinHashLSH
assert(mh.getNumHashTables === 1.0)
assert(mh.getNumHashFunctions === 1.0)
}

test("read/write") {
Expand Down Expand Up @@ -109,7 +110,8 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa

test("approxNearestNeighbors for min hash") {
val mh = new MinHashLSH()
.setNumHashTables(20)
.setNumHashTables(64)
.setNumHashFunctions(2)
.setInputCol("keys")
.setOutputCol("values")
.setSeed(12345)
Expand All @@ -119,8 +121,8 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa

val (precision, recall) = LSHTest.calculateApproxNearestNeighbors(mh, dataset, key, 20,
singleProbe = true)
assert(precision >= 0.7)
assert(recall >= 0.7)
assert(precision >= 0.6)
assert(recall >= 0.6)
}

test("approxNearestNeighbors for numNeighbors <= 0") {
Expand Down Expand Up @@ -149,7 +151,8 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
val df2 = spark.createDataFrame(data2.map(Tuple1.apply)).toDF("keys")

val mh = new MinHashLSH()
.setNumHashTables(20)
.setNumHashTables(64)
.setNumHashFunctions(2)
.setInputCol("keys")
.setOutputCol("values")
.setSeed(12345)
Expand Down