Skip to content

Commit c2db5e5

Browse files
committed
make test pass
1 parent 86588e1 commit c2db5e5

File tree

2 files changed

+33
-34
lines changed
  • mllib/src

2 files changed

+33
-34
lines changed

mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ object ALS extends Logging {
291291

292292
/** Adds an observation. */
293293
def add(a: Array[Float], b: Float): this.type = {
294-
require(a.size == k)
294+
require(a.length == k)
295295
copyToDouble(a)
296296
blas.dspr(upper, k, 1.0, da, 1, ata)
297297
blas.daxpy(k, b.toDouble, da, 1, atb, 1)
@@ -303,7 +303,7 @@ object ALS extends Logging {
303303
* Adds an observation with implicit feedback. Note that this does not increment the counter.
304304
*/
305305
def addImplicit(a: Array[Float], b: Float, alpha: Double): this.type = {
306-
require(a.size == k)
306+
require(a.length == k)
307307
// Extension to the original paper to handle b < 0. confidence is a function of |b| instead
308308
// so that it is never negative.
309309
val confidence = 1.0 + alpha * math.abs(b)
@@ -319,8 +319,8 @@ object ALS extends Logging {
319319
/** Merges another normal equation object. */
320320
def merge(other: NormalEquation): this.type = {
321321
require(other.k == k)
322-
blas.daxpy(ata.size, 1.0, other.ata, 1, ata, 1)
323-
blas.daxpy(atb.size, 1.0, other.atb, 1, atb, 1)
322+
blas.daxpy(ata.length, 1.0, other.ata, 1, ata, 1)
323+
blas.daxpy(atb.length, 1.0, other.atb, 1, atb, 1)
324324
n += other.n
325325
this
326326
}
@@ -454,9 +454,9 @@ object ALS extends Logging {
454454
dstEncodedIndices: Array[Int],
455455
ratings: Array[Float]) {
456456
/** Size of the block. */
457-
val size: Int = ratings.size
458-
require(dstEncodedIndices.size == size)
459-
require(dstPtrs.size == srcIds.size + 1)
457+
def size: Int = ratings.length
458+
require(dstEncodedIndices.length == size)
459+
require(dstPtrs.length == srcIds.length + 1)
460460
}
461461

462462
/**
@@ -476,7 +476,7 @@ object ALS extends Logging {
476476
// (<1%) compared picking elements uniformly at random in [0,1].
477477
inBlocks.map { case (srcBlockId, inBlock) =>
478478
val random = new XORShiftRandom(srcBlockId)
479-
val factors = Array.fill(inBlock.srcIds.size) {
479+
val factors = Array.fill(inBlock.srcIds.length) {
480480
val factor = Array.fill(rank)(random.nextGaussian().toFloat)
481481
val nrm = blas.snrm2(rank, factor, 1)
482482
blas.sscal(rank, 1.0f / nrm, factor, 1)
@@ -489,15 +489,14 @@ object ALS extends Logging {
489489
/**
490490
* A rating block that contains src IDs, dst IDs, and ratings, stored in primitive arrays.
491491
*/
492-
private[recommendation]
493-
case class RatingBlock[@specialized(Int, Long) ID](
492+
case class RatingBlock[@specialized(Int, Long) ID: ClassTag](
494493
srcIds: Array[ID],
495494
dstIds: Array[ID],
496495
ratings: Array[Float]) {
497496
/** Size of the block. */
498-
val size: Int = srcIds.size
499-
require(dstIds.size == size)
500-
require(ratings.size == size)
497+
def size: Int = srcIds.length
498+
require(dstIds.length == srcIds.length)
499+
require(ratings.length == srcIds.length)
501500
}
502501

503502
/**
@@ -522,7 +521,7 @@ object ALS extends Logging {
522521

523522
/** Merges another [[RatingBlockBuilder]]. */
524523
def merge(other: RatingBlock[ID]): this.type = {
525-
size += other.srcIds.size
524+
size += other.srcIds.length
526525
srcIds ++= other.srcIds
527526
dstIds ++= other.dstIds
528527
ratings ++= other.ratings
@@ -531,7 +530,7 @@ object ALS extends Logging {
531530

532531
/** Builds a [[RatingBlock]]. */
533532
def build(): RatingBlock[ID] = {
534-
RatingBlock[ID](srcIds.result(), dstIds.result(), ratings.result())
533+
new RatingBlock[ID](srcIds.result(), dstIds.result(), ratings.result())
535534
}
536535
}
537536

@@ -568,14 +567,14 @@ object ALS extends Logging {
568567
val idx = srcBlockId + srcPart.numPartitions * dstBlockId
569568
val builder = builders(idx)
570569
builder.add(r)
571-
if (builder.size >= 2048) { // 2048 * (3 * 4) = 24k
570+
if (builder.length >= 2048) { // 2048 * (3 * 4) = 24k
572571
builders(idx) = new RatingBlockBuilder
573572
Iterator.single(((srcBlockId, dstBlockId), builder.build()))
574573
} else {
575574
Iterator.empty
576575
}
577576
} ++ {
578-
builders.view.zipWithIndex.filter(_._1.size > 0).map { case (block, idx) =>
577+
builders.view.zipWithIndex.filter(_._1.length > 0).map { case (block, idx) =>
579578
val srcBlockId = idx % srcPart.numPartitions
580579
val dstBlockId = idx / srcPart.numPartitions
581580
((srcBlockId, dstBlockId), block.build())
@@ -613,9 +612,9 @@ object ALS extends Logging {
613612
srcIds: Array[ID],
614613
dstLocalIndices: Array[Int],
615614
ratings: Array[Float]): this.type = {
616-
val sz = srcIds.size
617-
require(dstLocalIndices.size == sz)
618-
require(ratings.size == sz)
615+
val sz = srcIds.length
616+
require(dstLocalIndices.length == sz)
617+
require(ratings.length == sz)
619618
this.srcIds ++= srcIds
620619
this.ratings ++= ratings
621620
var j = 0
@@ -642,15 +641,15 @@ object ALS extends Logging {
642641
implicit ord: Ordering[ID]) {
643642

644643
/** Size the of block. */
645-
def size: Int = srcIds.size
644+
def length: Int = srcIds.length
646645

647646
/**
648647
* Compresses the block into an [[InBlock]]. The algorithm is the same as converting a
649648
* sparse matrix from coordinate list (COO) format into compressed sparse column (CSC) format.
650649
* Sorting is done using Spark's built-in Timsort to avoid generating too many objects.
651650
*/
652651
def compress(): InBlock[ID] = {
653-
val sz = size
652+
val sz = length
654653
assert(sz > 0, "Empty in-link block should not exist.")
655654
sort()
656655
val uniqueSrcIdsBuilder = mutable.ArrayBuilder.make[ID]
@@ -674,7 +673,7 @@ object ALS extends Logging {
674673
}
675674
dstCountsBuilder += curCount
676675
val uniqueSrcIds = uniqueSrcIdsBuilder.result()
677-
val numUniqueSrdIds = uniqueSrcIds.size
676+
val numUniqueSrdIds = uniqueSrcIds.length
678677
val dstCounts = dstCountsBuilder.result()
679678
val dstPtrs = new Array[Int](numUniqueSrdIds + 1)
680679
var sum = 0
@@ -688,13 +687,13 @@ object ALS extends Logging {
688687
}
689688

690689
private def sort(): Unit = {
691-
val sz = size
690+
val sz = length
692691
// Since there might be interleaved log messages, we insert a unique id for easy pairing.
693692
val sortId = Utils.random.nextInt()
694693
logDebug(s"Start sorting an uncompressed in-block of size $sz. (sortId = $sortId)")
695694
val start = System.nanoTime()
696695
val sorter = new Sorter(new UncompressedInBlockSort[ID])
697-
sorter.sort(this, 0, size, Ordering[KeyWrapper[ID]])
696+
sorter.sort(this, 0, length, Ordering[KeyWrapper[ID]])
698697
val duration = (System.nanoTime() - start) / 1e9
699698
logDebug(s"Sorting took $duration seconds. (sortId = $sortId)")
700699
}
@@ -819,9 +818,9 @@ object ALS extends Logging {
819818
}
820819
assert(i == dstIdSet.size)
821820
Sorting.quickSort(sortedDstIds)
822-
val dstIdToLocalIndex = new OpenHashMap[ID, Int](sortedDstIds.size)
821+
val dstIdToLocalIndex = new OpenHashMap[ID, Int](sortedDstIds.length)
823822
i = 0
824-
while (i < sortedDstIds.size) {
823+
while (i < sortedDstIds.length) {
825824
dstIdToLocalIndex.update(sortedDstIds(i), i)
826825
i += 1
827826
}
@@ -843,7 +842,7 @@ object ALS extends Logging {
843842
val activeIds = Array.fill(dstPart.numPartitions)(mutable.ArrayBuilder.make[Int])
844843
var i = 0
845844
val seen = new Array[Boolean](dstPart.numPartitions)
846-
while (i < srcIds.size) {
845+
while (i < srcIds.length) {
847846
var j = dstPtrs(i)
848847
ju.Arrays.fill(seen, false)
849848
while (j < dstPtrs(i + 1)) {
@@ -886,26 +885,26 @@ object ALS extends Logging {
886885
srcEncoder: LocalIndexEncoder,
887886
implicitPrefs: Boolean = false,
888887
alpha: Double = 1.0): RDD[(Int, FactorBlock)] = {
889-
val numSrcBlocks = srcFactorBlocks.partitions.size
888+
val numSrcBlocks = srcFactorBlocks.partitions.length
890889
val YtY = if (implicitPrefs) Some(computeYtY(srcFactorBlocks, rank)) else None
891890
val srcOut = srcOutBlocks.join(srcFactorBlocks).flatMap {
892891
case (srcBlockId, (srcOutBlock, srcFactors)) =>
893892
srcOutBlock.view.zipWithIndex.map { case (activeIndices, dstBlockId) =>
894893
(dstBlockId, (srcBlockId, activeIndices.map(idx => srcFactors(idx))))
895894
}
896895
}
897-
val merged = srcOut.groupByKey(new HashPartitioner(dstInBlocks.partitions.size))
896+
val merged = srcOut.groupByKey(new HashPartitioner(dstInBlocks.partitions.length))
898897
dstInBlocks.join(merged).mapValues {
899898
case (InBlock(dstIds, srcPtrs, srcEncodedIndices, ratings), srcFactors) =>
900899
val sortedSrcFactors = new Array[FactorBlock](numSrcBlocks)
901900
srcFactors.foreach { case (srcBlockId, factors) =>
902901
sortedSrcFactors(srcBlockId) = factors
903902
}
904-
val dstFactors = new Array[Array[Float]](dstIds.size)
903+
val dstFactors = new Array[Array[Float]](dstIds.length)
905904
var j = 0
906905
val ls = new NormalEquation(rank)
907906
val solver = new CholeskySolver // TODO: add NNLS solver
908-
while (j < dstIds.size) {
907+
while (j < dstIds.length) {
909908
ls.reset()
910909
if (implicitPrefs) {
911910
ls.merge(YtY.get)

mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,8 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
183183
.add(0, Array(1, 0, 2), Array(0, 1, 4), Array(1.0f, 2.0f, 3.0f))
184184
.add(1, Array(3, 0), Array(2, 5), Array(4.0f, 5.0f))
185185
.build()
186-
assert(uncompressed.size === 5)
187-
val records = Seq.tabulate(uncompressed.size) { i =>
186+
assert(uncompressed.length === 5)
187+
val records = Seq.tabulate(uncompressed.length) { i =>
188188
val dstEncodedIndex = uncompressed.dstEncodedIndices(i)
189189
val dstBlockId = encoder.blockId(dstEncodedIndex)
190190
val dstLocalIndex = encoder.localIndex(dstEncodedIndex)

0 commit comments

Comments
 (0)