Skip to content

Commit 997d2e0

Browse files
committed
[SPARK-5207] [MLLIB] [WIP] make withMean and withStd public, add constructor which uses defaults, un-refactor test class
1 parent 64408a4 commit 997d2e0

File tree

2 files changed

+113
-74
lines changed

2 files changed

+113
-74
lines changed

mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,15 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging {
7070
class StandardScalerModel (
7171
val mean: Vector,
7272
val variance: Vector,
73-
private var withMean: Boolean = false,
74-
private var withStd: Boolean = true) extends VectorTransformer {
73+
var withMean: Boolean,
74+
var withStd: Boolean) extends VectorTransformer {
7575

7676
require(mean.size == variance.size)
7777

78+
def this(mean: Vector, variance: Vector) {
79+
this(mean, variance, false, true)
80+
}
81+
7882
def setWithMean(withMean: Boolean): this.type = {
7983
this.withMean = withMean
8084
this

mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala

Lines changed: 107 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -54,45 +54,37 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
5454
Vectors.dense(0.0, 1.9, 0.0)
5555
)
5656

57-
def validateConstant(data1: Array[Vector], data2: Array[Vector], data3: Array[Vector]) {
58-
assert(data1.forall(_.toArray.forall(_ == 0.0)),
59-
"The variance is zero, so the transformed result should be 0.0")
60-
assert(data2.forall(_.toArray.forall(_ == 0.0)),
61-
"The variance is zero, so the transformed result should be 0.0")
62-
assert(data3.forall(_.toArray.forall(_ == 0.0)),
63-
"The variance is zero, so the transformed result should be 0.0")
57+
private def computeSummary(data: RDD[Vector]): MultivariateStatisticalSummary = {
58+
data.treeAggregate(new MultivariateOnlineSummarizer)(
59+
(aggregator, data) => aggregator.add(data),
60+
(aggregator1, aggregator2) => aggregator1.merge(aggregator2))
6461
}
6562

66-
def validateSparse(data: Array[Vector], dataRDD: RDD[Vector]) = {
63+
test("Standardization with dense input when means and variances are provided") {
6764

68-
val summary = computeSummary(dataRDD)
65+
val dataRDD = sc.parallelize(denseData, 3)
6966

70-
assert((sparseData, data, dataRDD.collect()).zipped.forall {
71-
case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true
72-
case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true
73-
case _ => false
74-
}, "The vector type should be preserved after standardization.")
67+
val standardizer1 = new StandardScaler(withMean = true, withStd = true)
68+
val standardizer2 = new StandardScaler()
69+
val standardizer3 = new StandardScaler(withMean = true, withStd = false)
7570

76-
assert((data, dataRDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5))
71+
val model1 = standardizer1.fit(dataRDD)
72+
val model2 = standardizer2.fit(dataRDD)
73+
val model3 = standardizer3.fit(dataRDD)
7774

78-
assert(summary.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
79-
assert(summary.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5)
75+
val equivalentModel1 = new StandardScalerModel(model1.mean, model1.variance, true, true)
76+
val equivalentModel2 = new StandardScalerModel(model2.mean, model2.variance)
77+
val equivalentModel3 = new StandardScalerModel(model3.mean, model3.variance, true, false)
8078

81-
assert(data(4) ~== Vectors.sparse(3, Seq((0, 0.865538862), (1, -0.22604255))) absTol 1E-5)
82-
assert(data(5) ~== Vectors.sparse(3, Seq((1, 0.71580142))) absTol 1E-5)
83-
}
79+
val data1 = denseData.map(equivalentModel1.transform)
80+
val data2 = denseData.map(equivalentModel2.transform)
81+
val data3 = denseData.map(equivalentModel3.transform)
8482

85-
def validateDense(
86-
data1: Array[Vector],
87-
data2: Array[Vector],
88-
data3: Array[Vector],
89-
refDataRDD: RDD[Vector],
90-
data1RDD: RDD[Vector],
91-
data2RDD: RDD[Vector],
92-
data3RDD: RDD[Vector]
93-
) = {
94-
95-
val refSummary = computeSummary(refDataRDD)
83+
val data1RDD = equivalentModel1.transform(dataRDD)
84+
val data2RDD = equivalentModel2.transform(dataRDD)
85+
val data3RDD = equivalentModel3.transform(dataRDD)
86+
87+
val summary = computeSummary(dataRDD)
9688
val summary1 = computeSummary(data1RDD)
9789
val summary2 = computeSummary(data2RDD)
9890
val summary3 = computeSummary(data3RDD)
@@ -126,7 +118,7 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
126118
assert(summary2.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5)
127119

128120
assert(summary3.mean ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
129-
assert(summary3.variance ~== refSummary.variance absTol 1E-5)
121+
assert(summary3.variance ~== summary.variance absTol 1E-5)
130122

131123
assert(data1(0) ~== Vectors.dense(-1.31527964, 1.023470449, 0.11637768424) absTol 1E-5)
132124
assert(data1(3) ~== Vectors.dense(1.637735298, 0.156973995, 1.32247368462) absTol 1E-5)
@@ -136,13 +128,7 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
136128
assert(data3(5) ~== Vectors.dense(-0.58333333, 2.316666666, 0.18333333333) absTol 1E-5)
137129
}
138130

139-
private def computeSummary(data: RDD[Vector]): MultivariateStatisticalSummary = {
140-
data.treeAggregate(new MultivariateOnlineSummarizer)(
141-
(aggregator, data) => aggregator.add(data),
142-
(aggregator1, aggregator2) => aggregator1.merge(aggregator2))
143-
}
144-
145-
test("Standardization with dense input when means and variances are provided") {
131+
test("Standardization with dense input") {
146132

147133
val dataRDD = sc.parallelize(denseData, 3)
148134

@@ -154,42 +140,56 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
154140
val model2 = standardizer2.fit(dataRDD)
155141
val model3 = standardizer3.fit(dataRDD)
156142

157-
val equivalentModel1 = new StandardScalerModel(model1.mean, model1.variance, true, true)
158-
val equivalentModel2 = new StandardScalerModel(model2.mean, model2.variance)
159-
val equivalentModel3 = new StandardScalerModel(model3.mean, model3.variance, true, false)
143+
val data1 = denseData.map(model1.transform)
144+
val data2 = denseData.map(model2.transform)
145+
val data3 = denseData.map(model3.transform)
160146

161-
val data1 = denseData.map(equivalentModel1.transform)
162-
val data2 = denseData.map(equivalentModel2.transform)
163-
val data3 = denseData.map(equivalentModel3.transform)
147+
val data1RDD = model1.transform(dataRDD)
148+
val data2RDD = model2.transform(dataRDD)
149+
val data3RDD = model3.transform(dataRDD)
164150

165-
val data1RDD = equivalentModel1.transform(dataRDD)
166-
val data2RDD = equivalentModel2.transform(dataRDD)
167-
val data3RDD = equivalentModel3.transform(dataRDD)
151+
val summary = computeSummary(dataRDD)
152+
val summary1 = computeSummary(data1RDD)
153+
val summary2 = computeSummary(data2RDD)
154+
val summary3 = computeSummary(data3RDD)
168155

169-
validateDense(data1, data2, data3, dataRDD, data1RDD, data2RDD, data3RDD)
170-
}
156+
assert((denseData, data1, data1RDD.collect()).zipped.forall {
157+
case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true
158+
case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true
159+
case _ => false
160+
}, "The vector type should be preserved after standardization.")
171161

172-
test("Standardization with dense input") {
162+
assert((denseData, data2, data2RDD.collect()).zipped.forall {
163+
case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true
164+
case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true
165+
case _ => false
166+
}, "The vector type should be preserved after standardization.")
173167

174-
val dataRDD = sc.parallelize(denseData, 3)
168+
assert((denseData, data3, data3RDD.collect()).zipped.forall {
169+
case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true
170+
case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true
171+
case _ => false
172+
}, "The vector type should be preserved after standardization.")
175173

176-
val standardizer1 = new StandardScaler(withMean = true, withStd = true)
177-
val standardizer2 = new StandardScaler()
178-
val standardizer3 = new StandardScaler(withMean = true, withStd = false)
174+
assert((data1, data1RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5))
175+
assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5))
176+
assert((data3, data3RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5))
179177

180-
val model1 = standardizer1.fit(dataRDD)
181-
val model2 = standardizer2.fit(dataRDD)
182-
val model3 = standardizer3.fit(dataRDD)
178+
assert(summary1.mean ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
179+
assert(summary1.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5)
183180

184-
val data1 = denseData.map(model1.transform)
185-
val data2 = denseData.map(model2.transform)
186-
val data3 = denseData.map(model3.transform)
181+
assert(summary2.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
182+
assert(summary2.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5)
187183

188-
val data1RDD = model1.transform(dataRDD)
189-
val data2RDD = model2.transform(dataRDD)
190-
val data3RDD = model3.transform(dataRDD)
184+
assert(summary3.mean ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
185+
assert(summary3.variance ~== summary.variance absTol 1E-5)
191186

192-
validateDense(data1, data2, data3, dataRDD, data1RDD, data2RDD, data3RDD)
187+
assert(data1(0) ~== Vectors.dense(-1.31527964, 1.023470449, 0.11637768424) absTol 1E-5)
188+
assert(data1(3) ~== Vectors.dense(1.637735298, 0.156973995, 1.32247368462) absTol 1E-5)
189+
assert(data2(4) ~== Vectors.dense(0.865538862, -0.22604255, 0.0) absTol 1E-5)
190+
assert(data2(5) ~== Vectors.dense(0.0, 0.71580142, 0.0) absTol 1E-5)
191+
assert(data3(1) ~== Vectors.dense(-0.58333333, -0.58333333, -2.8166666666) absTol 1E-5)
192+
assert(data3(5) ~== Vectors.dense(-0.58333333, 2.316666666, 0.18333333333) absTol 1E-5)
193193
}
194194

195195

@@ -226,8 +226,21 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
226226

227227
val data2RDD = equivalentModel2.transform(dataRDD)
228228

229-
validateSparse(data2, data2RDD)
229+
val summary = computeSummary(data2RDD)
230+
231+
assert((sparseData, data2, data2RDD.collect()).zipped.forall {
232+
case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true
233+
case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true
234+
case _ => false
235+
}, "The vector type should be preserved after standardization.")
236+
237+
assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5))
238+
239+
assert(summary.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
240+
assert(summary.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5)
230241

242+
assert(data2(4) ~== Vectors.sparse(3, Seq((0, 0.865538862), (1, -0.22604255))) absTol 1E-5)
243+
assert(data2(5) ~== Vectors.sparse(3, Seq((1, 0.71580142))) absTol 1E-5)
231244
}
232245

233246
test("Standardization with sparse input") {
@@ -258,7 +271,22 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
258271

259272
val data2RDD = model2.transform(dataRDD)
260273

261-
validateSparse(data2, data2RDD)
274+
275+
val summary = computeSummary(data2RDD)
276+
277+
assert((sparseData, data2, data2RDD.collect()).zipped.forall {
278+
case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true
279+
case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true
280+
case _ => false
281+
}, "The vector type should be preserved after standardization.")
282+
283+
assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5))
284+
285+
assert(summary.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
286+
assert(summary.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5)
287+
288+
assert(data2(4) ~== Vectors.sparse(3, Seq((0, 0.865538862), (1, -0.22604255))) absTol 1E-5)
289+
assert(data2(5) ~== Vectors.sparse(3, Seq((1, 0.71580142))) absTol 1E-5)
262290
}
263291

264292
test("Standardization with constant input when means and variances are provided") {
@@ -281,8 +309,12 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
281309
val data2 = constantData.map(equivalentModel2.transform)
282310
val data3 = constantData.map(equivalentModel3.transform)
283311

284-
validateConstant(data1, data2, data3)
285-
312+
assert(data1.forall(_.toArray.forall(_ == 0.0)),
313+
"The variance is zero, so the transformed result should be 0.0")
314+
assert(data2.forall(_.toArray.forall(_ == 0.0)),
315+
"The variance is zero, so the transformed result should be 0.0")
316+
assert(data3.forall(_.toArray.forall(_ == 0.0)),
317+
"The variance is zero, so the transformed result should be 0.0")
286318
}
287319

288320
test("Standardization with constant input") {
@@ -301,8 +333,11 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
301333
val data2 = constantData.map(model2.transform)
302334
val data3 = constantData.map(model3.transform)
303335

304-
validateConstant(data1, data2, data3)
305-
336+
assert(data1.forall(_.toArray.forall(_ == 0.0)),
337+
"The variance is zero, so the transformed result should be 0.0")
338+
assert(data2.forall(_.toArray.forall(_ == 0.0)),
339+
"The variance is zero, so the transformed result should be 0.0")
340+
assert(data3.forall(_.toArray.forall(_ == 0.0)),
341+
"The variance is zero, so the transformed result should be 0.0")
306342
}
307-
308343
}

0 commit comments

Comments
 (0)