Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.{Logging, Partitioner, SparkException}
import org.apache.spark.annotation.Since
import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Matrix, SparseMatrix}
import org.apache.spark.rdd.RDD
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.storage.StorageLevel

/**
Expand Down Expand Up @@ -436,3 +437,29 @@ class BlockMatrix @Since("1.3.0") (
}
}
}

@Since("1.6.0")
object BlockMatrix {
/** A Java-friendly auxiliary factory. */
@Since("1.6.0")
def from[M <: Matrix](
blocks: JavaRDD[((Integer, Integer), M)],
rowsPerBlock: Int,
colsPerBlock: Int,
nRows: Long,
nCols: Long): BlockMatrix = {
val rdd = blocks.rdd.map { case ((blockRowIndex, blockColIndex), subMatrix) =>
((blockRowIndex.toInt, blockColIndex.toInt), subMatrix.asInstanceOf[Matrix])
}
new BlockMatrix(rdd, rowsPerBlock, colsPerBlock, nRows, nCols)
}

/** A Java-friendly auxiliary factory without the input of the number of rows and columns. */
@Since("1.6.0")
def from[M <: Matrix](
blocks: JavaRDD[((Integer, Integer), M)],
rowsPerBlock: Int,
colsPerBlock: Int): BlockMatrix = {
from(blocks, rowsPerBlock, colsPerBlock, 0L, 0L)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import breeze.numerics.{sqrt => brzSqrt}

import org.apache.spark.Logging
import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary}
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -675,6 +676,25 @@ class RowMatrix @Since("1.0.0") (
@Since("1.0.0")
object RowMatrix {

/** A Java-friendly auxiliary factory. */
@Since("1.6.0")
def from[V <: Vector](
rows: JavaRDD[V],
nRows: Long,
nCols: Int): RowMatrix = {
val rdd = rows.rdd.map(_.asInstanceOf[Vector])
new RowMatrix(rdd, nRows, nCols)
}

/**
* Alternative Java-friendly auxiliary factory
* leaving matrix dimensions to be determined automatically.
*/
@Since("1.6.0")
def from[V <: Vector](rows: JavaRDD[V]): RowMatrix = {
from(rows, 0L, 0)
}

/**
* Fills a full square matrix from its upper triangular part.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package org.apache.spark.mllib.linalg.distributed;

import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.DenseMatrix;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import scala.Tuple2;

import java.util.Arrays;
import java.util.List;

import static org.junit.Assert.assertEquals;

public class JavaBlockMatrixSuite {
private transient JavaSparkContext sc;

@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaBlockMatrixSuite");
}

@After
public void tearDown() {
sc.stop();
sc = null;
}

@Test
public void blockMatrixConstruction() {
List<Tuple2<Tuple2<Integer, Integer>, DenseMatrix>> blocks = Arrays.asList(
new Tuple2<Tuple2<Integer, Integer>, DenseMatrix>(
new Tuple2<Integer, Integer>(0, 0), new DenseMatrix(2, 2, new double[]{1.0, 0.0, 0.0, 2.0})),
new Tuple2<Tuple2<Integer, Integer>, DenseMatrix>(
new Tuple2<Integer, Integer>(0, 1), new DenseMatrix(2, 2, new double[]{0.0, 1.0, 0.0, 0.0})),
new Tuple2<Tuple2<Integer, Integer>, DenseMatrix>(
new Tuple2<Integer, Integer>(1, 0), new DenseMatrix(2, 2, new double[]{3.0, 0.0, 1.0, 1.0})),
new Tuple2<Tuple2<Integer, Integer>, DenseMatrix>(
new Tuple2<Integer, Integer>(1, 1), new DenseMatrix(2, 2, new double[]{1.0, 2.0, 0.0, 1.0})),
new Tuple2<Tuple2<Integer, Integer>, DenseMatrix>(
new Tuple2<Integer, Integer>(2, 1), new DenseMatrix(1, 2, new double[]{1.0, 5.0})));
BlockMatrix blockMatrix = BlockMatrix.from(sc.parallelize(blocks), 2, 2);
final DenseMatrix expectedMatrix = new DenseMatrix(5, 4, new double[]{
1.0, 0.0, 0.0, 0.0,
0.0, 2.0, 1.0, 0.0,
3.0, 1.0, 1.0, 0.0,
0.0, 1.0, 2.0, 1.0,
0.0, 0.0, 1.0, 5.0}, true);
assertEquals(5L, blockMatrix.numRows());
assertEquals(4L, blockMatrix.numCols());
assertEquals(expectedMatrix, blockMatrix.toLocalMatrix());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package org.apache.spark.mllib.linalg.distributed;

import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.DenseMatrix;
import org.apache.spark.mllib.linalg.DenseVector;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;

import java.util.Arrays;
import java.util.List;

import static org.junit.Assert.assertEquals;

public class JavaRowMatrixSuite {
private transient JavaSparkContext sc;

@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaRowMatrixSuite");
}

@After
public void tearDown() {
sc.stop();
sc = null;
}

@Test
public void rowMatrixConstruction() {
List<DenseVector> rows = Arrays.asList(
new DenseVector(new double[]{1.0, 0.0, 0.0, 0.0}),
new DenseVector(new double[]{0.0, 2.0, 1.0, 0.0}),
new DenseVector(new double[]{3.0, 1.0, 1.0, 0.0}),
new DenseVector(new double[]{0.0, 1.0, 2.0, 1.0}),
new DenseVector(new double[]{0.0, 0.0, 1.0, 5.0}));
RowMatrix rowMatrix = RowMatrix.from(sc.parallelize(rows));
final DenseMatrix expectedMatrix = new DenseMatrix(5, 4, new double[]{
1.0, 0.0, 0.0, 0.0,
0.0, 2.0, 1.0, 0.0,
3.0, 1.0, 1.0, 0.0,
0.0, 1.0, 2.0, 1.0,
0.0, 0.0, 1.0, 5.0}, true);
assertEquals(5L, rowMatrix.numRows());
assertEquals(4L, rowMatrix.numCols());
assertEquals(expectedMatrix.toBreeze(), rowMatrix.toBreeze());
}
}