@@ -361,9 +361,9 @@ object SparseMatrix {
361361 * @param entries Array of (i, j, value) tuples
362362 * @return The corresponding `SparseMatrix`
363363 */
364- def fromCOO (numRows : Int , numCols : Int , entries : Array [(Int , Int , Double )]): SparseMatrix = {
365- val numEntries = entries.size
366- val sortedEntries = entries.sortBy(v => (v._2, v._1))
364+ def fromCOO (numRows : Int , numCols : Int , entries : Iterable [(Int , Int , Double )]): SparseMatrix = {
365+ val sortedEntries = entries.toSeq.sortBy(v => (v._2, v._1))
366+ val numEntries = sortedEntries.size
367367 if (sortedEntries.nonEmpty) {
368368 // Since the entries are sorted by column index, we only need to check the first and the last.
369369 for (col <- Seq (sortedEntries.head._2, sortedEntries.last._2)) {
@@ -413,54 +413,59 @@ object SparseMatrix {
413413 new SparseMatrix (n, n, (0 to n).toArray, (0 until n).toArray, Array .fill(n)(1.0 ))
414414 }
415415
416- /** Generates the skeleton of a random `SparseMatrix` with a given random number generator. */
416+ /**
417+ * Generates the skeleton of a random `SparseMatrix` with a given random number generator.
418+ * The values of the matrix returned are undefined.
419+ */
417420 private def genRandMatrix (
418421 numRows : Int ,
419422 numCols : Int ,
420423 density : Double ,
421424 rng : Random ): SparseMatrix = {
422- require(density >= 0.0 && density <= 1.0 , " density must be a double in the range " +
423- s " 0.0 <= d <= 1.0. Currently, density: $density" )
424- val length = math.ceil(numRows * numCols * density).toInt
425- var i = 0
425+ require(numRows > 0 , s " numRows must be greater than 0 but got $numRows" )
426+ require(numCols > 0 , s " numCols must be greater than 0 but got $numCols" )
427+ require(density >= 0.0 && density <= 1.0 ,
428+ s " density must be a double in the range 0.0 <= d <= 1.0. Currently, density: $density" )
429+ val size = numRows.toLong * numCols
430+ val expected = size * density
431+ assert(expected < Int .MaxValue ,
432+ " The expected number of nonzeros cannot be greater than Int.MaxValue." )
433+ val nnz = math.ceil(expected).toInt
426434 if (density == 0.0 ) {
427- return new SparseMatrix (numRows, numCols, new Array [Int ](numCols + 1 ),
428- Array [Int ](), Array [Double ]())
435+ new SparseMatrix (numRows, numCols, new Array [Int ](numCols + 1 ), Array [Int ](), Array [Double ]())
429436 } else if (density == 1.0 ) {
430- val rowIndices = Array .tabulate(numCols, numRows)((j, i) => i).flatten
431- return new SparseMatrix (numRows, numCols, ( 0 to numRows * numCols by numRows).toArray,
432- rowIndices, new Array [Double ](numRows * numCols))
433- }
434- if (density < 0.34 ) { // Expected number of iterations is less than 1.5 * length
437+ val colPtrs = Array .tabulate(numCols + 1 )(j => j * numRows)
438+ val rowIndices = Array .tabulate(size.toInt)(idx => idx % numRows)
439+ new SparseMatrix (numRows, numCols, colPtrs, rowIndices, new Array [Double ](numRows * numCols))
440+ } else if (density < 0.34 ) {
441+ // draw-by-draw, expected number of iterations is less than 1.5 * nnz
435442 val entries = MHashSet [(Int , Int )]()
436- while (entries.size < length ) {
443+ while (entries.size < nnz ) {
437444 entries += ((rng.nextInt(numRows), rng.nextInt(numCols)))
438445 }
439- val entryList = entries.toArray.map(v => (v._1, v._2, 1.0 ))
440- SparseMatrix .fromCOO(numRows, numCols, entryList)
441- } else { // selection - rejection method
446+ SparseMatrix .fromCOO(numRows, numCols, entries.map(v => (v._1, v._2, 1.0 )))
447+ } else {
448+ // selection-rejection method
449+ var idx = 0L
450+ var numSelected = 0
451+ var i = 0
442452 var j = 0
443- val pool = numRows * numCols
444- val rowIndexBuilder = new MArrayBuilder .ofInt
445453 val colPtrs = new Array [Int ](numCols + 1 )
446- while (i < length && j < numCols) {
447- var passedInPool = j * numRows
448- var r = 0
449- while (i < length && r < numRows) {
450- if (rng.nextDouble() < 1.0 * (length - i) / (pool - passedInPool)) {
451- rowIndexBuilder += r
452- i += 1
454+ val rowIndices = new Array [Int ](nnz)
455+ while (j < numCols && numSelected < nnz) {
456+ while (i < numRows && numSelected < nnz) {
457+ if (rng.nextDouble() < 1.0 * (nnz - numSelected) / (size - idx)) {
458+ rowIndices(numSelected) = i
459+ numSelected += 1
453460 }
454- r += 1
455- passedInPool += 1
461+ i += 1
462+ idx += 1
456463 }
464+ colPtrs(j + 1 ) = numSelected
457465 j += 1
458- colPtrs(j) = i
459466 }
460- val rowIndices = rowIndexBuilder.result()
461- new SparseMatrix (numRows, numCols, colPtrs, rowIndices, new Array [Double ](rowIndices.size))
467+ new SparseMatrix (numRows, numCols, colPtrs, rowIndices, new Array [Double ](nnz))
462468 }
463-
464469 }
465470
466471 /**
0 commit comments