Skip to content

Commit 6dd74a0

Browse files
committed
rewrite some functions and classes
1 parent cd390fd commit 6dd74a0

File tree

2 files changed

+18
-26
lines changed

2 files changed

+18
-26
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@ import org.json4s.JsonDSL._
2222
import org.json4s.jackson.JsonMethods._
2323

2424
import org.apache.spark.api.java.JavaRDD
25-
import org.apache.spark.mllib.linalg._
25+
import org.apache.spark.mllib.linalg.Vector
2626
import org.apache.spark.mllib.util.{Loader, Saveable}
27-
import org.apache.spark.mllib.util.Loader._
2827
import org.apache.spark.rdd.RDD
2928
import org.apache.spark.SparkContext
3029
import org.apache.spark.sql.SQLContext
@@ -79,11 +78,11 @@ object KMeansModel extends Loader[KMeansModel] {
7978
KMeansModel.SaveLoadV1_0.load(sc, path)
8079
}
8180

82-
case class IndexedPoint(id: Int, point: Vector)
81+
private case class Cluster(id: Int, point: Vector)
8382

84-
object IndexedPoint {
85-
def apply(r: Row): IndexedPoint = {
86-
IndexedPoint(r.getInt(0), r.getAs[Vector](1))
83+
private object Cluster {
84+
def apply(r: Row): Cluster = {
85+
Cluster(r.getInt(0), r.getAs[Vector](1))
8786
}
8887
}
8988

@@ -102,21 +101,21 @@ object KMeansModel extends Loader[KMeansModel] {
102101
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k)))
103102
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
104103
val dataRDD = sc.parallelize(model.clusterCenters.zipWithIndex).map { case (point, id) =>
105-
IndexedPoint(id, point)
104+
Cluster(id, point)
106105
}.toDF()
107106
dataRDD.saveAsParquetFile(Loader.dataPath(path))
108107
}
109108

110109
def load(sc: SparkContext, path: String): KMeansModel = {
111110
implicit val formats = DefaultFormats
112111
val sqlContext = new SQLContext(sc)
113-
val (className, formatVersion, metadata) = loadMetadata(sc, path)
112+
val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
114113
assert(className == thisClassName)
115114
assert(formatVersion == thisFormatVersion)
116115
val k = (metadata \ "k").extract[Int]
117-
val centriods = sqlContext.parquetFile(dataPath(path))
118-
Loader.checkSchema[IndexedPoint](centriods.schema)
119-
val localCentriods = centriods.map(IndexedPoint.apply).collect()
116+
val centriods = sqlContext.parquetFile(Loader.dataPath(path))
117+
Loader.checkSchema[Cluster](centriods.schema)
118+
val localCentriods = centriods.map(Cluster.apply).collect()
120119
assert(k == localCentriods.size)
121120
new KMeansModel(localCentriods.sortBy(_.id).map(_.point))
122121
}

mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext {
269269
try {
270270
model.save(sc, path)
271271
val sameModel = KMeansModel.load(sc, path)
272-
KMeansSuite.checkEqual(model, sameModel, selector)
272+
KMeansSuite.checkEqual(model, sameModel)
273273
} finally {
274274
Utils.deleteRecursively(tempDir)
275275
}
@@ -288,22 +288,15 @@ object KMeansSuite extends FunSuite {
288288
new KMeansModel(Array.fill[Vector](k)(singlePoint))
289289
}
290290

291-
def checkEqual(a: KMeansModel, b: KMeansModel, isSparse: Boolean): Unit = {
291+
def checkEqual(a: KMeansModel, b: KMeansModel): Unit = {
292292
assert(a.k === b.k)
293-
isSparse match {
294-
case true =>
295-
a.clusterCenters.zip(b.clusterCenters).foreach { case (pointA, pointB) =>
296-
assert(pointA.asInstanceOf[SparseVector].size === pointB.asInstanceOf[SparseVector].size)
297-
assert(
298-
pointA.asInstanceOf[SparseVector].indices === pointB.asInstanceOf[SparseVector].indices)
299-
assert(
300-
pointA.asInstanceOf[SparseVector].values === pointB.asInstanceOf[SparseVector].values)
301-
}
293+
a.clusterCenters.zip(b.clusterCenters).foreach {
294+
case (ca: SparseVector, cb: SparseVector) =>
295+
assert(ca === cb)
296+
case (ca: DenseVector, cb: DenseVector) =>
297+
assert(ca === cb)
302298
case _ =>
303-
a.clusterCenters.zip(b.clusterCenters).foreach { case (pointA, pointB) =>
304-
assert(
305-
pointA.asInstanceOf[DenseVector].toArray === pointB.asInstanceOf[DenseVector].toArray)
306-
}
299+
throw new AssertionError("checkEqual failed since the two clusters were not identical.\n")
307300
}
308301
}
309302
}

0 commit comments

Comments
 (0)