Skip to content

Commit 8877cf1

Browse files
Yun Nicmonkey
authored andcommitted
[SPARK-18080][ML][PYTHON] Python API & Examples for Locality Sensitive Hashing
## What changes were proposed in this pull request? This pull request includes python API and examples for LSH. The API changes was based on yanboliang 's PR apache#15768 and resolved conflicts and API changes on the Scala API. The examples are consistent with Scala examples of MinHashLSH and BucketedRandomProjectionLSH. ## How was this patch tested? API and examples are tested using spark-submit: `bin/spark-submit examples/src/main/python/ml/min_hash_lsh.py` `bin/spark-submit examples/src/main/python/ml/bucketed_random_projection_lsh.py` User guide changes are generated and manually inspected: `SKIP_API=1 jekyll build` Author: Yun Ni <[email protected]> Author: Yanbo Liang <[email protected]> Author: Yunni <[email protected]> Closes apache#16715 from Yunni/spark-18080.
1 parent 66bb9bb commit 8877cf1

File tree

9 files changed

+601
-53
lines changed

9 files changed

+601
-53
lines changed

docs/ml-features.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1558,6 +1558,15 @@ for more details on the API.
15581558

15591559
{% include_example java/org/apache/spark/examples/ml/JavaBucketedRandomProjectionLSHExample.java %}
15601560
</div>
1561+
1562+
<div data-lang="python" markdown="1">
1563+
1564+
Refer to the [BucketedRandomProjectionLSH Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.BucketedRandomProjectionLSH)
1565+
for more details on the API.
1566+
1567+
{% include_example python/ml/bucketed_random_projection_lsh_example.py %}
1568+
</div>
1569+
15611570
</div>
15621571

15631572
### MinHash for Jaccard Distance
@@ -1590,4 +1599,12 @@ for more details on the API.
15901599

15911600
{% include_example java/org/apache/spark/examples/ml/JavaMinHashLSHExample.java %}
15921601
</div>
1602+
1603+
<div data-lang="python" markdown="1">
1604+
1605+
Refer to the [MinHashLSH Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.MinHashLSH)
1606+
for more details on the API.
1607+
1608+
{% include_example python/ml/min_hash_lsh_example.py %}
1609+
</div>
15931610
</div>

examples/src/main/java/org/apache/spark/examples/ml/JavaBucketedRandomProjectionLSHExample.java

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,15 @@
3535
import org.apache.spark.sql.types.Metadata;
3636
import org.apache.spark.sql.types.StructField;
3737
import org.apache.spark.sql.types.StructType;
38+
39+
import static org.apache.spark.sql.functions.col;
3840
// $example off$
3941

