@@ -29,10 +29,8 @@ import org.apache.spark.{HashPartitioner, Logging, Partitioner}
2929import org .apache .spark .ml .{Estimator , Model }
3030import org .apache .spark .ml .param ._
3131import org .apache .spark .rdd .RDD
32- import org .apache .spark .sql .SchemaRDD
33- import org .apache .spark .sql .catalyst .dsl ._
34- import org .apache .spark .sql .catalyst .expressions .Cast
35- import org .apache .spark .sql .catalyst .plans .LeftOuter
32+ import org .apache .spark .sql .{Column , DataFrame }
33+ import org .apache .spark .sql .dsl ._
3634import org .apache .spark .sql .types .{DoubleType , FloatType , IntegerType , StructField , StructType }
3735import org .apache .spark .util .Utils
3836import org .apache .spark .util .collection .{OpenHashMap , OpenHashSet , SortDataFormat , Sorter }
@@ -112,21 +110,21 @@ class ALSModel private[ml] (
112110
113111 def setPredictionCol (value : String ): this .type = set(predictionCol, value)
114112
115- override def transform (dataset : SchemaRDD , paramMap : ParamMap ): SchemaRDD = {
113+ override def transform (dataset : DataFrame , paramMap : ParamMap ): DataFrame = {
116114 import dataset .sqlContext ._
117115 import org .apache .spark .ml .recommendation .ALSModel .Factor
118116 val map = this .paramMap ++ paramMap
119117 // TODO: Add DSL to simplify the code here.
120118 val instanceTable = s " instance_ $uid"
121119 val userTable = s " user_ $uid"
122120 val itemTable = s " item_ $uid"
123- val instances = dataset.as(Symbol ( instanceTable) )
121+ val instances = dataset.as(instanceTable)
124122 val users = userFactors.map { case (id, features) =>
125123 Factor (id, features)
126- }.as(Symbol ( userTable) )
124+ }.as(userTable)
127125 val items = itemFactors.map { case (id, features) =>
128126 Factor (id, features)
129- }.as(Symbol ( itemTable) )
127+ }.as(itemTable)
130128 val predict : (Seq [Float ], Seq [Float ]) => Float = (userFeatures, itemFeatures) => {
131129 if (userFeatures != null && itemFeatures != null ) {
132130 blas.sdot(k, userFeatures.toArray, 1 , itemFeatures.toArray, 1 )
@@ -135,12 +133,12 @@ class ALSModel private[ml] (
135133 }
136134 }
137135 val inputColumns = dataset.schema.fieldNames
138- val prediction =
139- predict.call( s " $userTable .features " .attr, s " $itemTable .features " .attr) as map (predictionCol)
140- val outputColumns = inputColumns.map(f => s " $instanceTable. $f" .attr as f ) :+ prediction
136+ val prediction = callUDF(predict, $ " $userTable.features " , $ " $itemTable.features " )
137+ .as( map(predictionCol) )
138+ val outputColumns = inputColumns.map(f => $ " $instanceTable.$f" .as(f) ) :+ prediction
141139 instances
142- .join(users, LeftOuter , Some (map(userCol).attr === s " $userTable.id " .attr) )
143- .join(items, LeftOuter , Some (map(itemCol).attr === s " $itemTable.id " .attr) )
140+ .join(users, " left " , Column (map(userCol)) === $ " $userTable.id" )
141+ .join(items, " left " , Column (map(itemCol)) === $ " $itemTable.id" )
144142 .select(outputColumns : _* )
145143 }
146144
@@ -209,14 +207,13 @@ class ALS extends Estimator[ALSModel] with ALSParams {
209207 setMaxIter(20 )
210208 setRegParam(1.0 )
211209
212- override def fit (dataset : SchemaRDD , paramMap : ParamMap ): ALSModel = {
213- import dataset .sqlContext ._
210+ override def fit (dataset : DataFrame , paramMap : ParamMap ): ALSModel = {
214211 val map = this .paramMap ++ paramMap
215- val ratings =
216- dataset .select(map(userCol).attr, map(itemCol).attr, Cast (map(ratingCol).attr, FloatType ))
217- .map { row =>
218- new Rating (row.getInt(0 ), row.getInt(1 ), row.getFloat(2 ))
219- }
212+ val ratings = dataset
213+ .select(Column ( map(userCol)), Column ( map(itemCol)), Column (map(ratingCol)).cast( FloatType ))
214+ .map { row =>
215+ new Rating (row.getInt(0 ), row.getInt(1 ), row.getFloat(2 ))
216+ }
220217 val (userFactors, itemFactors) = ALS .train(ratings, rank = map(rank),
221218 numUserBlocks = map(numUserBlocks), numItemBlocks = map(numItemBlocks),
222219 maxIter = map(maxIter), regParam = map(regParam), implicitPrefs = map(implicitPrefs),
0 commit comments