Skip to content

Commit 62fc43c

Browse files
committed
implement save/load for MFM
1 parent 4d8d070 commit 62fc43c

File tree

2 files changed

+63
-3
lines changed

2 files changed

+63
-3
lines changed

mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.mllib.recommendation
1919

2020
import org.apache.spark.Logging
21-
import org.apache.spark.annotation.{DeveloperApi, Experimental}
21+
import org.apache.spark.annotation.DeveloperApi
2222
import org.apache.spark.api.java.JavaRDD
2323
import org.apache.spark.ml.recommendation.{ALS => NewALS}
2424
import org.apache.spark.rdd.RDD

mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,15 @@ package org.apache.spark.mllib.recommendation
1919

2020
import java.lang.{Integer => JavaInteger}
2121

22+
import org.apache.hadoop.fs.Path
2223
import org.jblas.DoubleMatrix
2324

24-
import org.apache.spark.Logging
25+
import org.apache.spark.{Logging, SparkContext}
2526
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
27+
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel.SaveLoadV1_0
28+
import org.apache.spark.mllib.util.{Loader, Saveable}
2629
import org.apache.spark.rdd.RDD
30+
import org.apache.spark.sql.{Row, SQLContext}
2731
import org.apache.spark.storage.StorageLevel
2832

2933
/**
@@ -41,7 +45,8 @@ import org.apache.spark.storage.StorageLevel
4145
class MatrixFactorizationModel(
4246
val rank: Int,
4347
val userFeatures: RDD[(Int, Array[Double])],
44-
val productFeatures: RDD[(Int, Array[Double])]) extends Serializable with Logging {
48+
val productFeatures: RDD[(Int, Array[Double])])
49+
extends Saveable with Serializable with Logging {
4550

4651
require(rank > 0)
4752
validateFeatures("User", userFeatures)
@@ -125,6 +130,11 @@ class MatrixFactorizationModel(
125130
recommend(productFeatures.lookup(product).head, userFeatures, num)
126131
.map(t => Rating(t._1, product, t._2))
127132

133+
134+
override def save(sc: SparkContext, path: String): Unit = {
135+
SaveLoadV1_0.save(this, path)
136+
}
137+
128138
private def recommend(
129139
recommendToFeatures: Array[Double],
130140
recommendableFeatures: RDD[(Int, Array[Double])],
@@ -136,3 +146,53 @@ class MatrixFactorizationModel(
136146
scored.top(num)(Ordering.by(_._2))
137147
}
138148
}
149+
150+
private object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
151+
152+
import org.apache.spark.mllib.util.Loader._
153+
154+
private object SaveLoadV1_0 {
155+
156+
private val thisFormatVersion = "1.0"
157+
158+
private val thisClassName = "org.apache.spark.mllib.recommendation.MatrixFactorizationModel"
159+
160+
def save(model: MatrixFactorizationModel, path: String): Unit = {
161+
val sc = model.userFeatures.sparkContext
162+
val sqlContext = new SQLContext(sc)
163+
import sqlContext.implicits.createDataFrame
164+
val metadata = (thisClassName, thisFormatVersion, model.rank)
165+
val metadataRDD = sc.parallelize(Seq(metadata), 1).toDataFrame("class", "version", "rank")
166+
metadataRDD.toJSON.saveAsTextFile(metadataPath(path))
167+
model.userFeatures.toDataFrame("id", "features").saveAsParquetFile(userPath(path))
168+
model.productFeatures.toDataFrame("id", "features").saveAsParquetFile(productPath(path))
169+
}
170+
171+
override def load(sc: SparkContext, path: String): MatrixFactorizationModel = {
172+
val sqlContext = new SQLContext(sc)
173+
val (className, formatVersion, metadata) = loadMetadata(sc, path)
174+
assert(className == thisClassName)
175+
assert(formatVersion == thisFormatVersion)
176+
val rank = metadata.select("rank").map { case Row(r: Int) =>
177+
r
178+
}.first()
179+
val userFeatures = sqlContext.parquetFile(userPath(path))
180+
.map { case Row(id: Int, features: Seq[Double]) =>
181+
(id, features.toArray)
182+
}
183+
val productFeatures = sqlContext.parquetFile(productPath(path))
184+
.map { case Row(id: Int, features: Seq[Double]) =>
185+
(id, features.toArray)
186+
}
187+
new MatrixFactorizationModel(r, userFeatures, productFeatures)
188+
}
189+
190+
private def userPath(path: String): String = {
191+
new Path(dataPath(path), "user").toUri.toString
192+
}
193+
194+
private def productPath(path: String): String = {
195+
new Path(dataPath(path), "product").toUri.toString
196+
}
197+
}
198+
}

0 commit comments

Comments
 (0)