Skip to content

Commit dce7055

Browse files
committed
add save/load for k-means for SPARK-5986
1 parent 55b1b32 commit dce7055

File tree

2 files changed

+109
-4
lines changed

2 files changed

+109
-4
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,22 @@
1717

1818
package org.apache.spark.mllib.clustering
1919

20+
import org.json4s._
21+
import org.json4s.JsonDSL._
22+
import org.json4s.jackson.JsonMethods._
23+
24+
import org.apache.spark.mllib.linalg._
25+
import org.apache.spark.mllib.util.{Loader, Saveable}
26+
import org.apache.spark.mllib.util.Loader._
27+
import org.apache.spark.sql.SQLContext
28+
import org.apache.spark.SparkContext
2029
import org.apache.spark.api.java.JavaRDD
2130
import org.apache.spark.rdd.RDD
22-
import org.apache.spark.SparkContext._
23-
import org.apache.spark.mllib.linalg.Vector
2431

2532
/**
2633
* A clustering model for K-means. Each point belongs to the cluster with the closest center.
2734
*/
28-
class KMeansModel (val clusterCenters: Array[Vector]) extends Serializable {
35+
class KMeansModel (val clusterCenters: Array[Vector]) extends Saveable with Serializable {
2936

3037
/** Total number of clusters. */
3138
def k: Int = clusterCenters.length
@@ -58,4 +65,53 @@ class KMeansModel (val clusterCenters: Array[Vector]) extends Serializable {
5865

5966
private def clusterCentersWithNorm: Iterable[VectorWithNorm] =
6067
clusterCenters.map(new VectorWithNorm(_))
68+
69+
override def save(sc: SparkContext, path: String): Unit = {
70+
KMeansModel.SaveLoadV1_0.save(sc, this, path)
71+
}
72+
73+
override protected def formatVersion: String = "1.0"
74+
}
75+
76+
object KMeansModel extends Loader[KMeansModel] {
77+
override def load(sc: SparkContext, path: String): KMeansModel = {
78+
KMeansModel.SaveLoadV1_0.load(sc, path)
79+
}
80+
81+
private[clustering]
82+
object SaveLoadV1_0 {
83+
84+
private val thisFormatVersion = "1.0"
85+
86+
private[clustering]
87+
val thisClassName = "org.apache.spark.mllib.clustering.KMeansModel"
88+
89+
/**
90+
* Saves a [[KMeansModel]], where user features are saved under `data/users` and
91+
* product features are saved under `data/products`.
92+
*/
93+
def save(sc: SparkContext, model: KMeansModel, path: String): Unit = {
94+
val sqlContext = new SQLContext(sc)
95+
val wrapper = new VectorUDT()
96+
val metadata = compact(render(
97+
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k)))
98+
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
99+
val dataRDD = sc.parallelize(model.clusterCenters).map(wrapper.serialize)
100+
sqlContext.createDataFrame(dataRDD, wrapper.sqlType).saveAsParquetFile(Loader.dataPath(path))
101+
}
102+
103+
def load(sc: SparkContext, path: String): KMeansModel = {
104+
implicit val formats = DefaultFormats
105+
val sqlContext = new SQLContext(sc)
106+
val wrapper = new VectorUDT()
107+
val (className, formatVersion, metadata) = loadMetadata(sc, path)
108+
assert(className == thisClassName)
109+
assert(formatVersion == thisFormatVersion)
110+
val k = (metadata \ "k").extract[Int]
111+
val centriods = sqlContext.parquetFile(dataPath(path))
112+
val localCentriods = centriods.collect()
113+
assert(k == localCentriods.size)
114+
new KMeansModel(localCentriods.map(wrapper.deserialize))
115+
}
116+
}
61117
}

mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ import scala.util.Random
2121

2222
import org.scalatest.FunSuite
2323

24-
import org.apache.spark.mllib.linalg.{Vector, Vectors}
24+
import org.apache.spark.util.Utils
25+
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
2526
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
2627
import org.apache.spark.mllib.util.TestingUtils._
2728

@@ -257,6 +258,54 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext {
257258
assert(predicts(0) != predicts(3))
258259
}
259260
}
261+
262+
test("model save/load") {
263+
val tempDir = Utils.createTempDir()
264+
val path = tempDir.toURI.toString
265+
266+
Array(true, false).foreach { case selector =>
267+
val model = KMeansSuite.createModel(10, 3, selector)
268+
// Save model, load it back, and compare.
269+
try {
270+
model.save(sc, path)
271+
val sameModel = KMeansModel.load(sc, path)
272+
KMeansSuite.checkEqual(model, sameModel, selector)
273+
} finally {
274+
Utils.deleteRecursively(tempDir)
275+
}
276+
}
277+
}
278+
}
279+
280+
object KMeansSuite extends FunSuite {
281+
def createModel(dim: Int, k: Int, isSparse: Boolean): KMeansModel = {
282+
val singlePoint = isSparse match {
283+
case true =>
284+
Vectors.sparse(dim, Array.empty[Int], Array.empty[Double])
285+
case _ =>
286+
Vectors.dense(Array.fill[Double](dim)(0.0))
287+
}
288+
new KMeansModel(Array.fill[Vector](k)(singlePoint))
289+
}
290+
291+
def checkEqual(a: KMeansModel, b: KMeansModel, isSparse: Boolean): Unit = {
292+
assert(a.k === b.k)
293+
isSparse match {
294+
case true =>
295+
a.clusterCenters.zip(b.clusterCenters).foreach { case (pointA, pointB) =>
296+
assert(pointA.asInstanceOf[SparseVector].size === pointB.asInstanceOf[SparseVector].size)
297+
assert(
298+
pointA.asInstanceOf[SparseVector].indices === pointB.asInstanceOf[SparseVector].indices)
299+
assert(
300+
pointA.asInstanceOf[SparseVector].values === pointB.asInstanceOf[SparseVector].values)
301+
}
302+
case _ =>
303+
a.clusterCenters.zip(b.clusterCenters).foreach { case (pointA, pointB) =>
304+
assert(
305+
pointA.asInstanceOf[DenseVector].toArray === pointB.asInstanceOf[DenseVector].toArray)
306+
}
307+
}
308+
}
260309
}
261310

262311
class KMeansClusterSuite extends FunSuite with LocalClusterSparkContext {

0 commit comments

Comments
 (0)