Skip to content

Commit e36469a

Browse files
committed
add classtags and make it compile
1 parent 7a5aeb3 commit e36469a

File tree

2 files changed

+85
-70
lines changed
  • mllib/src

2 files changed

+85
-70
lines changed

mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala

Lines changed: 75 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ package org.apache.spark.ml.recommendation
2020
import java.{util => ju}
2121

2222
import scala.collection.mutable
23+
import scala.reflect.ClassTag
24+
import scala.util.Sorting
2325

2426
import com.github.fommil.netlib.BLAS.{getInstance => blas}
2527
import com.github.fommil.netlib.LAPACK.{getInstance => lapack}
@@ -29,7 +31,7 @@ import org.apache.spark.{HashPartitioner, Logging, Partitioner}
2931
import org.apache.spark.ml.{Estimator, Model}
3032
import org.apache.spark.ml.param._
3133
import org.apache.spark.rdd.RDD
32-
import org.apache.spark.sql.{Column, DataFrame}
34+
import org.apache.spark.sql.DataFrame
3335
import org.apache.spark.sql.api.scala.dsl._
3436
import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType}
3537
import org.apache.spark.util.Utils
@@ -221,7 +223,8 @@ private[recommendation] object ALS extends Logging {
221223
private[recommendation]
222224
case class Rating[@specialized(Int, Long) User, @specialized(Int, Long) Item](
223225
user: User,
224-
item: Item, rating: Float)
226+
item: Item,
227+
rating: Float)
225228

