Skip to content

Commit a09904f

Browse files
committed
Helper function for wrapping Array[Double]'s with DoubleMatrix's.
1 parent 7863ecc commit a09904f

File tree

1 file changed

+9
-2
lines changed
  • mllib/src/main/scala/org/apache/spark/mllib/recommendation

1 file changed

+9
-2
lines changed

mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ class ALS private (
267267
private def computeYtY(factors: RDD[(Int, Array[Array[Double]])]) = {
268268
val n = rank * (rank + 1) / 2
269269
val LYtY = factors.values.aggregate(new DoubleMatrix(n))( seqOp = (L, Y) => {
270-
Y.foreach(y => dspr(1.0, new DoubleMatrix(y), L))
270+
Y.foreach(y => dspr(1.0, wrapDoubleArray(y), L))
271271
L
272272
}, combOp = (L1, L2) => {
273273
L1.addi(L2)
@@ -302,6 +302,13 @@ class ALS private (
302302
}
303303
}
304304

305+
/**
306+
* Wrap a double array in a DoubleMatrix without creating garbage.
307+
*/
308+
private def wrapDoubleArray(v: Array[Double]): DoubleMatrix = {
309+
new DoubleMatrix(v.length, 1, v:_*)
310+
}
311+
305312
/**
306313
* Flatten out blocked user or product factors into an RDD of (id, factor vector) pairs
307314
*/
@@ -455,7 +462,7 @@ class ALS private (
455462
// block
456463
for (productBlock <- 0 until numBlocks) {
457464
for (p <- 0 until blockFactors(productBlock).length) {
458-
val x = new DoubleMatrix(blockFactors(productBlock)(p))
465+
val x = wrapDoubleArray(blockFactors(productBlock)(p))
459466
tempXtX.fill(0.0)
460467
dspr(1.0, x, tempXtX)
461468
val (us, rs) = inLinkBlock.ratingsForBlock(productBlock)(p)

0 commit comments

Comments
 (0)