Skip to content

Commit 5c299c5

Browse files
committed
[SPARK-5598][MLLIB] model save/load for ALS
following apache#4233. jkbradley Author: Xiangrui Meng <[email protected]> Closes apache#4422 from mengxr/SPARK-5598 and squashes the following commits: a059394 [Xiangrui Meng] SaveLoad not extending Loader 14b7ea6 [Xiangrui Meng] address comments f487cb2 [Xiangrui Meng] add unit tests 62fc43c [Xiangrui Meng] implement save/load for MFM
1 parent 804949d commit 5c299c5

File tree

3 files changed

+100
-3
lines changed

3 files changed

+100
-3
lines changed

mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.mllib.recommendation
1919

2020
import org.apache.spark.Logging
21-
import org.apache.spark.annotation.{DeveloperApi, Experimental}
21+
import org.apache.spark.annotation.DeveloperApi
2222
import org.apache.spark.api.java.JavaRDD
2323
import org.apache.spark.ml.recommendation.{ALS => NewALS}
2424
import org.apache.spark.rdd.RDD

mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,17 @@
1717

1818
package org.apache.spark.mllib.recommendation
1919

20+
import java.io.IOException
2021
import java.lang.{Integer => JavaInteger}
2122

23+
import org.apache.hadoop.fs.Path
2224
import org.jblas.DoubleMatrix
2325

24-
import org.apache.spark.Logging
26+
import org.apache.spark.{Logging, SparkContext}
2527
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
28+
import org.apache.spark.mllib.util.{Loader, Saveable}
2629
import org.apache.spark.rdd.RDD
30+
import org.apache.spark.sql.{Row, SQLContext}
2731
import org.apache.spark.storage.StorageLevel
2832

2933
/**
@@ -41,7 +45,8 @@ import org.apache.spark.storage.StorageLevel
4145
class MatrixFactorizationModel(
4246
val rank: Int,
4347
val userFeatures: RDD[(Int, Array[Double])],
44-
val productFeatures: RDD[(Int, Array[Double])]) extends Serializable with Logging {
48+
val productFeatures: RDD[(Int, Array[Double])])
49+
extends Saveable with Serializable with Logging {
4550

4651
require(rank > 0)
4752
validateFeatures("User", userFeatures)
@@ -125,6 +130,12 @@ class MatrixFactorizationModel(
125130
recommend(productFeatures.lookup(product).head, userFeatures, num)
126131
.map(t => Rating(t._1, product, t._2))
127132

133+
protected override val formatVersion: String = "1.0"
134+
135+
override def save(sc: SparkContext, path: String): Unit = {
136+
MatrixFactorizationModel.SaveLoadV1_0.save(this, path)
137+
}
138+
128139
private def recommend(
129140
recommendToFeatures: Array[Double],
130141
recommendableFeatures: RDD[(Int, Array[Double])],
@@ -136,3 +147,70 @@ class MatrixFactorizationModel(
136147
scored.top(num)(Ordering.by(_._2))
137148
}
138149
}
150+
151+
object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
152+
153+
import org.apache.spark.mllib.util.Loader._
154+
155+
override def load(sc: SparkContext, path: String): MatrixFactorizationModel = {
156+
val (loadedClassName, formatVersion, metadata) = loadMetadata(sc, path)
157+
val classNameV1_0 = SaveLoadV1_0.thisClassName
158+
(loadedClassName, formatVersion) match {
159+
case (className, "1.0") if className == classNameV1_0 =>
160+
SaveLoadV1_0.load(sc, path)
161+
case _ =>
162+
throw new IOException("MatrixFactorizationModel.load did not recognize model with" +
163+
s"(class: $loadedClassName, version: $formatVersion). Supported:\n" +
164+
s" ($classNameV1_0, 1.0)")
165+
}
166+
}
167+
168+
private[recommendation]
169+
object SaveLoadV1_0 {
170+
171+
private val thisFormatVersion = "1.0"
172+
173+
private[recommendation]
174+
val thisClassName = "org.apache.spark.mllib.recommendation.MatrixFactorizationModel"
175+
176+
/**
177+
* Saves a [[MatrixFactorizationModel]], where user features are saved under `data/users` and
178+
* product features are saved under `data/products`.
179+
*/
180+
def save(model: MatrixFactorizationModel, path: String): Unit = {
181+
val sc = model.userFeatures.sparkContext
182+
val sqlContext = new SQLContext(sc)
183+
import sqlContext.implicits.createDataFrame
184+
val metadata = (thisClassName, thisFormatVersion, model.rank)
185+
val metadataRDD = sc.parallelize(Seq(metadata), 1).toDataFrame("class", "version", "rank")
186+
metadataRDD.toJSON.saveAsTextFile(metadataPath(path))
187+
model.userFeatures.toDataFrame("id", "features").saveAsParquetFile(userPath(path))
188+
model.productFeatures.toDataFrame("id", "features").saveAsParquetFile(productPath(path))
189+
}
190+
191+
def load(sc: SparkContext, path: String): MatrixFactorizationModel = {
192+
val sqlContext = new SQLContext(sc)
193+
val (className, formatVersion, metadata) = loadMetadata(sc, path)
194+
assert(className == thisClassName)
195+
assert(formatVersion == thisFormatVersion)
196+
val rank = metadata.select("rank").first().getInt(0)
197+
val userFeatures = sqlContext.parquetFile(userPath(path))
198+
.map { case Row(id: Int, features: Seq[Double]) =>
199+
(id, features.toArray)
200+
}
201+
val productFeatures = sqlContext.parquetFile(productPath(path))
202+
.map { case Row(id: Int, features: Seq[Double]) =>
203+
(id, features.toArray)
204+
}
205+
new MatrixFactorizationModel(rank, userFeatures, productFeatures)
206+
}
207+
208+
private def userPath(path: String): String = {
209+
new Path(dataPath(path), "user").toUri.toString
210+
}
211+
212+
private def productPath(path: String): String = {
213+
new Path(dataPath(path), "product").toUri.toString
214+
}
215+
}
216+
}

mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import org.scalatest.FunSuite
2222
import org.apache.spark.mllib.util.MLlibTestSparkContext
2323
import org.apache.spark.mllib.util.TestingUtils._
2424
import org.apache.spark.rdd.RDD
25+
import org.apache.spark.util.Utils
2526

2627
class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext {
2728

@@ -53,4 +54,22 @@ class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext
5354
new MatrixFactorizationModel(rank, userFeatures, prodFeatures1)
5455
}
5556
}
57+
58+
test("save/load") {
59+
val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures)
60+
val tempDir = Utils.createTempDir()
61+
val path = tempDir.toURI.toString
62+
def collect(features: RDD[(Int, Array[Double])]): Set[(Int, Seq[Double])] = {
63+
features.mapValues(_.toSeq).collect().toSet
64+
}
65+
try {
66+
model.save(sc, path)
67+
val newModel = MatrixFactorizationModel.load(sc, path)
68+
assert(newModel.rank === rank)
69+
assert(collect(newModel.userFeatures) === collect(userFeatures))
70+
assert(collect(newModel.productFeatures) === collect(prodFeatures))
71+
} finally {
72+
Utils.deleteRecursively(tempDir)
73+
}
74+
}
5675
}

0 commit comments

Comments
 (0)