diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala index 09527dcf5d9e..72b04260237e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala @@ -25,6 +25,7 @@ import org.apache.spark.{Logging, Partitioner, SparkException} import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Matrix, SparseMatrix} import org.apache.spark.rdd.RDD +import org.apache.spark.api.java.JavaRDD import org.apache.spark.storage.StorageLevel /** @@ -436,3 +437,29 @@ class BlockMatrix @Since("1.3.0") ( } } } + +@Since("1.6.0") +object BlockMatrix { + /** A Java-friendly auxiliary factory. */ + @Since("1.6.0") + def from[M <: Matrix]( + blocks: JavaRDD[((Integer, Integer), M)], + rowsPerBlock: Int, + colsPerBlock: Int, + nRows: Long, + nCols: Long): BlockMatrix = { + val rdd = blocks.rdd.map { case ((blockRowIndex, blockColIndex), subMatrix) => + ((blockRowIndex.toInt, blockColIndex.toInt), subMatrix.asInstanceOf[Matrix]) + } + new BlockMatrix(rdd, rowsPerBlock, colsPerBlock, nRows, nCols) + } + + /** A Java-friendly auxiliary factory without the input of the number of rows and columns. */ + @Since("1.6.0") + def from[M <: Matrix]( + blocks: JavaRDD[((Integer, Integer), M)], + rowsPerBlock: Int, + colsPerBlock: Int): BlockMatrix = { + from(blocks, rowsPerBlock, colsPerBlock, 0L, 0L) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 52c0f19c645d..42b3c2190152 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -27,6 +27,7 @@ import breeze.numerics.{sqrt => brzSqrt} import org.apache.spark.Logging import org.apache.spark.annotation.Since +import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary} import org.apache.spark.rdd.RDD @@ -675,6 +676,25 @@ class RowMatrix @Since("1.0.0") ( @Since("1.0.0") object RowMatrix { + /** A Java-friendly auxiliary factory. */ + @Since("1.6.0") + def from[V <: Vector]( + rows: JavaRDD[V], + nRows: Long, + nCols: Int): RowMatrix = { + val rdd = rows.rdd.map(_.asInstanceOf[Vector]) + new RowMatrix(rdd, nRows, nCols) + } + + /** + * Alternative Java-friendly auxiliary factory + * leaving matrix dimensions to be determined automatically. + */ + @Since("1.6.0") + def from[V <: Vector](rows: JavaRDD[V]): RowMatrix = { + from(rows, 0L, 0) + } + /** * Fills a full square matrix from its upper triangular part. */ diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/distributed/JavaBlockMatrixSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/distributed/JavaBlockMatrixSuite.java new file mode 100644 index 000000000000..53fac2a75d58 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/distributed/JavaBlockMatrixSuite.java @@ -0,0 +1,53 @@ +package org.apache.spark.mllib.linalg.distributed; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.DenseMatrix; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import scala.Tuple2; + +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertEquals; + +public class JavaBlockMatrixSuite { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaBlockMatrixSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void blockMatrixConstruction() { + List, DenseMatrix>> blocks = Arrays.asList( + new Tuple2, DenseMatrix>( + new Tuple2(0, 0), new DenseMatrix(2, 2, new double[]{1.0, 0.0, 0.0, 2.0})), + new Tuple2, DenseMatrix>( + new Tuple2(0, 1), new DenseMatrix(2, 2, new double[]{0.0, 1.0, 0.0, 0.0})), + new Tuple2, DenseMatrix>( + new Tuple2(1, 0), new DenseMatrix(2, 2, new double[]{3.0, 0.0, 1.0, 1.0})), + new Tuple2, DenseMatrix>( + new Tuple2(1, 1), new DenseMatrix(2, 2, new double[]{1.0, 2.0, 0.0, 1.0})), + new Tuple2, DenseMatrix>( + new Tuple2(2, 1), new DenseMatrix(1, 2, new double[]{1.0, 5.0}))); + BlockMatrix blockMatrix = BlockMatrix.from(sc.parallelize(blocks), 2, 2); + final DenseMatrix expectedMatrix = new DenseMatrix(5, 4, new double[]{ + 1.0, 0.0, 0.0, 0.0, + 0.0, 2.0, 1.0, 0.0, + 3.0, 1.0, 1.0, 0.0, + 0.0, 1.0, 2.0, 1.0, + 0.0, 0.0, 1.0, 5.0}, true); + assertEquals(5L, blockMatrix.numRows()); + assertEquals(4L, blockMatrix.numCols()); + assertEquals(expectedMatrix, blockMatrix.toLocalMatrix()); + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/distributed/JavaRowMatrixSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/distributed/JavaRowMatrixSuite.java new file mode 100644 index 000000000000..f25103cc39e9 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/distributed/JavaRowMatrixSuite.java @@ -0,0 +1,48 @@ +package org.apache.spark.mllib.linalg.distributed; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.DenseMatrix; +import org.apache.spark.mllib.linalg.DenseVector; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertEquals; + +public class JavaRowMatrixSuite { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaRowMatrixSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void rowMatrixConstruction() { + List rows = Arrays.asList( + new DenseVector(new double[]{1.0, 0.0, 0.0, 0.0}), + new DenseVector(new double[]{0.0, 2.0, 1.0, 0.0}), + new DenseVector(new double[]{3.0, 1.0, 1.0, 0.0}), + new DenseVector(new double[]{0.0, 1.0, 2.0, 1.0}), + new DenseVector(new double[]{0.0, 0.0, 1.0, 5.0})); + RowMatrix rowMatrix = RowMatrix.from(sc.parallelize(rows)); + final DenseMatrix expectedMatrix = new DenseMatrix(5, 4, new double[]{ + 1.0, 0.0, 0.0, 0.0, + 0.0, 2.0, 1.0, 0.0, + 3.0, 1.0, 1.0, 0.0, + 0.0, 1.0, 2.0, 1.0, + 0.0, 0.0, 1.0, 5.0}, true); + assertEquals(5L, rowMatrix.numRows()); + assertEquals(4L, rowMatrix.numCols()); + assertEquals(expectedMatrix.toBreeze(), rowMatrix.toBreeze()); + } +}