Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ private[ml] class WeightedLeastSquares(
if (fitIntercept) {
// shift centers
// A^T A - aBar aBar^T
RowMatrix.dspr(-1.0, aBar, aaValues)
BLAS.spr(-1.0, aBar, aaValues)
// A^T b - bBar aBar
BLAS.axpy(-bBar, aBar, abBar)
}
Expand Down Expand Up @@ -203,7 +203,7 @@ private[ml] object WeightedLeastSquares {
bbSum += w * b * b
BLAS.axpy(w, a, aSum)
BLAS.axpy(w * b, a, abSum)
RowMatrix.dspr(w, a, aaSum.values)
BLAS.spr(w, a, aaSum)
this
}

Expand Down
44 changes: 44 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,50 @@ private[spark] object BLAS extends Serializable with Logging {
_nativeBLAS
}

/**
* Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's ?SPR.
*
* @param U the upper triangular part of the matrix in a [[DenseVector]](column major)
*/
def spr(alpha: Double, v: Vector, U: DenseVector): Unit = {
spr(alpha, v, U.values)
}

/**
* Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's ?SPR.
*
* @param U the upper triangular part of the matrix packed in an array (column major)
*/
def spr(alpha: Double, v: Vector, U: Array[Double]): Unit = {
val n = v.size
v match {
case DenseVector(values) =>
NativeBLAS.dspr("U", n, alpha, values, 1, U)
case SparseVector(size, indices, values) =>
val nnz = indices.length
var colStartIdx = 0
var prevCol = 0
var col = 0
var j = 0
var i = 0
var av = 0.0
while (j < nnz) {
col = indices(j)
// Skip empty columns.
colStartIdx += (col - prevCol) * (col + prevCol + 1) / 2
col = indices(j)
av = alpha * values(j)
i = 0
while (i <= j) {
U(colStartIdx + indices(i)) += av * values(i)
i += 1
}
j += 1
prevCol = col
}
}
}

/**
* A := alpha * x * x^T^ + A
* @param alpha a real scalar that will be multiplied to x * x^T^.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import scala.collection.mutable.ListBuffer
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV, axpy => brzAxpy,
svd => brzSvd, MatrixSingularException, inv}
import breeze.numerics.{sqrt => brzSqrt}
import com.github.fommil.netlib.BLAS.{getInstance => blas}

import org.apache.spark.Logging
import org.apache.spark.SparkContext._
Expand Down Expand Up @@ -123,7 +122,7 @@ class RowMatrix @Since("1.0.0") (
// Compute the upper triangular part of the gram matrix.
val GU = rows.treeAggregate(new BDV[Double](new Array[Double](nt)))(
seqOp = (U, v) => {
RowMatrix.dspr(1.0, v, U.data)
BLAS.spr(1.0, v, U.data)
U
}, combOp = (U1, U2) => U1 += U2)

Expand Down Expand Up @@ -673,43 +672,6 @@ class RowMatrix @Since("1.0.0") (
@Experimental
object RowMatrix {

/**
* Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's DSPR.
*
* @param U the upper triangular part of the matrix packed in an array (column major)
*/
// TODO: SPARK-10491 - move this method to linalg.BLAS
private[spark] def dspr(alpha: Double, v: Vector, U: Array[Double]): Unit = {
// TODO: Find a better home (breeze?) for this method.
val n = v.size
v match {
case DenseVector(values) =>
blas.dspr("U", n, alpha, values, 1, U)
case SparseVector(size, indices, values) =>
val nnz = indices.length
var colStartIdx = 0
var prevCol = 0
var col = 0
var j = 0
var i = 0
var av = 0.0
while (j < nnz) {
col = indices(j)
// Skip empty columns.
colStartIdx += (col - prevCol) * (col + prevCol + 1) / 2
col = indices(j)
av = alpha * values(j)
i = 0
while (i <= j) {
U(colStartIdx + indices(i)) += av * values(i)
i += 1
}
j += 1
prevCol = col
}
}
}

/**
* Fills a full square matrix from its upper triangular part.
*/
Expand Down
25 changes: 25 additions & 0 deletions mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,31 @@ class BLASSuite extends SparkFunSuite {
}
}

test("spr") {
// test dense vector
val alpha = 0.1
val x = new DenseVector(Array(1.0, 2, 2.1, 4))
val U = new DenseVector(Array(1.0, 2, 2, 3, 3, 3, 4, 4, 4, 4))
val expected = new DenseVector(Array(1.1, 2.2, 2.4, 3.21, 3.42, 3.441, 4.4, 4.8, 4.84, 5.6))

spr(alpha, x, U)
assert(U ~== expected absTol 1e-9)

val matrix33 = new DenseVector(Array(1.0, 2, 3, 4, 5))
withClue("Size of vector must match the rank of matrix") {
intercept[Exception] {
spr(alpha, x, matrix33)
}
}

// test sparse vector
val sv = new SparseVector(4, Array(0, 3), Array(1.0, 2))
val U2 = new DenseVector(Array(1.0, 2, 2, 3, 3, 3, 4, 4, 4, 4))
spr(0.1, sv, U2)
val expectedSparse = new DenseVector(Array(1.1, 2.0, 2.0, 3.0, 3.0, 3.0, 4.2, 4.0, 4.0, 4.4))
assert(U2 ~== expectedSparse absTol 1e-15)
}

test("syr") {
val dA = new DenseMatrix(4, 4,
Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0, 3.1, 4.6, 3.0, 0.8))
Expand Down