Skip to content

Commit 3d7b36e

Browse files
rezazadehmengxr
authored andcommitted
[SPARK-3790][MLlib] CosineSimilarity Example
Provide example for `RowMatrix.columnSimilarity()` Author: Reza Zadeh <[email protected]> Closes apache#2622 from rezazadeh/dimsumexample and squashes the following commits: 8f20b82 [Reza Zadeh] update comment 379066d [Reza Zadeh] cache rows 792b81c [Reza Zadeh] Address review comments e573c7a [Reza Zadeh] Average absolute error b15685f [Reza Zadeh] Use scopt. Distribute evaluation. eca3dfd [Reza Zadeh] Documentation ac96fb2 [Reza Zadeh] Compute approximation error, add command line. 4533579 [Reza Zadeh] CosineSimilarity Example
1 parent 446063e commit 3d7b36e

File tree

1 file changed

+107
-0
lines changed

1 file changed

+107
-0
lines changed
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.examples.mllib
19+
20+
import scopt.OptionParser
21+
22+
import org.apache.spark.SparkContext._
23+
import org.apache.spark.mllib.linalg.Vectors
24+
import org.apache.spark.mllib.linalg.distributed.{MatrixEntry, RowMatrix}
25+
import org.apache.spark.{SparkConf, SparkContext}
26+
27+
/**
28+
* Compute the similar columns of a matrix, using cosine similarity.
29+
*
30+
* The input matrix must be stored in row-oriented dense format, one line per row with its entries
31+
* separated by space. For example,
32+
* {{{
33+
* 0.5 1.0
34+
* 2.0 3.0
35+
* 4.0 5.0
36+
* }}}
37+
* represents a 3-by-2 matrix, whose first row is (0.5, 1.0).
38+
*
39+
* Example invocation:
40+
*
41+
* bin/run-example mllib.CosineSimilarity \
42+
* --threshold 0.1 data/mllib/sample_svm_data.txt
43+
*/
44+
object CosineSimilarity {
45+
case class Params(inputFile: String = null, threshold: Double = 0.1)
46+
47+
def main(args: Array[String]) {
48+
val defaultParams = Params()
49+
50+
val parser = new OptionParser[Params]("CosineSimilarity") {
51+
head("CosineSimilarity: an example app.")
52+
opt[Double]("threshold")
53+
.required()
54+
.text(s"threshold similarity: to tradeoff computation vs quality estimate")
55+
.action((x, c) => c.copy(threshold = x))
56+
arg[String]("<inputFile>")
57+
.required()
58+
.text(s"input file, one row per line, space-separated")
59+
.action((x, c) => c.copy(inputFile = x))
60+
note(
61+
"""
62+
|For example, the following command runs this app on a dataset:
63+
|
64+
| ./bin/spark-submit --class org.apache.spark.examples.mllib.CosineSimilarity \
65+
| examplesjar.jar \
66+
| --threshold 0.1 data/mllib/sample_svm_data.txt
67+
""".stripMargin)
68+
}
69+
70+
parser.parse(args, defaultParams).map { params =>
71+
run(params)
72+
} getOrElse {
73+
System.exit(1)
74+
}
75+
}
76+
77+
def run(params: Params) {
78+
val conf = new SparkConf().setAppName("CosineSimilarity")
79+
val sc = new SparkContext(conf)
80+
81+
// Load and parse the data file.
82+
val rows = sc.textFile(params.inputFile).map { line =>
83+
val values = line.split(' ').map(_.toDouble)
84+
Vectors.dense(values)
85+
}.cache()
86+
val mat = new RowMatrix(rows)
87+
88+
// Compute similar columns perfectly, with brute force.
89+
val exact = mat.columnSimilarities()
90+
91+
// Compute similar columns with estimation using DIMSUM
92+
val approx = mat.columnSimilarities(params.threshold)
93+
94+
val exactEntries = exact.entries.map { case MatrixEntry(i, j, u) => ((i, j), u) }
95+
val approxEntries = approx.entries.map { case MatrixEntry(i, j, v) => ((i, j), v) }
96+
val MAE = exactEntries.leftOuterJoin(approxEntries).values.map {
97+
case (u, Some(v)) =>
98+
math.abs(u - v)
99+
case (u, None) =>
100+
math.abs(u)
101+
}.mean()
102+
103+
println(s"Average absolute error in estimate is: $MAE")
104+
105+
sc.stop()
106+
}
107+
}

0 commit comments

Comments
 (0)