@@ -19,11 +19,15 @@ package org.apache.spark.mllib.recommendation
1919
2020import java .lang .{Integer => JavaInteger }
2121
22+ import org .apache .hadoop .fs .Path
2223import org .jblas .DoubleMatrix
2324
24- import org .apache .spark .Logging
25+ import org .apache .spark .{ Logging , SparkContext }
2526import org .apache .spark .api .java .{JavaPairRDD , JavaRDD }
27+ import org .apache .spark .mllib .recommendation .MatrixFactorizationModel .SaveLoadV1_0
28+ import org .apache .spark .mllib .util .{Loader , Saveable }
2629import org .apache .spark .rdd .RDD
30+ import org .apache .spark .sql .{Row , SQLContext }
2731import org .apache .spark .storage .StorageLevel
2832
2933/**
@@ -41,7 +45,8 @@ import org.apache.spark.storage.StorageLevel
4145class MatrixFactorizationModel (
4246 val rank : Int ,
4347 val userFeatures : RDD [(Int , Array [Double ])],
44- val productFeatures : RDD [(Int , Array [Double ])]) extends Serializable with Logging {
48+ val productFeatures : RDD [(Int , Array [Double ])])
49+ extends Saveable with Serializable with Logging {
4550
4651 require(rank > 0 )
4752 validateFeatures(" User" , userFeatures)
@@ -125,6 +130,11 @@ class MatrixFactorizationModel(
125130 recommend(productFeatures.lookup(product).head, userFeatures, num)
126131 .map(t => Rating (t._1, product, t._2))
127132
133+
134+ override def save (sc : SparkContext , path : String ): Unit = {
135+ SaveLoadV1_0 .save(this , path)
136+ }
137+
128138 private def recommend (
129139 recommendToFeatures : Array [Double ],
130140 recommendableFeatures : RDD [(Int , Array [Double ])],
@@ -136,3 +146,53 @@ class MatrixFactorizationModel(
136146 scored.top(num)(Ordering .by(_._2))
137147 }
138148}
149+
150+ private object MatrixFactorizationModel extends Loader [MatrixFactorizationModel ] {
151+
152+ import org .apache .spark .mllib .util .Loader ._
153+
154+ private object SaveLoadV1_0 {
155+
156+ private val thisFormatVersion = " 1.0"
157+
158+ private val thisClassName = " org.apache.spark.mllib.recommendation.MatrixFactorizationModel"
159+
160+ def save (model : MatrixFactorizationModel , path : String ): Unit = {
161+ val sc = model.userFeatures.sparkContext
162+ val sqlContext = new SQLContext (sc)
163+ import sqlContext .implicits .createDataFrame
164+ val metadata = (thisClassName, thisFormatVersion, model.rank)
165+ val metadataRDD = sc.parallelize(Seq (metadata), 1 ).toDataFrame(" class" , " version" , " rank" )
166+ metadataRDD.toJSON.saveAsTextFile(metadataPath(path))
167+ model.userFeatures.toDataFrame(" id" , " features" ).saveAsParquetFile(userPath(path))
168+ model.productFeatures.toDataFrame(" id" , " features" ).saveAsParquetFile(productPath(path))
169+ }
170+
171+ override def load (sc : SparkContext , path : String ): MatrixFactorizationModel = {
172+ val sqlContext = new SQLContext (sc)
173+ val (className, formatVersion, metadata) = loadMetadata(sc, path)
174+ assert(className == thisClassName)
175+ assert(formatVersion == thisFormatVersion)
176+ val rank = metadata.select(" rank" ).map { case Row (r : Int ) =>
177+ r
178+ }.first()
179+ val userFeatures = sqlContext.parquetFile(userPath(path))
180+ .map { case Row (id : Int , features : Seq [Double ]) =>
181+ (id, features.toArray)
182+ }
183+ val productFeatures = sqlContext.parquetFile(productPath(path))
184+ .map { case Row (id : Int , features : Seq [Double ]) =>
185+ (id, features.toArray)
186+ }
187+ new MatrixFactorizationModel (r, userFeatures, productFeatures)
188+ }
189+
190+ private def userPath (path : String ): String = {
191+ new Path (dataPath(path), " user" ).toUri.toString
192+ }
193+
194+ private def productPath (path : String ): String = {
195+ new Path (dataPath(path), " product" ).toUri.toString
196+ }
197+ }
198+ }
0 commit comments