Skip to content

Commit b15685f

Browse files
committed
Use scopt. Distribute evaluation.
1 parent eca3dfd commit b15685f

File tree

1 file changed

+48
-22
lines changed

1 file changed

+48
-22
lines changed

examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717

1818
package org.apache.spark.examples.mllib
1919

20+
import org.apache.spark.SparkContext._
2021
import org.apache.spark.mllib.linalg.Vectors
22+
import org.apache.spark.mllib.linalg.distributed.{MatrixEntry, RowMatrix}
2123
import org.apache.spark.{SparkConf, SparkContext}
22-
import org.apache.spark.mllib.linalg.distributed.{MatrixEntry, CoordinateMatrix, RowMatrix}
24+
import scopt.OptionParser
2325

2426
/**
2527
* Compute the similar columns of a matrix, using cosine similarity.
@@ -36,47 +38,71 @@ import org.apache.spark.mllib.linalg.distributed.{MatrixEntry, CoordinateMatrix,
3638
* Example invocation:
3739
*
3840
* bin/run-example org.apache.spark.examples.mllib.CosineSimilarity \
39-
* data/mllib/sample_svm_data.txt 0.1
41+
* --inputFile data/mllib/sample_svm_data.txt --threshold 0.1
4042
*/
4143
object CosineSimilarity {
44+
case class Params(inputFile: String = null, threshold: Double = 0.1)
45+
4246
def main(args: Array[String]) {
43-
if (args.length != 2) {
44-
System.err.println("Usage: CosineSimilarity <input> <threshold>")
47+
val defaultParams = Params()
48+
49+
val parser = new OptionParser[Params]("CosineSimilarity") {
50+
head("CosineSimilarity: an example app.")
51+
opt[String]("inputFile")
52+
.required()
53+
.text(s"input file, one row per line, space-separated")
54+
.action((x, c) => c.copy(inputFile = x))
55+
opt[Double]("threshold")
56+
.required()
57+
.text(s"threshold similarity: to tradeoff computation vs quality estimate")
58+
.action((x, c) => c.copy(threshold = x))
59+
note(
60+
"""
61+
|For example, the following command runs this app on a dataset:
62+
|
63+
| ./bin/spark-submit --class org.apache.spark.examples.mllib.CosineSimilarity \
64+
| examplesjar.jar \
65+
| --inputFile data/mllib/sample_svm_data.txt --threshold 0.1
66+
""".stripMargin)
67+
}
68+
69+
parser.parse(args, defaultParams).map { params =>
70+
run(params)
71+
} getOrElse {
4572
System.exit(1)
4673
}
74+
}
4775

76+
def run(params: Params) {
4877
val conf = new SparkConf().setAppName("CosineSimilarity")
4978
val sc = new SparkContext(conf)
5079

5180
// Load and parse the data file.
52-
val rows = sc.textFile(args(0)).map { line =>
81+
val rows = sc.textFile(params.inputFile).map { line =>
5382
val values = line.split(' ').map(_.toDouble)
5483
Vectors.dense(values)
5584
}
5685
val mat = new RowMatrix(rows)
5786

58-
val threshold = args(1).toDouble
59-
6087
// Compute similar columns perfectly, with brute force.
61-
val simsPerfect = mat.columnSimilarities().entries.collect
88+
val exact = mat.columnSimilarities()
6289

6390
// Compute similar columns with estimation using DIMSUM
64-
val simsEstimate = mat.columnSimilarities(threshold).entries.collect
65-
66-
val n = mat.numCols().toInt
67-
val real = Array.ofDim[Double](n, n)
68-
val est = Array.ofDim[Double](n, n)
69-
for (entry <- simsPerfect) {
70-
real(entry.i.toInt)(entry.j.toInt) = entry.value
71-
}
72-
for (entry <- simsEstimate) {
73-
est(entry.i.toInt)(entry.j.toInt) = entry.value
74-
}
91+
val approx = mat.columnSimilarities(params.threshold)
7592

76-
val errors = Array.tabulate[Double](n, n)((i, j) => math.abs(real(i)(j) - est(i)(j)))
77-
val avgErr = errors.flatten.sum / (n * (n - 1) / 2)
93+
val MAE = exact.entries.map { case MatrixEntry(i, j, u) =>
94+
((i, j), u)
95+
}.leftOuterJoin(
96+
approx.entries.map { case MatrixEntry(i, j, v) =>
97+
((i, j), v)
98+
}).values.map {
99+
case (u, Some(v)) =>
100+
math.abs(u - v)
101+
case (u, None) =>
102+
math.abs(u)
103+
}.mean()
78104

79-
println(s"Average error in estimate is: $avgErr")
105+
println(s"Average error in estimate is: $MAE")
80106

81107
sc.stop()
82108
}

0 commit comments

Comments
 (0)