Skip to content

Commit f487cb2

Browse files
committed
add unit tests
1 parent 62fc43c commit f487cb2

File tree

2 files changed

+44
-8
lines changed

2 files changed

+44
-8
lines changed

mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@
1717

1818
package org.apache.spark.mllib.recommendation
1919

20+
import java.io.IOException
2021
import java.lang.{Integer => JavaInteger}
2122

2223
import org.apache.hadoop.fs.Path
2324
import org.jblas.DoubleMatrix
2425

2526
import org.apache.spark.{Logging, SparkContext}
2627
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
27-
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel.SaveLoadV1_0
2828
import org.apache.spark.mllib.util.{Loader, Saveable}
2929
import org.apache.spark.rdd.RDD
3030
import org.apache.spark.sql.{Row, SQLContext}
@@ -130,9 +130,10 @@ class MatrixFactorizationModel(
130130
recommend(productFeatures.lookup(product).head, userFeatures, num)
131131
.map(t => Rating(t._1, product, t._2))
132132

133+
override val formatVersion: String = "1.0"
133134

134135
override def save(sc: SparkContext, path: String): Unit = {
135-
SaveLoadV1_0.save(this, path)
136+
MatrixFactorizationModel.SaveLoadV1_0.save(this, path)
136137
}
137138

138139
private def recommend(
@@ -151,12 +152,30 @@ private object MatrixFactorizationModel extends Loader[MatrixFactorizationModel]
151152

152153
import org.apache.spark.mllib.util.Loader._
153154

154-
private object SaveLoadV1_0 {
155+
override def load(sc: SparkContext, path: String): MatrixFactorizationModel = {
156+
val (loadedClassName, formatVersion, metadata) = loadMetadata(sc, path)
157+
val classNameV1_0 = SaveLoadV1_0.thisClassName
158+
(loadedClassName, formatVersion) match {
159+
case (className, "1.0") if className == classNameV1_0 =>
160+
SaveLoadV1_0.load(sc, path)
161+
case _ =>
162+
throw new IOException("" +
163+
"MatrixFactorizationModel.load did not recognize model with" +
164+
s"(class: $loadedClassName, version: $formatVersion). Supported:\n" +
165+
s" ($classNameV1_0, 1.0)")
166+
}
167+
}
168+
169+
private object SaveLoadV1_0 extends Loader[MatrixFactorizationModel] {
155170

156171
private val thisFormatVersion = "1.0"
157172

158-
private val thisClassName = "org.apache.spark.mllib.recommendation.MatrixFactorizationModel"
173+
val thisClassName = "org.apache.spark.mllib.recommendation.MatrixFactorizationModel"
159174

175+
/**
176+
* Saves a [[MatrixFactorizationModel]], where user features are saved under `data/users` and
177+
* product features are saved under `data/products`.
178+
*/
160179
def save(model: MatrixFactorizationModel, path: String): Unit = {
161180
val sc = model.userFeatures.sparkContext
162181
val sqlContext = new SQLContext(sc)
@@ -173,9 +192,7 @@ private object MatrixFactorizationModel extends Loader[MatrixFactorizationModel]
173192
val (className, formatVersion, metadata) = loadMetadata(sc, path)
174193
assert(className == thisClassName)
175194
assert(formatVersion == thisFormatVersion)
176-
val rank = metadata.select("rank").map { case Row(r: Int) =>
177-
r
178-
}.first()
195+
val rank = metadata.select("rank").first().getInt(0)
179196
val userFeatures = sqlContext.parquetFile(userPath(path))
180197
.map { case Row(id: Int, features: Seq[Double]) =>
181198
(id, features.toArray)
@@ -184,7 +201,7 @@ private object MatrixFactorizationModel extends Loader[MatrixFactorizationModel]
184201
.map { case Row(id: Int, features: Seq[Double]) =>
185202
(id, features.toArray)
186203
}
187-
new MatrixFactorizationModel(r, userFeatures, productFeatures)
204+
new MatrixFactorizationModel(rank, userFeatures, productFeatures)
188205
}
189206

190207
private def userPath(path: String): String = {

mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import org.scalatest.FunSuite
2222
import org.apache.spark.mllib.util.MLlibTestSparkContext
2323
import org.apache.spark.mllib.util.TestingUtils._
2424
import org.apache.spark.rdd.RDD
25+
import org.apache.spark.util.Utils
2526

2627
class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext {
2728

@@ -53,4 +54,22 @@ class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext
5354
new MatrixFactorizationModel(rank, userFeatures, prodFeatures1)
5455
}
5556
}
57+
58+
test("save/load") {
59+
val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures)
60+
val tempDir = Utils.createTempDir()
61+
val path = tempDir.toURI.toString
62+
def collect(features: RDD[(Int, Array[Double])]): Set[(Int, Seq[Double])] = {
63+
features.mapValues(_.toSeq).collect().toSet
64+
}
65+
try {
66+
model.save(sc, path)
67+
val newModel = MatrixFactorizationModel.load(sc, path)
68+
assert(newModel.rank === rank)
69+
assert(collect(newModel.userFeatures) === collect(userFeatures))
70+
assert(collect(newModel.productFeatures) === collect(prodFeatures))
71+
} finally {
72+
Utils.deleteRecursively(tempDir)
73+
}
74+
}
5675
}

0 commit comments

Comments
 (0)