42+
/**
43+
* An example demonstrating BucketedRandomProjectionLSH.
44+
* Run with:
45+
* bin/run-example org.apache.spark.examples.ml.JavaBucketedRandomProjectionLSHExample
46+
*/
4047
public class JavaBucketedRandomProjectionLSHExample {
4148
public static void main(String[] args) {
4249
SparkSession spark = SparkSession
@@ -61,7 +68,7 @@ public static void main(String[] args) {
6168

6269
StructType schema = new StructType(new StructField[]{
6370
new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
64-
new StructField("keys", new VectorUDT(), false, Metadata.empty())
71+
new StructField("features", new VectorUDT(), false, Metadata.empty())
6572
});
6673
Dataset<Row> dfA = spark.createDataFrame(dataA, schema);
6774
Dataset<Row> dfB = spark.createDataFrame(dataB, schema);
@@ -71,26 +78,31 @@ public static void main(String[] args) {
7178
BucketedRandomProjectionLSH mh = new BucketedRandomProjectionLSH()
7279
.setBucketLength(2.0)
7380
.setNumHashTables(3)
74-
.setInputCol("keys")
75-
.setOutputCol("values");
81+
.setInputCol("features")
82+
.setOutputCol("hashes");
7683

7784
BucketedRandomProjectionLSHModel model = mh.fit(dfA);
7885

7986
// Feature Transformation
87+
System.out.println("The hashed dataset where hashed values are stored in the column 'hashes':");
8088
model.transform(dfA).show();
81-
// Cache the transformed columns
82-
Dataset<Row> transformedA = model.transform(dfA).cache();
83-
Dataset<Row> transformedB = model.transform(dfB).cache();
8489

85-
// Approximate similarity join
86-
model.approxSimilarityJoin(dfA, dfB, 1.5).show();
87-
model.approxSimilarityJoin(transformedA, transformedB, 1.5).show();
88-
// Self Join
89-
model.approxSimilarityJoin(dfA, dfA, 2.5).filter("datasetA.id < datasetB.id").show();
90+
// Compute the locality sensitive hashes for the input rows, then perform approximate
91+
// similarity join.
92+
// We could avoid computing hashes by passing in the already-transformed dataset, e.g.
93+
// `model.approxSimilarityJoin(transformedA, transformedB, 1.5)`
94+
System.out.println("Approximately joining dfA and dfB on distance smaller than 1.5:");
95+
model.approxSimilarityJoin(dfA, dfB, 1.5, "EuclideanDistance")
96+
.select(col("datasetA.id").alias("idA"),
97+
col("datasetB.id").alias("idB"),
98+
col("EuclideanDistance")).show();
9099

91-
// Approximate nearest neighbor search
100+
// Compute the locality sensitive hashes for the input rows, then perform approximate nearest
101+
// neighbor search.
102+
// We could avoid computing hashes by passing in the already-transformed dataset, e.g.
103+
// `model.approxNearestNeighbors(transformedA, key, 2)`
104+
System.out.println("Approximately searching dfA for 2 nearest neighbors of the key:");
92105
model.approxNearestNeighbors(dfA, key, 2).show();
93-
model.approxNearestNeighbors(transformedA, key, 2).show();
94106
// $example off$
95107

96108
spark.stop();

examples/src/main/java/org/apache/spark/examples/ml/JavaMinHashLSHExample.java

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import org.apache.spark.ml.feature.MinHashLSH;
2727
import org.apache.spark.ml.feature.MinHashLSHModel;
28+
import org.apache.spark.ml.linalg.Vector;
2829
import org.apache.spark.ml.linalg.VectorUDT;
2930
import org.apache.spark.ml.linalg.Vectors;
3031
import org.apache.spark.sql.Dataset;
@@ -34,8 +35,15 @@
3435
import org.apache.spark.sql.types.Metadata;
3536
import org.apache.spark.sql.types.StructField;
3637
import org.apache.spark.sql.types.StructType;
38+
39+
import static org.apache.spark.sql.functions.col;
3740
// $example off$
3841

42+
/**
43+
* An example demonstrating MinHashLSH.
44+
* Run with:
45+
* bin/run-example org.apache.spark.examples.ml.JavaMinHashLSHExample
46+
*/
3947
public class JavaMinHashLSHExample {
4048
public static void main(String[] args) {
4149
SparkSession spark = SparkSession
@@ -44,25 +52,58 @@ public static void main(String[] args) {
4452
.getOrCreate();
4553

4654
// $example on$
47-
List<Row> data = Arrays.asList(
55+
List<Row> dataA = Arrays.asList(
4856
RowFactory.create(0, Vectors.sparse(6, new int[]{0, 1, 2}, new double[]{1.0, 1.0, 1.0})),
4957
RowFactory.create(1, Vectors.sparse(6, new int[]{2, 3, 4}, new double[]{1.0, 1.0, 1.0})),
5058
RowFactory.create(2, Vectors.sparse(6, new int[]{0, 2, 4}, new double[]{1.0, 1.0, 1.0}))
5159
);
5260

61+
List<Row> dataB = Arrays.asList(
62+
RowFactory.create(0, Vectors.sparse(6, new int[]{1, 3, 5}, new double[]{1.0, 1.0, 1.0})),
63+
RowFactory.create(1, Vectors.sparse(6, new int[]{2, 3, 5}, new double[]{1.0, 1.0, 1.0})),
64+
RowFactory.create(2, Vectors.sparse(6, new int[]{1, 2, 4}, new double[]{1.0, 1.0, 1.0}))
65+
);
66+
5367
StructType schema = new StructType(new StructField[]{
5468
new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
55-
new StructField("keys", new VectorUDT(), false, Metadata.empty())
69+
new StructField("features", new VectorUDT(), false, Metadata.empty())
5670
});
57-
Dataset<Row> dataFrame = spark.createDataFrame(data, schema);
71+
Dataset<Row> dfA = spark.createDataFrame(dataA, schema);
72+
Dataset<Row> dfB = spark.createDataFrame(dataB, schema);
73+
74+
int[] indices = {1, 3};
75+
double[] values = {1.0, 1.0};
76+
Vector key = Vectors.sparse(6, indices, values);
5877

5978
MinHashLSH mh = new MinHashLSH()
60-
.setNumHashTables(1)
61-
.setInputCol("keys")
62-
.setOutputCol("values");
79+
.setNumHashTables(5)
80+
.setInputCol("features")
81+
.setOutputCol("hashes");
82+
83+
MinHashLSHModel model = mh.fit(dfA);
84+
85+
// Feature Transformation
86+
System.out.println("The hashed dataset where hashed values are stored in the column 'hashes':");
87+
model.transform(dfA).show();
88+
89+
// Compute the locality sensitive hashes for the input rows, then perform approximate
90+
// similarity join.
91+
// We could avoid computing hashes by passing in the already-transformed dataset, e.g.
92+
// `model.approxSimilarityJoin(transformedA, transformedB, 0.6)`
93+
System.out.println("Approximately joining dfA and dfB on Jaccard distance smaller than 0.6:");
94+
model.approxSimilarityJoin(dfA, dfB, 0.6, "JaccardDistance")
95+
.select(col("datasetA.id").alias("idA"),
96+
col("datasetB.id").alias("idB"),
97+
col("JaccardDistance")).show();
6398

64-
MinHashLSHModel model = mh.fit(dataFrame);
65-
model.transform(dataFrame).show();
99+
// Compute the locality sensitive hashes for the input rows, then perform approximate nearest
100+
// neighbor search.
101+
// We could avoid computing hashes by passing in the already-transformed dataset, e.g.
102+
// `model.approxNearestNeighbors(transformedA, key, 2)`
103+
// It may return less than 2 rows when not enough approximate near-neighbor candidates are
104+
// found.
105+
System.out.println("Approximately searching dfA for 2 nearest neighbors of the key:");
106+
model.approxNearestNeighbors(dfA, key, 2).show();
66107
// $example off$
67108

68109
spark.stop();
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
19+
from __future__ import print_function
20+
21+
# $example on$
22+
from pyspark.ml.feature import BucketedRandomProjectionLSH
23+
from pyspark.ml.linalg import Vectors
24+
from pyspark.sql.functions import col
25+
# $example off$
26+
from pyspark.sql import SparkSession
27+
28+
"""
29+
An example demonstrating BucketedRandomProjectionLSH.
30+
Run with:
31+
bin/spark-submit examples/src/main/python/ml/bucketed_random_projection_lsh_example.py
32+
"""
33+
34+
if __name__ == "__main__":
35+
spark = SparkSession \
36+
.builder \
37+
.appName("BucketedRandomProjectionLSHExample") \
38+
.getOrCreate()
39+
40+
# $example on$
41+
dataA = [(0, Vectors.dense([1.0, 1.0]),),
42+
(1, Vectors.dense([1.0, -1.0]),),
43+
(2, Vectors.dense([-1.0, -1.0]),),
44+
(3, Vectors.dense([-1.0, 1.0]),)]
45+
dfA = spark.createDataFrame(dataA, ["id", "features"])
46+
47+
dataB = [(4, Vectors.dense([1.0, 0.0]),),
48+
(5, Vectors.dense([-1.0, 0.0]),),
49+
(6, Vectors.dense([0.0, 1.0]),),
50+
(7, Vectors.dense([0.0, -1.0]),)]
51+
dfB = spark.createDataFrame(dataB, ["id", "features"])
52+
53+
key = Vectors.dense([1.0, 0.0])
54+
55+
brp = BucketedRandomProjectionLSH(inputCol="features", outputCol="hashes", bucketLength=2.0,
56+
numHashTables=3)
57+
model = brp.fit(dfA)
58+
59+
# Feature Transformation
60+
print("The hashed dataset where hashed values are stored in the column 'hashes':")
61+
model.transform(dfA).show()
62+
63+
# Compute the locality sensitive hashes for the input rows, then perform approximate
64+
# similarity join.
65+
# We could avoid computing hashes by passing in the already-transformed dataset, e.g.
66+
# `model.approxSimilarityJoin(transformedA, transformedB, 1.5)`
67+
print("Approximately joining dfA and dfB on Euclidean distance smaller than 1.5:")
68+
model.approxSimilarityJoin(dfA, dfB, 1.5, distCol="EuclideanDistance")\
69+
.select(col("datasetA.id").alias("idA"),
70+
col("datasetB.id").alias("idB"),
71+
col("EuclideanDistance")).show()
72+
73+
# Compute the locality sensitive hashes for the input rows, then perform approximate nearest
74+
# neighbor search.
75+
# We could avoid computing hashes by passing in the already-transformed dataset, e.g.
76+
# `model.approxNearestNeighbors(transformedA, key, 2)`
77+
print("Approximately searching dfA for 2 nearest neighbors of the key:")
78+
model.approxNearestNeighbors(dfA, key, 2).show()
79+
# $example off$
80+
81+
spark.stop()
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
19+
from __future__ import print_function
20+
21+
# $example on$
22+
from pyspark.ml.feature import MinHashLSH
23+
from pyspark.ml.linalg import Vectors
24+
from pyspark.sql.functions import col
25+
# $example off$
26+
from pyspark.sql import SparkSession
27+
28+
"""
29+
An example demonstrating MinHashLSH.
30+
Run with:
31+
bin/spark-submit examples/src/main/python/ml/min_hash_lsh_example.py
32+
"""
33+
34+
if __name__ == "__main__":
35+
spark = SparkSession \
36+
.builder \
37+
.appName("MinHashLSHExample") \
38+
.getOrCreate()
39+
40+
# $example on$
41+
dataA = [(0, Vectors.sparse(6, [0, 1, 2], [1.0, 1.0, 1.0]),),
42+
(1, Vectors.sparse(6, [2, 3, 4], [1.0, 1.0, 1.0]),),
43+
(2, Vectors.sparse(6, [0, 2, 4], [1.0, 1.0, 1.0]),)]
44+
dfA = spark.createDataFrame(dataA, ["id", "features"])
45+
46+
dataB = [(3, Vectors.sparse(6, [1, 3, 5], [1.0, 1.0, 1.0]),),
47+
(4, Vectors.sparse(6, [2, 3, 5], [1.0, 1.0, 1.0]),),
48+
(5, Vectors.sparse(6, [1, 2, 4], [1.0, 1.0, 1.0]),)]
49+
dfB = spark.createDataFrame(dataB, ["id", "features"])
50+
51+
key = Vectors.sparse(6, [1, 3], [1.0, 1.0])
52+
53+
mh = MinHashLSH(inputCol="features", outputCol="hashes", numHashTables=5)
54+
model = mh.fit(dfA)
55+
56+
# Feature Transformation
57+
print("The hashed dataset where hashed values are stored in the column 'hashes':")
58+
model.transform(dfA).show()
59+
60+
# Compute the locality sensitive hashes for the input rows, then perform approximate
61+
# similarity join.
62+
# We could avoid computing hashes by passing in the already-transformed dataset, e.g.
63+
# `model.approxSimilarityJoin(transformedA, transformedB, 0.6)`
64+
print("Approximately joining dfA and dfB on distance smaller than 0.6:")
65+
model.approxSimilarityJoin(dfA, dfB, 0.6, distCol="JaccardDistance")\
66+
.select(col("datasetA.id").alias("idA"),
67+
col("datasetB.id").alias("idB"),
68+
col("JaccardDistance")).show()
69+
70+
# Compute the locality sensitive hashes for the input rows, then perform approximate nearest
71+
# neighbor search.
72+
# We could avoid computing hashes by passing in the already-transformed dataset, e.g.
73+
# `model.approxNearestNeighbors(transformedA, key, 2)`
74+
# It may return less than 2 rows when not enough approximate near-neighbor candidates are
75+
# found.
76+
print("Approximately searching dfA for 2 nearest neighbors of the key:")
77+
model.approxNearestNeighbors(dfA, key, 2).show()
78+
79+
# $example off$
80+
81+
spark.stop()

0 commit comments

Comments
 (0)