1717
1818package org .apache .spark .mllib .recommendation
1919
20+ import java .io .IOException
2021import java .lang .{Integer => JavaInteger }
2122
2223import org .apache .hadoop .fs .Path
2324import org .jblas .DoubleMatrix
2425
2526import org .apache .spark .{Logging , SparkContext }
2627import org .apache .spark .api .java .{JavaPairRDD , JavaRDD }
27- import org .apache .spark .mllib .recommendation .MatrixFactorizationModel .SaveLoadV1_0
2828import org .apache .spark .mllib .util .{Loader , Saveable }
2929import org .apache .spark .rdd .RDD
3030import org .apache .spark .sql .{Row , SQLContext }
@@ -130,9 +130,10 @@ class MatrixFactorizationModel(
130130 recommend(productFeatures.lookup(product).head, userFeatures, num)
131131 .map(t => Rating (t._1, product, t._2))
132132
133+ override val formatVersion : String = " 1.0"
133134
134135 override def save (sc : SparkContext , path : String ): Unit = {
135- SaveLoadV1_0 .save(this , path)
136+ MatrixFactorizationModel . SaveLoadV1_0 .save(this , path)
136137 }
137138
138139 private def recommend (
@@ -151,12 +152,30 @@ private object MatrixFactorizationModel extends Loader[MatrixFactorizationModel]
151152
152153 import org .apache .spark .mllib .util .Loader ._
153154
154- private object SaveLoadV1_0 {
155+ override def load (sc : SparkContext , path : String ): MatrixFactorizationModel = {
156+ val (loadedClassName, formatVersion, metadata) = loadMetadata(sc, path)
157+ val classNameV1_0 = SaveLoadV1_0 .thisClassName
158+ (loadedClassName, formatVersion) match {
159+ case (className, " 1.0" ) if className == classNameV1_0 =>
160+ SaveLoadV1_0 .load(sc, path)
161+ case _ =>
162+ throw new IOException (" " +
163+ " MatrixFactorizationModel.load did not recognize model with" +
164+ s " (class: $loadedClassName, version: $formatVersion). Supported: \n " +
165+ s " ( $classNameV1_0, 1.0) " )
166+ }
167+ }
168+
169+ private object SaveLoadV1_0 extends Loader [MatrixFactorizationModel ] {
155170
156171 private val thisFormatVersion = " 1.0"
157172
158- private val thisClassName = " org.apache.spark.mllib.recommendation.MatrixFactorizationModel"
173+ val thisClassName = " org.apache.spark.mllib.recommendation.MatrixFactorizationModel"
159174
175+ /**
176+ * Saves a [[MatrixFactorizationModel ]], where user features are saved under `data/users` and
177+ * product features are saved under `data/products`.
178+ */
160179 def save (model : MatrixFactorizationModel , path : String ): Unit = {
161180 val sc = model.userFeatures.sparkContext
162181 val sqlContext = new SQLContext (sc)
@@ -173,9 +192,7 @@ private object MatrixFactorizationModel extends Loader[MatrixFactorizationModel]
173192 val (className, formatVersion, metadata) = loadMetadata(sc, path)
174193 assert(className == thisClassName)
175194 assert(formatVersion == thisFormatVersion)
176- val rank = metadata.select(" rank" ).map { case Row (r : Int ) =>
177- r
178- }.first()
195+ val rank = metadata.select(" rank" ).first().getInt(0 )
179196 val userFeatures = sqlContext.parquetFile(userPath(path))
180197 .map { case Row (id : Int , features : Seq [Double ]) =>
181198 (id, features.toArray)
@@ -184,7 +201,7 @@ private object MatrixFactorizationModel extends Loader[MatrixFactorizationModel]
184201 .map { case Row (id : Int , features : Seq [Double ]) =>
185202 (id, features.toArray)
186203 }
187- new MatrixFactorizationModel (r , userFeatures, productFeatures)
204+ new MatrixFactorizationModel (rank , userFeatures, productFeatures)
188205 }
189206
190207 private def userPath (path : String ): String = {
0 commit comments