diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala index 6e43d60bd03a3..f437d66cddb54 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Vectors.scala @@ -178,6 +178,14 @@ sealed trait Vector extends Serializable { */ @Since("2.0.0") def argmax: Int + + /** + * Calculate the dot product of this vector with another. + * + * If `size` does not match an [[IllegalArgumentException]] is thrown. + */ + @Since("3.0.0") + def dot(v: Vector): Double = BLAS.dot(this, v) } /** diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala index 0a316f57f811b..c97dc2c3c06f8 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/VectorsSuite.scala @@ -380,4 +380,27 @@ class VectorsSuite extends SparkMLFunSuite { Vectors.sparse(-1, Array((1, 2.0))) } } + + test("dot product only supports vectors of same size") { + val vSize4 = Vectors.dense(arr) + val vSize1 = Vectors.zeros(1) + intercept[IllegalArgumentException]{ vSize1.dot(vSize4) } + } + + test("dense vector dot product") { + val dv = Vectors.dense(arr) + assert(dv.dot(dv) === 0.26) + } + + test("sparse vector dot product") { + val sv = Vectors.sparse(n, indices, values) + assert(sv.dot(sv) === 0.26) + } + + test("mixed sparse and dense vector dot product") { + val sv = Vectors.sparse(n, indices, values) + val dv = Vectors.dense(arr) + assert(sv.dot(dv) === 0.26) + assert(dv.dot(sv) === 0.26) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index b754fad0c1796..83a519326df75 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -204,6 +204,14 @@ sealed trait Vector extends Serializable { */ @Since("2.0.0") def asML: newlinalg.Vector + + /** + * Calculate the dot product of this vector with another. + * + * If `size` does not match an [[IllegalArgumentException]] is thrown. + */ + @Since("3.0.0") + def dot(v: Vector): Double = BLAS.dot(this, v) } /** diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index fee0b02bf8ed8..b2163b518dbd1 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -510,4 +510,27 @@ class VectorsSuite extends SparkFunSuite with Logging { Vectors.sparse(-1, Array((1, 2.0))) } } + + test("dot product only supports vectors of same size") { + val vSize4 = Vectors.dense(arr) + val vSize1 = Vectors.zeros(1) + intercept[IllegalArgumentException]{ vSize1.dot(vSize4) } + } + + test("dense vector dot product") { + val dv = Vectors.dense(arr) + assert(dv.dot(dv) === 0.26) + } + + test("sparse vector dot product") { + val sv = Vectors.sparse(n, indices, values) + assert(sv.dot(sv) === 0.26) + } + + test("mixed sparse and dense vector dot product") { + val sv = Vectors.sparse(n, indices, values) + val dv = Vectors.dense(arr) + assert(sv.dot(dv) === 0.26) + assert(dv.dot(sv) === 0.26) + } }