226229
/** Cholesky solver for least square problems. */
227230
private[recommendation] class CholeskySolver {
@@ -333,15 +336,19 @@ private[recommendation] object ALS extends Logging {
333336
/**
334337
* Implementation of the ALS algorithm.
335338
*/
336-
private def train[@specialized(Int, Long) User, @specialized(Int, Long) Item](
339+
private def train[
340+
@specialized(Int, Long) User: ClassTag,
341+
@specialized(Int, Long) Item: ClassTag](
337342
ratings: RDD[Rating[User, Item]],
338343
rank: Int = 10,
339344
numUserBlocks: Int = 10,
340345
numItemBlocks: Int = 10,
341346
maxIter: Int = 10,
342347
regParam: Double = 1.0,
343348
implicitPrefs: Boolean = false,
344-
alpha: Double = 1.0): (RDD[(User, Array[Float])], RDD[(Item, Array[Float])]) = {
349+
alpha: Double = 1.0)(
350+
implicit userOrd: Ordering[User],
351+
itemOrd: Ordering[Item]): (RDD[(User, Array[Float])], RDD[(Item, Array[Float])]) = {
345352
val userPart = new HashPartitioner(numUserBlocks)
346353
val itemPart = new HashPartitioner(numItemBlocks)
347354
val userLocalIndexEncoder = new LocalIndexEncoder(userPart.numPartitions)
@@ -444,8 +451,8 @@ private[recommendation] object ALS extends Logging {
444451
*
445452
* @see [[LocalIndexEncoder]]
446453
*/
447-
private[recommendation] case class InBlock[@specialized(Int, Long) SrcType](
448-
srcIds: Array[SrcType],
454+
private[recommendation] case class InBlock[@specialized(Int, Long) Src: ClassTag](
455+
srcIds: Array[Src],
449456
dstPtrs: Array[Int],
450457
dstEncodedIndices: Array[Int],
451458
ratings: Array[Float]) {
@@ -463,8 +470,8 @@ private[recommendation] object ALS extends Logging {
463470
* @param rank rank
464471
* @return initialized factor blocks
465472
*/
466-
private def initialize[@specialized(Int, Long) SrcType](
467-
inBlocks: RDD[(Int, InBlock[SrcType])],
473+
private def initialize[@specialized(Int, Long) Src](
474+
inBlocks: RDD[(Int, InBlock[Src])],
468475
rank: Int): RDD[(Int, FactorBlock)] = {
469476
// Choose a unit vector uniformly at random from the unit sphere, but from the
470477
// "first quadrant" where all elements are nonnegative. This can be done by choosing
@@ -487,9 +494,9 @@ private[recommendation] object ALS extends Logging {
487494
* A rating block that contains src IDs, dst IDs, and ratings, stored in primitive arrays.
488495
*/
489496
private[recommendation]
490-
case class RatingBlock[@specialized(Int, Long) SrcType, @specialized(Int, Long) DstType](
491-
srcIds: Array[SrcType],
492-
dstIds: Array[DstType],
497+
case class RatingBlock[@specialized(Int, Long) Src, @specialized(Int, Long) Dst](
498+
srcIds: Array[Src],
499+
dstIds: Array[Dst],
493500
ratings: Array[Float]) {
494501
/** Size of the block. */
495502
val size: Int = srcIds.size
@@ -501,16 +508,16 @@ private[recommendation] object ALS extends Logging {
501508
* Builder for [[RatingBlock]]. [[mutable.ArrayBuilder]] is used to avoid boxing/unboxing.
502509
*/
503510
private[recommendation] class RatingBlockBuilder[
504-
@specialized(Int, Long) SrcType,
505-
@specialized(Int, Long) DstType] extends Serializable {
511+
@specialized(Int, Long) Src: ClassTag,
512+
@specialized(Int, Long) Dst: ClassTag] extends Serializable {
506513

507-
private val srcIds = mutable.ArrayBuilder.make[SrcType]
508-
private val dstIds = mutable.ArrayBuilder.make[DstType]
514+
private val srcIds = mutable.ArrayBuilder.make[Src]
515+
private val dstIds = mutable.ArrayBuilder.make[Dst]
509516
private val ratings = mutable.ArrayBuilder.make[Float]
510517
var size = 0
511518

512519
/** Adds a rating. */
513-
def add(r: Rating[SrcType, DstType]): this.type = {
520+
def add(r: Rating[Src, Dst]): this.type = {
514521
size += 1
515522
srcIds += r.user
516523
dstIds += r.item
@@ -519,7 +526,7 @@ private[recommendation] object ALS extends Logging {
519526
}
520527

521528
/** Merges another [[RatingBlockBuilder]]. */
522-
def merge(other: RatingBlock[SrcType, DstType]): this.type = {
529+
def merge(other: RatingBlock[Src, Dst]): this.type = {
523530
size += other.srcIds.size
524531
srcIds ++= other.srcIds
525532
dstIds ++= other.dstIds
@@ -528,8 +535,8 @@ private[recommendation] object ALS extends Logging {
528535
}
529536

530537
/** Builds a [[RatingBlock]]. */
531-
def build(): RatingBlock[SrcType, DstType] = {
532-
RatingBlock[SrcType, DstType](srcIds.result(), dstIds.result(), ratings.result())
538+
def build(): RatingBlock[Src, Dst] = {
539+
RatingBlock[Src, Dst](srcIds.result(), dstIds.result(), ratings.result())
533540
}
534541
}
535542

@@ -542,7 +549,9 @@ private[recommendation] object ALS extends Logging {
542549
*
543550
* @return an RDD of rating blocks in the form of ((srcBlockId, dstBlockId), ratingBlock)
544551
*/
545-
private def partitionRatings[@specialized(Int, Long) User, @specialized(Int, Long) Item](
552+
private def partitionRatings[
553+
@specialized(Int, Long) User: ClassTag,
554+
@specialized(Int, Long) Item: ClassTag](
546555
ratings: RDD[Rating[User, Item]],
547556
srcPart: Partitioner,
548557
dstPart: Partitioner): RDD[((Int, Int), RatingBlock[User, Item])] = {
@@ -590,10 +599,11 @@ private[recommendation] object ALS extends Logging {
590599
* Builder for uncompressed in-blocks of (srcId, dstEncodedIndex, rating) tuples.
591600
* @param encoder encoder for dst indices
592601
*/
593-
private[recommendation] class UncompressedInBlockBuilder[@specialized(Int, Long) SrcType](
594-
encoder: LocalIndexEncoder) {
602+
private[recommendation] class UncompressedInBlockBuilder[@specialized(Int, Long) Src: ClassTag](
603+
encoder: LocalIndexEncoder)(
604+
implicit ord: Ordering[Src]) {
595605

596-
private val srcIds = mutable.ArrayBuilder.make[SrcType]
606+
private val srcIds = mutable.ArrayBuilder.make[Src]
597607
private val dstEncodedIndices = mutable.ArrayBuilder.make[Int]
598608
private val ratings = mutable.ArrayBuilder.make[Float]
599609

@@ -607,7 +617,7 @@ private[recommendation] object ALS extends Logging {
607617
*/
608618
def add(
609619
dstBlockId: Int,
610-
srcIds: Array[SrcType],
620+
srcIds: Array[Src],
611621
dstLocalIndices: Array[Int],
612622
ratings: Array[Float]): this.type = {
613623
val sz = srcIds.size
@@ -624,18 +634,19 @@ private[recommendation] object ALS extends Logging {
624634
}
625635

626636
/** Builds a [[UncompressedInBlock]]. */
627-
def build(): UncompressedInBlock[SrcType] = {
637+
def build(): UncompressedInBlock[Src] = {
628638
new UncompressedInBlock(srcIds.result(), dstEncodedIndices.result(), ratings.result())
629639
}
630640
}
631641

632642
/**
633643
* A block of (srcId, dstEncodedIndex, rating) tuples stored in primitive arrays.
634644
*/
635-
private[recommendation] class UncompressedInBlock[@specialized(Int, Long) SrcType](
636-
val srcIds: Array[SrcType],
645+
private[recommendation] class UncompressedInBlock[@specialized(Int, Long) Src: ClassTag](
646+
val srcIds: Array[Src],
637647
val dstEncodedIndices: Array[Int],
638-
val ratings: Array[Float]) {
648+
val ratings: Array[Float])(
649+
implicit ord: Ordering[Src]) {
639650

640651
/** Size the of block. */
641652
def size: Int = srcIds.size
@@ -645,11 +656,11 @@ private[recommendation] object ALS extends Logging {
645656
* sparse matrix from coordinate list (COO) format into compressed sparse column (CSC) format.
646657
* Sorting is done using Spark's built-in Timsort to avoid generating too many objects.
647658
*/
648-
def compress(): InBlock[SrcType] = {
659+
def compress(): InBlock[Src] = {
649660
val sz = size
650661
assert(sz > 0, "Empty in-link block should not exist.")
651662
sort()
652-
val uniqueSrcIdsBuilder = mutable.ArrayBuilder.make[SrcType]
663+
val uniqueSrcIdsBuilder = mutable.ArrayBuilder.make[Src]
653664
val dstCountsBuilder = mutable.ArrayBuilder.make[Int]
654665
var preSrcId = srcIds(0)
655666
uniqueSrcIdsBuilder += preSrcId
@@ -689,8 +700,8 @@ private[recommendation] object ALS extends Logging {
689700
val sortId = Utils.random.nextInt()
690701
logDebug(s"Start sorting an uncompressed in-block of size $sz. (sortId = $sortId)")
691702
val start = System.nanoTime()
692-
val sorter = new Sorter(new UncompressedInBlockSort[SrcType])
693-
sorter.sort(this, 0, size, Ordering[KeyWrapper[SrcType]])
703+
val sorter = new Sorter(new UncompressedInBlockSort[Src])
704+
sorter.sort(this, 0, size, Ordering[KeyWrapper[Src]])
694705
val duration = (System.nanoTime() - start) / 1e9
695706
logDebug(s"Sorting took $duration seconds. (sortId = $sortId)")
696707
}
@@ -701,13 +712,13 @@ private[recommendation] object ALS extends Logging {
701712
*
702713
* @see [[UncompressedInBlockSort]]
703714
*/
704-
private class KeyWrapper[@specialized(Int, Long) KeyType <: Ordered[KeyType]]
705-
extends Ordered[KeyWrapper[KeyType]] {
715+
private class KeyWrapper[@specialized(Int, Long) KeyType: ClassTag](
716+
implicit ord: Ordering[KeyType]) extends Ordered[KeyWrapper[KeyType]] {
706717

707-
private var key: KeyType = _
718+
var key: KeyType = _
708719

709720
override def compare(that: KeyWrapper[KeyType]): Int = {
710-
key.compare(that.key)
721+
ord.compare(key, that.key)
711722
}
712723

713724
def setKey(key: KeyType): this.type = {
@@ -719,15 +730,16 @@ private[recommendation] object ALS extends Logging {
719730
/**
720731
* [[SortDataFormat]] of [[UncompressedInBlock]] used by [[Sorter]].
721732
*/
722-
private class UncompressedInBlockSort[@specialized(Int, Long) SrcType]
723-
extends SortDataFormat[KeyWrapper[SrcType], UncompressedInBlock[SrcType]] {
733+
private class UncompressedInBlockSort[@specialized(Int, Long) Src: ClassTag](
734+
implicit ord: Ordering[Src])
735+
extends SortDataFormat[KeyWrapper[Src], UncompressedInBlock[Src]] {
724736

725-
override def newKey(): KeyWrapper[SrcType] = new KeyWrapper()
737+
override def newKey(): KeyWrapper[Src] = new KeyWrapper()
726738

727739
override def getKey(
728-
data: UncompressedInBlock[SrcType],
740+
data: UncompressedInBlock[Src],
729741
pos: Int,
730-
reuse: KeyWrapper[SrcType]): KeyWrapper[SrcType] = {
742+
reuse: KeyWrapper[Src]): KeyWrapper[Src] = {
731743
if (reuse == null) {
732744
new KeyWrapper().setKey(data.srcIds(pos))
733745
} else {
@@ -736,8 +748,8 @@ private[recommendation] object ALS extends Logging {
736748
}
737749

738750
override def getKey(
739-
data: UncompressedInBlock[SrcType],
740-
pos: Int): KeyWrapper[SrcType] = {
751+
data: UncompressedInBlock[Src],
752+
pos: Int): KeyWrapper[Src] = {
741753
getKey(data, pos, null)
742754
}
743755

@@ -750,32 +762,32 @@ private[recommendation] object ALS extends Logging {
750762
data(pos1) = tmp
751763
}
752764

753-
override def swap(data: UncompressedInBlock[SrcType], pos0: Int, pos1: Int): Unit = {
765+
override def swap(data: UncompressedInBlock[Src], pos0: Int, pos1: Int): Unit = {
754766
swapElements(data.srcIds, pos0, pos1)
755767
swapElements(data.dstEncodedIndices, pos0, pos1)
756768
swapElements(data.ratings, pos0, pos1)
757769
}
758770

759771
override def copyRange(
760-
src: UncompressedInBlock[SrcType],
772+
src: UncompressedInBlock[Src],
761773
srcPos: Int,
762-
dst: UncompressedInBlock[SrcType],
774+
dst: UncompressedInBlock[Src],
763775
dstPos: Int,
764776
length: Int): Unit = {
765777
System.arraycopy(src.srcIds, srcPos, dst.srcIds, dstPos, length)
766778
System.arraycopy(src.dstEncodedIndices, srcPos, dst.dstEncodedIndices, dstPos, length)
767779
System.arraycopy(src.ratings, srcPos, dst.ratings, dstPos, length)
768780
}
769781

770-
override def allocate(length: Int): UncompressedInBlock[SrcType] = {
782+
override def allocate(length: Int): UncompressedInBlock[Src] = {
771783
new UncompressedInBlock(
772-
new Array[SrcType](length), new Array[Int](length), new Array[Float](length))
784+
new Array[Src](length), new Array[Int](length), new Array[Float](length))
773785
}
774786

775787
override def copyElement(
776-
src: UncompressedInBlock[SrcType],
788+
src: UncompressedInBlock[Src],
777789
srcPos: Int,
778-
dst: UncompressedInBlock[SrcType],
790+
dst: UncompressedInBlock[Src],
779791
dstPos: Int): Unit = {
780792
dst.srcIds(dstPos) = src.srcIds(srcPos)
781793
dst.dstEncodedIndices(dstPos) = src.dstEncodedIndices(srcPos)
@@ -791,20 +803,23 @@ private[recommendation] object ALS extends Logging {
791803
* @param dstPart partitioner for dst IDs
792804
* @return (in-blocks, out-blocks)
793805
*/
794-
private def makeBlocks[@specialized(Int, Long) SrcType, @specialized(Int, Long) DstType](
806+
private def makeBlocks[
807+
@specialized(Int, Long) Src: ClassTag,
808+
@specialized(Int, Long) Dst: ClassTag](
795809
prefix: String,
796-
ratingBlocks: RDD[((Int, Int), RatingBlock[SrcType, DstType])],
810+
ratingBlocks: RDD[((Int, Int), RatingBlock[Src, Dst])],
797811
srcPart: Partitioner,
798812
dstPart: Partitioner)(
799-
implicit ord: Ordering[DstType]): (RDD[(Int, InBlock[SrcType])], RDD[(Int, OutBlock)]) = {
813+
implicit srcOrd: Ordering[Src],
814+
dstOrd: Ordering[Dst]): (RDD[(Int, InBlock[Src])], RDD[(Int, OutBlock)]) = {
800815
val inBlocks = ratingBlocks.map {
801816
case ((srcBlockId, dstBlockId), RatingBlock(srcIds, dstIds, ratings)) =>
802817
// The implementation is a faster version of
803818
// val dstIdToLocalIndex = dstIds.toSet.toSeq.sorted.zipWithIndex.toMap
804819
val start = System.nanoTime()
805-
val dstIdSet = new OpenHashSet[DstType](1 << 20)
820+
val dstIdSet = new OpenHashSet[Dst](1 << 20)
806821
dstIds.foreach(dstIdSet.add)
807-
val sortedDstIds = new Array[DstType](dstIdSet.size)
822+
val sortedDstIds = new Array[Dst](dstIdSet.size)
808823
var i = 0
809824
var pos = dstIdSet.nextPos(0)
810825
while (pos != -1) {
@@ -813,8 +828,8 @@ private[recommendation] object ALS extends Logging {
813828
i += 1
814829
}
815830
assert(i == dstIdSet.size)
816-
ju.Arrays.sort(sortedDstIds, ord)
817-
val dstIdToLocalIndex = new OpenHashMap[DstType, Int](sortedDstIds.size)
831+
Sorting.quickSort(sortedDstIds)
832+
val dstIdToLocalIndex = new OpenHashMap[Dst, Int](sortedDstIds.size)
818833
i = 0
819834
while (i < sortedDstIds.size) {
820835
dstIdToLocalIndex.update(sortedDstIds(i), i)
@@ -827,7 +842,7 @@ private[recommendation] object ALS extends Logging {
827842
}.groupByKey(new HashPartitioner(srcPart.numPartitions))
828843
.mapValues { iter =>
829844
val builder =
830-
new UncompressedInBlockBuilder[SrcType](new LocalIndexEncoder(dstPart.numPartitions))
845+
new UncompressedInBlockBuilder[Src](new LocalIndexEncoder(dstPart.numPartitions))
831846
iter.foreach { case (dstBlockId, srcIds, dstLocalIndices, ratings) =>
832847
builder.add(dstBlockId, srcIds, dstLocalIndices, ratings)
833848
}
@@ -872,10 +887,10 @@ private[recommendation] object ALS extends Logging {
872887
*
873888
* @return dst factors
874889
*/
875-
private def computeFactors[@specialized(Int, Long) SrcType](
890+
private def computeFactors[@specialized(Int, Long) Src](
876891
srcFactorBlocks: RDD[(Int, FactorBlock)],
877892
srcOutBlocks: RDD[(Int, OutBlock)],
878-
dstInBlocks: RDD[(Int, InBlock[SrcType])],
893+
dstInBlocks: RDD[(Int, InBlock[Src])],
879894
rank: Int,
880895
regParam: Double,
881896
srcEncoder: LocalIndexEncoder,

0 commit comments

Comments
 (0)