1717
1818package org .apache .spark .examples .mllib
1919
20+ import org .apache .spark .SparkContext ._
2021import org .apache .spark .mllib .linalg .Vectors
22+ import org .apache .spark .mllib .linalg .distributed .{MatrixEntry , RowMatrix }
2123import org .apache .spark .{SparkConf , SparkContext }
22- import org . apache . spark . mllib . linalg . distributed .{ MatrixEntry , CoordinateMatrix , RowMatrix }
24+ import scopt . OptionParser
2325
2426/**
2527 * Compute the similar columns of a matrix, using cosine similarity.
@@ -36,47 +38,71 @@ import org.apache.spark.mllib.linalg.distributed.{MatrixEntry, CoordinateMatrix,
3638 * Example invocation:
3739 *
3840 * bin/run-example org.apache.spark.examples.mllib.CosineSimilarity \
39- * data/mllib/sample_svm_data.txt 0.1
41+ * --inputFile data/mllib/sample_svm_data.txt --threshold 0.1
4042 */
4143object CosineSimilarity {
44+ case class Params (inputFile : String = null , threshold : Double = 0.1 )
45+
4246 def main (args : Array [String ]) {
43- if (args.length != 2 ) {
44- System .err.println(" Usage: CosineSimilarity <input> <threshold>" )
47+ val defaultParams = Params ()
48+
49+ val parser = new OptionParser [Params ](" CosineSimilarity" ) {
50+ head(" CosineSimilarity: an example app." )
51+ opt[String ](" inputFile" )
52+ .required()
53+ .text(s " input file, one row per line, space-separated " )
54+ .action((x, c) => c.copy(inputFile = x))
55+ opt[Double ](" threshold" )
56+ .required()
57+ .text(s " threshold similarity: to tradeoff computation vs quality estimate " )
58+ .action((x, c) => c.copy(threshold = x))
59+ note(
60+ """
61+ |For example, the following command runs this app on a dataset:
62+ |
63+ | ./bin/spark-submit --class org.apache.spark.examples.mllib.CosineSimilarity \
64+ | examplesjar.jar \
65+ | --inputFile data/mllib/sample_svm_data.txt --threshold 0.1
66+ """ .stripMargin)
67+ }
68+
69+ parser.parse(args, defaultParams).map { params =>
70+ run(params)
71+ } getOrElse {
4572 System .exit(1 )
4673 }
74+ }
4775
76+ def run (params : Params ) {
4877 val conf = new SparkConf ().setAppName(" CosineSimilarity" )
4978 val sc = new SparkContext (conf)
5079
5180 // Load and parse the data file.
52- val rows = sc.textFile(args( 0 ) ).map { line =>
81+ val rows = sc.textFile(params.inputFile ).map { line =>
5382 val values = line.split(' ' ).map(_.toDouble)
5483 Vectors .dense(values)
5584 }
5685 val mat = new RowMatrix (rows)
5786
58- val threshold = args(1 ).toDouble
59-
6087 // Compute similar columns perfectly, with brute force.
61- val simsPerfect = mat.columnSimilarities().entries.collect
88+ val exact = mat.columnSimilarities()
6289
6390 // Compute similar columns with estimation using DIMSUM
64- val simsEstimate = mat.columnSimilarities(threshold).entries.collect
65-
66- val n = mat.numCols().toInt
67- val real = Array .ofDim[Double ](n, n)
68- val est = Array .ofDim[Double ](n, n)
69- for (entry <- simsPerfect) {
70- real(entry.i.toInt)(entry.j.toInt) = entry.value
71- }
72- for (entry <- simsEstimate) {
73- est(entry.i.toInt)(entry.j.toInt) = entry.value
74- }
91+ val approx = mat.columnSimilarities(params.threshold)
7592
76- val errors = Array .tabulate[Double ](n, n)((i, j) => math.abs(real(i)(j) - est(i)(j)))
77- val avgErr = errors.flatten.sum / (n * (n - 1 ) / 2 )
93+ val MAE = exact.entries.map { case MatrixEntry (i, j, u) =>
94+ ((i, j), u)
95+ }.leftOuterJoin(
96+ approx.entries.map { case MatrixEntry (i, j, v) =>
97+ ((i, j), v)
98+ }).values.map {
99+ case (u, Some (v)) =>
100+ math.abs(u - v)
101+ case (u, None ) =>
102+ math.abs(u)
103+ }.mean()
78104
79- println(s " Average error in estimate is: $avgErr " )
105+ println(s " Average error in estimate is: $MAE " )
80106
81107 sc.stop()
82108 }
0 commit comments