@@ -20,6 +20,8 @@ package org.apache.spark.ml.recommendation
2020import java .{util => ju }
2121
2222import scala .collection .mutable
23+ import scala .reflect .ClassTag
24+ import scala .util .Sorting
2325
2426import com .github .fommil .netlib .BLAS .{getInstance => blas }
2527import com .github .fommil .netlib .LAPACK .{getInstance => lapack }
@@ -29,7 +31,7 @@ import org.apache.spark.{HashPartitioner, Logging, Partitioner}
2931import org .apache .spark .ml .{Estimator , Model }
3032import org .apache .spark .ml .param ._
3133import org .apache .spark .rdd .RDD
32- import org .apache .spark .sql .{ Column , DataFrame }
34+ import org .apache .spark .sql .DataFrame
3335import org .apache .spark .sql .api .scala .dsl ._
3436import org .apache .spark .sql .types .{DoubleType , FloatType , IntegerType , StructField , StructType }
3537import org .apache .spark .util .Utils
@@ -221,7 +223,8 @@ private[recommendation] object ALS extends Logging {
221223 private [recommendation]
222224 case class Rating [@ specialized(Int , Long ) User , @ specialized(Int , Long ) Item ](
223225 user : User ,
224- item : Item , rating : Float )
226+ item : Item ,
227+ rating : Float )
225228
226229 /** Cholesky solver for least square problems. */
227230 private [recommendation] class CholeskySolver {
@@ -333,15 +336,19 @@ private[recommendation] object ALS extends Logging {
333336 /**
334337 * Implementation of the ALS algorithm.
335338 */
336- private def train [@ specialized(Int , Long ) User , @ specialized(Int , Long ) Item ](
339+ private def train [
340+ @ specialized(Int , Long ) User : ClassTag ,
341+ @ specialized(Int , Long ) Item : ClassTag ](
337342 ratings : RDD [Rating [User , Item ]],
338343 rank : Int = 10 ,
339344 numUserBlocks : Int = 10 ,
340345 numItemBlocks : Int = 10 ,
341346 maxIter : Int = 10 ,
342347 regParam : Double = 1.0 ,
343348 implicitPrefs : Boolean = false ,
344- alpha : Double = 1.0 ): (RDD [(User , Array [Float ])], RDD [(Item , Array [Float ])]) = {
349+ alpha : Double = 1.0 )(
350+ implicit userOrd : Ordering [User ],
351+ itemOrd : Ordering [Item ]): (RDD [(User , Array [Float ])], RDD [(Item , Array [Float ])]) = {
345352 val userPart = new HashPartitioner (numUserBlocks)
346353 val itemPart = new HashPartitioner (numItemBlocks)
347354 val userLocalIndexEncoder = new LocalIndexEncoder (userPart.numPartitions)
@@ -444,8 +451,8 @@ private[recommendation] object ALS extends Logging {
444451 *
445452 * @see [[LocalIndexEncoder ]]
446453 */
447- private [recommendation] case class InBlock [@ specialized(Int , Long ) SrcType ](
448- srcIds : Array [SrcType ],
454+ private [recommendation] case class InBlock [@ specialized(Int , Long ) Src : ClassTag ](
455+ srcIds : Array [Src ],
449456 dstPtrs : Array [Int ],
450457 dstEncodedIndices : Array [Int ],
451458 ratings : Array [Float ]) {
@@ -463,8 +470,8 @@ private[recommendation] object ALS extends Logging {
463470 * @param rank rank
464471 * @return initialized factor blocks
465472 */
466- private def initialize [@ specialized(Int , Long ) SrcType ](
467- inBlocks : RDD [(Int , InBlock [SrcType ])],
473+ private def initialize [@ specialized(Int , Long ) Src ](
474+ inBlocks : RDD [(Int , InBlock [Src ])],
468475 rank : Int ): RDD [(Int , FactorBlock )] = {
469476 // Choose a unit vector uniformly at random from the unit sphere, but from the
470477 // "first quadrant" where all elements are nonnegative. This can be done by choosing
@@ -487,9 +494,9 @@ private[recommendation] object ALS extends Logging {
487494 * A rating block that contains src IDs, dst IDs, and ratings, stored in primitive arrays.
488495 */
489496 private [recommendation]
490- case class RatingBlock [@ specialized(Int , Long ) SrcType , @ specialized(Int , Long ) DstType ](
491- srcIds : Array [SrcType ],
492- dstIds : Array [DstType ],
497+ case class RatingBlock [@ specialized(Int , Long ) Src , @ specialized(Int , Long ) Dst ](
498+ srcIds : Array [Src ],
499+ dstIds : Array [Dst ],
493500 ratings : Array [Float ]) {
494501 /** Size of the block. */
495502 val size : Int = srcIds.size
@@ -501,16 +508,16 @@ private[recommendation] object ALS extends Logging {
501508 * Builder for [[RatingBlock ]]. [[mutable.ArrayBuilder ]] is used to avoid boxing/unboxing.
502509 */
503510 private [recommendation] class RatingBlockBuilder [
504- @ specialized(Int , Long ) SrcType ,
505- @ specialized(Int , Long ) DstType ] extends Serializable {
511+ @ specialized(Int , Long ) Src : ClassTag ,
512+ @ specialized(Int , Long ) Dst : ClassTag ] extends Serializable {
506513
507- private val srcIds = mutable.ArrayBuilder .make[SrcType ]
508- private val dstIds = mutable.ArrayBuilder .make[DstType ]
514+ private val srcIds = mutable.ArrayBuilder .make[Src ]
515+ private val dstIds = mutable.ArrayBuilder .make[Dst ]
509516 private val ratings = mutable.ArrayBuilder .make[Float ]
510517 var size = 0
511518
512519 /** Adds a rating. */
513- def add (r : Rating [SrcType , DstType ]): this .type = {
520+ def add (r : Rating [Src , Dst ]): this .type = {
514521 size += 1
515522 srcIds += r.user
516523 dstIds += r.item
@@ -519,7 +526,7 @@ private[recommendation] object ALS extends Logging {
519526 }
520527
521528 /** Merges another [[RatingBlockBuilder ]]. */
522- def merge (other : RatingBlock [SrcType , DstType ]): this .type = {
529+ def merge (other : RatingBlock [Src , Dst ]): this .type = {
523530 size += other.srcIds.size
524531 srcIds ++= other.srcIds
525532 dstIds ++= other.dstIds
@@ -528,8 +535,8 @@ private[recommendation] object ALS extends Logging {
528535 }
529536
530537 /** Builds a [[RatingBlock ]]. */
531- def build (): RatingBlock [SrcType , DstType ] = {
532- RatingBlock [SrcType , DstType ](srcIds.result(), dstIds.result(), ratings.result())
538+ def build (): RatingBlock [Src , Dst ] = {
539+ RatingBlock [Src , Dst ](srcIds.result(), dstIds.result(), ratings.result())
533540 }
534541 }
535542
@@ -542,7 +549,9 @@ private[recommendation] object ALS extends Logging {
542549 *
543550 * @return an RDD of rating blocks in the form of ((srcBlockId, dstBlockId), ratingBlock)
544551 */
545- private def partitionRatings [@ specialized(Int , Long ) User , @ specialized(Int , Long ) Item ](
552+ private def partitionRatings [
553+ @ specialized(Int , Long ) User : ClassTag ,
554+ @ specialized(Int , Long ) Item : ClassTag ](
546555 ratings : RDD [Rating [User , Item ]],
547556 srcPart : Partitioner ,
548557 dstPart : Partitioner ): RDD [((Int , Int ), RatingBlock [User , Item ])] = {
@@ -590,10 +599,11 @@ private[recommendation] object ALS extends Logging {
590599 * Builder for uncompressed in-blocks of (srcId, dstEncodedIndex, rating) tuples.
591600 * @param encoder encoder for dst indices
592601 */
593- private [recommendation] class UncompressedInBlockBuilder [@ specialized(Int , Long ) SrcType ](
594- encoder : LocalIndexEncoder ) {
602+ private [recommendation] class UncompressedInBlockBuilder [@ specialized(Int , Long ) Src : ClassTag ](
603+ encoder : LocalIndexEncoder )(
604+ implicit ord : Ordering [Src ]) {
595605
596- private val srcIds = mutable.ArrayBuilder .make[SrcType ]
606+ private val srcIds = mutable.ArrayBuilder .make[Src ]
597607 private val dstEncodedIndices = mutable.ArrayBuilder .make[Int ]
598608 private val ratings = mutable.ArrayBuilder .make[Float ]
599609
@@ -607,7 +617,7 @@ private[recommendation] object ALS extends Logging {
607617 */
608618 def add (
609619 dstBlockId : Int ,
610- srcIds : Array [SrcType ],
620+ srcIds : Array [Src ],
611621 dstLocalIndices : Array [Int ],
612622 ratings : Array [Float ]): this .type = {
613623 val sz = srcIds.size
@@ -624,18 +634,19 @@ private[recommendation] object ALS extends Logging {
624634 }
625635
626636 /** Builds a [[UncompressedInBlock ]]. */
627- def build (): UncompressedInBlock [SrcType ] = {
637+ def build (): UncompressedInBlock [Src ] = {
628638 new UncompressedInBlock (srcIds.result(), dstEncodedIndices.result(), ratings.result())
629639 }
630640 }
631641
632642 /**
633643 * A block of (srcId, dstEncodedIndex, rating) tuples stored in primitive arrays.
634644 */
635- private [recommendation] class UncompressedInBlock [@ specialized(Int , Long ) SrcType ](
636- val srcIds : Array [SrcType ],
645+ private [recommendation] class UncompressedInBlock [@ specialized(Int , Long ) Src : ClassTag ](
646+ val srcIds : Array [Src ],
637647 val dstEncodedIndices : Array [Int ],
638- val ratings : Array [Float ]) {
648+ val ratings : Array [Float ])(
649+ implicit ord : Ordering [Src ]) {
639650
640651 /** Size the of block. */
641652 def size : Int = srcIds.size
@@ -645,11 +656,11 @@ private[recommendation] object ALS extends Logging {
645656 * sparse matrix from coordinate list (COO) format into compressed sparse column (CSC) format.
646657 * Sorting is done using Spark's built-in Timsort to avoid generating too many objects.
647658 */
648- def compress (): InBlock [SrcType ] = {
659+ def compress (): InBlock [Src ] = {
649660 val sz = size
650661 assert(sz > 0 , " Empty in-link block should not exist." )
651662 sort()
652- val uniqueSrcIdsBuilder = mutable.ArrayBuilder .make[SrcType ]
663+ val uniqueSrcIdsBuilder = mutable.ArrayBuilder .make[Src ]
653664 val dstCountsBuilder = mutable.ArrayBuilder .make[Int ]
654665 var preSrcId = srcIds(0 )
655666 uniqueSrcIdsBuilder += preSrcId
@@ -689,8 +700,8 @@ private[recommendation] object ALS extends Logging {
689700 val sortId = Utils .random.nextInt()
690701 logDebug(s " Start sorting an uncompressed in-block of size $sz. (sortId = $sortId) " )
691702 val start = System .nanoTime()
692- val sorter = new Sorter (new UncompressedInBlockSort [SrcType ])
693- sorter.sort(this , 0 , size, Ordering [KeyWrapper [SrcType ]])
703+ val sorter = new Sorter (new UncompressedInBlockSort [Src ])
704+ sorter.sort(this , 0 , size, Ordering [KeyWrapper [Src ]])
694705 val duration = (System .nanoTime() - start) / 1e9
695706 logDebug(s " Sorting took $duration seconds. (sortId = $sortId) " )
696707 }
@@ -701,13 +712,13 @@ private[recommendation] object ALS extends Logging {
701712 *
702713 * @see [[UncompressedInBlockSort ]]
703714 */
704- private class KeyWrapper [@ specialized(Int , Long ) KeyType <: Ordered [ KeyType ]]
705- extends Ordered [KeyWrapper [KeyType ]] {
715+ private class KeyWrapper [@ specialized(Int , Long ) KeyType : ClassTag ](
716+ implicit ord : Ordering [ KeyType ]) extends Ordered [KeyWrapper [KeyType ]] {
706717
707- private var key : KeyType = _
718+ var key : KeyType = _
708719
709720 override def compare (that : KeyWrapper [KeyType ]): Int = {
710- key .compare(that.key)
721+ ord .compare(key, that.key)
711722 }
712723
713724 def setKey (key : KeyType ): this .type = {
@@ -719,15 +730,16 @@ private[recommendation] object ALS extends Logging {
719730 /**
720731 * [[SortDataFormat ]] of [[UncompressedInBlock ]] used by [[Sorter ]].
721732 */
722- private class UncompressedInBlockSort [@ specialized(Int , Long ) SrcType ]
723- extends SortDataFormat [KeyWrapper [SrcType ], UncompressedInBlock [SrcType ]] {
733+ private class UncompressedInBlockSort [@ specialized(Int , Long ) Src : ClassTag ](
734+ implicit ord : Ordering [Src ])
735+ extends SortDataFormat [KeyWrapper [Src ], UncompressedInBlock [Src ]] {
724736
725- override def newKey (): KeyWrapper [SrcType ] = new KeyWrapper ()
737+ override def newKey (): KeyWrapper [Src ] = new KeyWrapper ()
726738
727739 override def getKey (
728- data : UncompressedInBlock [SrcType ],
740+ data : UncompressedInBlock [Src ],
729741 pos : Int ,
730- reuse : KeyWrapper [SrcType ]): KeyWrapper [SrcType ] = {
742+ reuse : KeyWrapper [Src ]): KeyWrapper [Src ] = {
731743 if (reuse == null ) {
732744 new KeyWrapper ().setKey(data.srcIds(pos))
733745 } else {
@@ -736,8 +748,8 @@ private[recommendation] object ALS extends Logging {
736748 }
737749
738750 override def getKey (
739- data : UncompressedInBlock [SrcType ],
740- pos : Int ): KeyWrapper [SrcType ] = {
751+ data : UncompressedInBlock [Src ],
752+ pos : Int ): KeyWrapper [Src ] = {
741753 getKey(data, pos, null )
742754 }
743755
@@ -750,32 +762,32 @@ private[recommendation] object ALS extends Logging {
750762 data(pos1) = tmp
751763 }
752764
753- override def swap (data : UncompressedInBlock [SrcType ], pos0 : Int , pos1 : Int ): Unit = {
765+ override def swap (data : UncompressedInBlock [Src ], pos0 : Int , pos1 : Int ): Unit = {
754766 swapElements(data.srcIds, pos0, pos1)
755767 swapElements(data.dstEncodedIndices, pos0, pos1)
756768 swapElements(data.ratings, pos0, pos1)
757769 }
758770
759771 override def copyRange (
760- src : UncompressedInBlock [SrcType ],
772+ src : UncompressedInBlock [Src ],
761773 srcPos : Int ,
762- dst : UncompressedInBlock [SrcType ],
774+ dst : UncompressedInBlock [Src ],
763775 dstPos : Int ,
764776 length : Int ): Unit = {
765777 System .arraycopy(src.srcIds, srcPos, dst.srcIds, dstPos, length)
766778 System .arraycopy(src.dstEncodedIndices, srcPos, dst.dstEncodedIndices, dstPos, length)
767779 System .arraycopy(src.ratings, srcPos, dst.ratings, dstPos, length)
768780 }
769781
770- override def allocate (length : Int ): UncompressedInBlock [SrcType ] = {
782+ override def allocate (length : Int ): UncompressedInBlock [Src ] = {
771783 new UncompressedInBlock (
772- new Array [SrcType ](length), new Array [Int ](length), new Array [Float ](length))
784+ new Array [Src ](length), new Array [Int ](length), new Array [Float ](length))
773785 }
774786
775787 override def copyElement (
776- src : UncompressedInBlock [SrcType ],
788+ src : UncompressedInBlock [Src ],
777789 srcPos : Int ,
778- dst : UncompressedInBlock [SrcType ],
790+ dst : UncompressedInBlock [Src ],
779791 dstPos : Int ): Unit = {
780792 dst.srcIds(dstPos) = src.srcIds(srcPos)
781793 dst.dstEncodedIndices(dstPos) = src.dstEncodedIndices(srcPos)
@@ -791,20 +803,23 @@ private[recommendation] object ALS extends Logging {
791803 * @param dstPart partitioner for dst IDs
792804 * @return (in-blocks, out-blocks)
793805 */
794- private def makeBlocks [@ specialized(Int , Long ) SrcType , @ specialized(Int , Long ) DstType ](
806+ private def makeBlocks [
807+ @ specialized(Int , Long ) Src : ClassTag ,
808+ @ specialized(Int , Long ) Dst : ClassTag ](
795809 prefix : String ,
796- ratingBlocks : RDD [((Int , Int ), RatingBlock [SrcType , DstType ])],
810+ ratingBlocks : RDD [((Int , Int ), RatingBlock [Src , Dst ])],
797811 srcPart : Partitioner ,
798812 dstPart : Partitioner )(
799- implicit ord : Ordering [DstType ]): (RDD [(Int , InBlock [SrcType ])], RDD [(Int , OutBlock )]) = {
813+ implicit srcOrd : Ordering [Src ],
814+ dstOrd : Ordering [Dst ]): (RDD [(Int , InBlock [Src ])], RDD [(Int , OutBlock )]) = {
800815 val inBlocks = ratingBlocks.map {
801816 case ((srcBlockId, dstBlockId), RatingBlock (srcIds, dstIds, ratings)) =>
802817 // The implementation is a faster version of
803818 // val dstIdToLocalIndex = dstIds.toSet.toSeq.sorted.zipWithIndex.toMap
804819 val start = System .nanoTime()
805- val dstIdSet = new OpenHashSet [DstType ](1 << 20 )
820+ val dstIdSet = new OpenHashSet [Dst ](1 << 20 )
806821 dstIds.foreach(dstIdSet.add)
807- val sortedDstIds = new Array [DstType ](dstIdSet.size)
822+ val sortedDstIds = new Array [Dst ](dstIdSet.size)
808823 var i = 0
809824 var pos = dstIdSet.nextPos(0 )
810825 while (pos != - 1 ) {
@@ -813,8 +828,8 @@ private[recommendation] object ALS extends Logging {
813828 i += 1
814829 }
815830 assert(i == dstIdSet.size)
816- ju. Arrays .sort (sortedDstIds, ord )
817- val dstIdToLocalIndex = new OpenHashMap [DstType , Int ](sortedDstIds.size)
831+ Sorting .quickSort (sortedDstIds)
832+ val dstIdToLocalIndex = new OpenHashMap [Dst , Int ](sortedDstIds.size)
818833 i = 0
819834 while (i < sortedDstIds.size) {
820835 dstIdToLocalIndex.update(sortedDstIds(i), i)
@@ -827,7 +842,7 @@ private[recommendation] object ALS extends Logging {
827842 }.groupByKey(new HashPartitioner (srcPart.numPartitions))
828843 .mapValues { iter =>
829844 val builder =
830- new UncompressedInBlockBuilder [SrcType ](new LocalIndexEncoder (dstPart.numPartitions))
845+ new UncompressedInBlockBuilder [Src ](new LocalIndexEncoder (dstPart.numPartitions))
831846 iter.foreach { case (dstBlockId, srcIds, dstLocalIndices, ratings) =>
832847 builder.add(dstBlockId, srcIds, dstLocalIndices, ratings)
833848 }
@@ -872,10 +887,10 @@ private[recommendation] object ALS extends Logging {
872887 *
873888 * @return dst factors
874889 */
875- private def computeFactors [@ specialized(Int , Long ) SrcType ](
890+ private def computeFactors [@ specialized(Int , Long ) Src ](
876891 srcFactorBlocks : RDD [(Int , FactorBlock )],
877892 srcOutBlocks : RDD [(Int , OutBlock )],
878- dstInBlocks : RDD [(Int , InBlock [SrcType ])],
893+ dstInBlocks : RDD [(Int , InBlock [Src ])],
879894 rank : Int ,
880895 regParam : Double ,
881896 srcEncoder : LocalIndexEncoder ,
0 commit comments