1717
1818package org .apache .spark .mllib .recommendation
1919
20+ import java .io .IOException
2021import java .lang .{Integer => JavaInteger }
2122
23+ import org .apache .hadoop .fs .Path
2224import org .jblas .DoubleMatrix
2325
24- import org .apache .spark .Logging
26+ import org .apache .spark .{ Logging , SparkContext }
2527import org .apache .spark .api .java .{JavaPairRDD , JavaRDD }
28+ import org .apache .spark .mllib .util .{Loader , Saveable }
2629import org .apache .spark .rdd .RDD
30+ import org .apache .spark .sql .{Row , SQLContext }
2731import org .apache .spark .storage .StorageLevel
2832
2933/**
@@ -41,7 +45,8 @@ import org.apache.spark.storage.StorageLevel
4145class MatrixFactorizationModel (
4246 val rank : Int ,
4347 val userFeatures : RDD [(Int , Array [Double ])],
44- val productFeatures : RDD [(Int , Array [Double ])]) extends Serializable with Logging {
48+ val productFeatures : RDD [(Int , Array [Double ])])
49+ extends Saveable with Serializable with Logging {
4550
4651 require(rank > 0 )
4752 validateFeatures(" User" , userFeatures)
@@ -125,6 +130,12 @@ class MatrixFactorizationModel(
125130 recommend(productFeatures.lookup(product).head, userFeatures, num)
126131 .map(t => Rating (t._1, product, t._2))
127132
133+ protected override val formatVersion : String = " 1.0"
134+
135+ override def save (sc : SparkContext , path : String ): Unit = {
136+ MatrixFactorizationModel .SaveLoadV1_0 .save(this , path)
137+ }
138+
128139 private def recommend (
129140 recommendToFeatures : Array [Double ],
130141 recommendableFeatures : RDD [(Int , Array [Double ])],
@@ -136,3 +147,70 @@ class MatrixFactorizationModel(
136147 scored.top(num)(Ordering .by(_._2))
137148 }
138149}
150+
151+ object MatrixFactorizationModel extends Loader [MatrixFactorizationModel ] {
152+
153+ import org .apache .spark .mllib .util .Loader ._
154+
155+ override def load (sc : SparkContext , path : String ): MatrixFactorizationModel = {
156+ val (loadedClassName, formatVersion, metadata) = loadMetadata(sc, path)
157+ val classNameV1_0 = SaveLoadV1_0 .thisClassName
158+ (loadedClassName, formatVersion) match {
159+ case (className, " 1.0" ) if className == classNameV1_0 =>
160+ SaveLoadV1_0 .load(sc, path)
161+ case _ =>
162+ throw new IOException (" MatrixFactorizationModel.load did not recognize model with" +
163+ s " (class: $loadedClassName, version: $formatVersion). Supported: \n " +
164+ s " ( $classNameV1_0, 1.0) " )
165+ }
166+ }
167+
168+ private [recommendation]
169+ object SaveLoadV1_0 {
170+
171+ private val thisFormatVersion = " 1.0"
172+
173+ private [recommendation]
174+ val thisClassName = " org.apache.spark.mllib.recommendation.MatrixFactorizationModel"
175+
176+ /**
177+ * Saves a [[MatrixFactorizationModel ]], where user features are saved under `data/users` and
178+ * product features are saved under `data/products`.
179+ */
180+ def save (model : MatrixFactorizationModel , path : String ): Unit = {
181+ val sc = model.userFeatures.sparkContext
182+ val sqlContext = new SQLContext (sc)
183+ import sqlContext .implicits .createDataFrame
184+ val metadata = (thisClassName, thisFormatVersion, model.rank)
185+ val metadataRDD = sc.parallelize(Seq (metadata), 1 ).toDataFrame(" class" , " version" , " rank" )
186+ metadataRDD.toJSON.saveAsTextFile(metadataPath(path))
187+ model.userFeatures.toDataFrame(" id" , " features" ).saveAsParquetFile(userPath(path))
188+ model.productFeatures.toDataFrame(" id" , " features" ).saveAsParquetFile(productPath(path))
189+ }
190+
191+ def load (sc : SparkContext , path : String ): MatrixFactorizationModel = {
192+ val sqlContext = new SQLContext (sc)
193+ val (className, formatVersion, metadata) = loadMetadata(sc, path)
194+ assert(className == thisClassName)
195+ assert(formatVersion == thisFormatVersion)
196+ val rank = metadata.select(" rank" ).first().getInt(0 )
197+ val userFeatures = sqlContext.parquetFile(userPath(path))
198+ .map { case Row (id : Int , features : Seq [Double ]) =>
199+ (id, features.toArray)
200+ }
201+ val productFeatures = sqlContext.parquetFile(productPath(path))
202+ .map { case Row (id : Int , features : Seq [Double ]) =>
203+ (id, features.toArray)
204+ }
205+ new MatrixFactorizationModel (rank, userFeatures, productFeatures)
206+ }
207+
208+ private def userPath (path : String ): String = {
209+ new Path (dataPath(path), " user" ).toUri.toString
210+ }
211+
212+ private def productPath (path : String ): String = {
213+ new Path (dataPath(path), " product" ).toUri.toString
214+ }
215+ }
216+ }
0 commit comments