@@ -54,45 +54,37 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
5454 Vectors .dense(0.0 , 1.9 , 0.0 )
5555 )
5656
57- def validateConstant (data1 : Array [Vector ], data2 : Array [Vector ], data3 : Array [Vector ]) {
58- assert(data1.forall(_.toArray.forall(_ == 0.0 )),
59- " The variance is zero, so the transformed result should be 0.0" )
60- assert(data2.forall(_.toArray.forall(_ == 0.0 )),
61- " The variance is zero, so the transformed result should be 0.0" )
62- assert(data3.forall(_.toArray.forall(_ == 0.0 )),
63- " The variance is zero, so the transformed result should be 0.0" )
57+ private def computeSummary (data : RDD [Vector ]): MultivariateStatisticalSummary = {
58+ data.treeAggregate(new MultivariateOnlineSummarizer )(
59+ (aggregator, data) => aggregator.add(data),
60+ (aggregator1, aggregator2) => aggregator1.merge(aggregator2))
6461 }
6562
66- def validateSparse ( data : Array [ Vector ], dataRDD : RDD [ Vector ]) = {
63+ test( " Standardization with dense input when means and variances are provided " ) {
6764
68- val summary = computeSummary(dataRDD )
65+ val dataRDD = sc.parallelize(denseData, 3 )
6966
70- assert((sparseData, data, dataRDD.collect()).zipped.forall {
71- case (v1 : DenseVector , v2 : DenseVector , v3 : DenseVector ) => true
72- case (v1 : SparseVector , v2 : SparseVector , v3 : SparseVector ) => true
73- case _ => false
74- }, " The vector type should be preserved after standardization." )
67+ val standardizer1 = new StandardScaler (withMean = true , withStd = true )
68+ val standardizer2 = new StandardScaler ()
69+ val standardizer3 = new StandardScaler (withMean = true , withStd = false )
7570
76- assert((data, dataRDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5 ))
71+ val model1 = standardizer1.fit(dataRDD)
72+ val model2 = standardizer2.fit(dataRDD)
73+ val model3 = standardizer3.fit(dataRDD)
7774
78- assert(summary.mean !~== Vectors .dense(0.0 , 0.0 , 0.0 ) absTol 1E-5 )
79- assert(summary.variance ~== Vectors .dense(1.0 , 1.0 , 1.0 ) absTol 1E-5 )
75+ val equivalentModel1 = new StandardScalerModel (model1.mean, model1.variance, true , true )
76+ val equivalentModel2 = new StandardScalerModel (model2.mean, model2.variance)
77+ val equivalentModel3 = new StandardScalerModel (model3.mean, model3.variance, true , false )
8078
81- assert(data( 4 ) ~== Vectors .sparse( 3 , Seq (( 0 , 0.865538862 ), ( 1 , - 0.22604255 ))) absTol 1E-5 )
82- assert(data( 5 ) ~== Vectors .sparse( 3 , Seq (( 1 , 0.71580142 ))) absTol 1E-5 )
83- }
79+ val data1 = denseData.map(equivalentModel1.transform )
80+ val data2 = denseData.map(equivalentModel2.transform )
81+ val data3 = denseData.map(equivalentModel3.transform)
8482
85- def validateDense (
86- data1 : Array [Vector ],
87- data2 : Array [Vector ],
88- data3 : Array [Vector ],
89- refDataRDD : RDD [Vector ],
90- data1RDD : RDD [Vector ],
91- data2RDD : RDD [Vector ],
92- data3RDD : RDD [Vector ]
93- ) = {
94-
95- val refSummary = computeSummary(refDataRDD)
83+ val data1RDD = equivalentModel1.transform(dataRDD)
84+ val data2RDD = equivalentModel2.transform(dataRDD)
85+ val data3RDD = equivalentModel3.transform(dataRDD)
86+
87+ val summary = computeSummary(dataRDD)
9688 val summary1 = computeSummary(data1RDD)
9789 val summary2 = computeSummary(data2RDD)
9890 val summary3 = computeSummary(data3RDD)
@@ -126,7 +118,7 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
126118 assert(summary2.variance ~== Vectors .dense(1.0 , 1.0 , 1.0 ) absTol 1E-5 )
127119
128120 assert(summary3.mean ~== Vectors .dense(0.0 , 0.0 , 0.0 ) absTol 1E-5 )
129- assert(summary3.variance ~== refSummary .variance absTol 1E-5 )
121+ assert(summary3.variance ~== summary .variance absTol 1E-5 )
130122
131123 assert(data1(0 ) ~== Vectors .dense(- 1.31527964 , 1.023470449 , 0.11637768424 ) absTol 1E-5 )
132124 assert(data1(3 ) ~== Vectors .dense(1.637735298 , 0.156973995 , 1.32247368462 ) absTol 1E-5 )
@@ -136,13 +128,7 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
136128 assert(data3(5 ) ~== Vectors .dense(- 0.58333333 , 2.316666666 , 0.18333333333 ) absTol 1E-5 )
137129 }
138130
139- private def computeSummary (data : RDD [Vector ]): MultivariateStatisticalSummary = {
140- data.treeAggregate(new MultivariateOnlineSummarizer )(
141- (aggregator, data) => aggregator.add(data),
142- (aggregator1, aggregator2) => aggregator1.merge(aggregator2))
143- }
144-
145- test(" Standardization with dense input when means and variances are provided" ) {
131+ test(" Standardization with dense input" ) {
146132
147133 val dataRDD = sc.parallelize(denseData, 3 )
148134
@@ -154,42 +140,56 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
154140 val model2 = standardizer2.fit(dataRDD)
155141 val model3 = standardizer3.fit(dataRDD)
156142
157- val equivalentModel1 = new StandardScalerModel (model1.mean, model1.variance, true , true )
158- val equivalentModel2 = new StandardScalerModel (model2.mean, model2.variance )
159- val equivalentModel3 = new StandardScalerModel (model3.mean, model3.variance, true , false )
143+ val data1 = denseData.map (model1.transform )
144+ val data2 = denseData.map (model2.transform )
145+ val data3 = denseData.map (model3.transform )
160146
161- val data1 = denseData.map(equivalentModel1. transform)
162- val data2 = denseData.map(equivalentModel2. transform)
163- val data3 = denseData.map(equivalentModel3. transform)
147+ val data1RDD = model1. transform(dataRDD )
148+ val data2RDD = model2. transform(dataRDD )
149+ val data3RDD = model3. transform(dataRDD )
164150
165- val data1RDD = equivalentModel1.transform(dataRDD)
166- val data2RDD = equivalentModel2.transform(dataRDD)
167- val data3RDD = equivalentModel3.transform(dataRDD)
151+ val summary = computeSummary(dataRDD)
152+ val summary1 = computeSummary(data1RDD)
153+ val summary2 = computeSummary(data2RDD)
154+ val summary3 = computeSummary(data3RDD)
168155
169- validateDense(data1, data2, data3, dataRDD, data1RDD, data2RDD, data3RDD)
170- }
156+ assert((denseData, data1, data1RDD.collect()).zipped.forall {
157+ case (v1 : DenseVector , v2 : DenseVector , v3 : DenseVector ) => true
158+ case (v1 : SparseVector , v2 : SparseVector , v3 : SparseVector ) => true
159+ case _ => false
160+ }, " The vector type should be preserved after standardization." )
171161
172- test(" Standardization with dense input" ) {
162+ assert((denseData, data2, data2RDD.collect()).zipped.forall {
163+ case (v1 : DenseVector , v2 : DenseVector , v3 : DenseVector ) => true
164+ case (v1 : SparseVector , v2 : SparseVector , v3 : SparseVector ) => true
165+ case _ => false
166+ }, " The vector type should be preserved after standardization." )
173167
174- val dataRDD = sc.parallelize(denseData, 3 )
168+ assert((denseData, data3, data3RDD.collect()).zipped.forall {
169+ case (v1 : DenseVector , v2 : DenseVector , v3 : DenseVector ) => true
170+ case (v1 : SparseVector , v2 : SparseVector , v3 : SparseVector ) => true
171+ case _ => false
172+ }, " The vector type should be preserved after standardization." )
175173
176- val standardizer1 = new StandardScaler (withMean = true , withStd = true )
177- val standardizer2 = new StandardScaler ( )
178- val standardizer3 = new StandardScaler (withMean = true , withStd = false )
174+ assert((data1, data1RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5 ) )
175+ assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5 ) )
176+ assert((data3, data3RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5 ) )
179177
180- val model1 = standardizer1.fit(dataRDD)
181- val model2 = standardizer2.fit(dataRDD)
182- val model3 = standardizer3.fit(dataRDD)
178+ assert(summary1.mean ~== Vectors .dense(0.0 , 0.0 , 0.0 ) absTol 1E-5 )
179+ assert(summary1.variance ~== Vectors .dense(1.0 , 1.0 , 1.0 ) absTol 1E-5 )
183180
184- val data1 = denseData.map(model1.transform)
185- val data2 = denseData.map(model2.transform)
186- val data3 = denseData.map(model3.transform)
181+ assert(summary2.mean !~== Vectors .dense(0.0 , 0.0 , 0.0 ) absTol 1E-5 )
182+ assert(summary2.variance ~== Vectors .dense(1.0 , 1.0 , 1.0 ) absTol 1E-5 )
187183
188- val data1RDD = model1.transform(dataRDD)
189- val data2RDD = model2.transform(dataRDD)
190- val data3RDD = model3.transform(dataRDD)
184+ assert(summary3.mean ~== Vectors .dense(0.0 , 0.0 , 0.0 ) absTol 1E-5 )
185+ assert(summary3.variance ~== summary.variance absTol 1E-5 )
191186
192- validateDense(data1, data2, data3, dataRDD, data1RDD, data2RDD, data3RDD)
187+ assert(data1(0 ) ~== Vectors .dense(- 1.31527964 , 1.023470449 , 0.11637768424 ) absTol 1E-5 )
188+ assert(data1(3 ) ~== Vectors .dense(1.637735298 , 0.156973995 , 1.32247368462 ) absTol 1E-5 )
189+ assert(data2(4 ) ~== Vectors .dense(0.865538862 , - 0.22604255 , 0.0 ) absTol 1E-5 )
190+ assert(data2(5 ) ~== Vectors .dense(0.0 , 0.71580142 , 0.0 ) absTol 1E-5 )
191+ assert(data3(1 ) ~== Vectors .dense(- 0.58333333 , - 0.58333333 , - 2.8166666666 ) absTol 1E-5 )
192+ assert(data3(5 ) ~== Vectors .dense(- 0.58333333 , 2.316666666 , 0.18333333333 ) absTol 1E-5 )
193193 }
194194
195195
@@ -226,8 +226,21 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
226226
227227 val data2RDD = equivalentModel2.transform(dataRDD)
228228
229- validateSparse(data2, data2RDD)
229+ val summary = computeSummary(data2RDD)
230+
231+ assert((sparseData, data2, data2RDD.collect()).zipped.forall {
232+ case (v1 : DenseVector , v2 : DenseVector , v3 : DenseVector ) => true
233+ case (v1 : SparseVector , v2 : SparseVector , v3 : SparseVector ) => true
234+ case _ => false
235+ }, " The vector type should be preserved after standardization." )
236+
237+ assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5 ))
238+
239+ assert(summary.mean !~== Vectors .dense(0.0 , 0.0 , 0.0 ) absTol 1E-5 )
240+ assert(summary.variance ~== Vectors .dense(1.0 , 1.0 , 1.0 ) absTol 1E-5 )
230241
242+ assert(data2(4 ) ~== Vectors .sparse(3 , Seq ((0 , 0.865538862 ), (1 , - 0.22604255 ))) absTol 1E-5 )
243+ assert(data2(5 ) ~== Vectors .sparse(3 , Seq ((1 , 0.71580142 ))) absTol 1E-5 )
231244 }
232245
233246 test(" Standardization with sparse input" ) {
@@ -258,7 +271,22 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
258271
259272 val data2RDD = model2.transform(dataRDD)
260273
261- validateSparse(data2, data2RDD)
274+
275+ val summary = computeSummary(data2RDD)
276+
277+ assert((sparseData, data2, data2RDD.collect()).zipped.forall {
278+ case (v1 : DenseVector , v2 : DenseVector , v3 : DenseVector ) => true
279+ case (v1 : SparseVector , v2 : SparseVector , v3 : SparseVector ) => true
280+ case _ => false
281+ }, " The vector type should be preserved after standardization." )
282+
283+ assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5 ))
284+
285+ assert(summary.mean !~== Vectors .dense(0.0 , 0.0 , 0.0 ) absTol 1E-5 )
286+ assert(summary.variance ~== Vectors .dense(1.0 , 1.0 , 1.0 ) absTol 1E-5 )
287+
288+ assert(data2(4 ) ~== Vectors .sparse(3 , Seq ((0 , 0.865538862 ), (1 , - 0.22604255 ))) absTol 1E-5 )
289+ assert(data2(5 ) ~== Vectors .sparse(3 , Seq ((1 , 0.71580142 ))) absTol 1E-5 )
262290 }
263291
264292 test(" Standardization with constant input when means and variances are provided" ) {
@@ -281,8 +309,12 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
281309 val data2 = constantData.map(equivalentModel2.transform)
282310 val data3 = constantData.map(equivalentModel3.transform)
283311
284- validateConstant(data1, data2, data3)
285-
312+ assert(data1.forall(_.toArray.forall(_ == 0.0 )),
313+ " The variance is zero, so the transformed result should be 0.0" )
314+ assert(data2.forall(_.toArray.forall(_ == 0.0 )),
315+ " The variance is zero, so the transformed result should be 0.0" )
316+ assert(data3.forall(_.toArray.forall(_ == 0.0 )),
317+ " The variance is zero, so the transformed result should be 0.0" )
286318 }
287319
288320 test(" Standardization with constant input" ) {
@@ -301,8 +333,11 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
301333 val data2 = constantData.map(model2.transform)
302334 val data3 = constantData.map(model3.transform)
303335
304- validateConstant(data1, data2, data3)
305-
336+ assert(data1.forall(_.toArray.forall(_ == 0.0 )),
337+ " The variance is zero, so the transformed result should be 0.0" )
338+ assert(data2.forall(_.toArray.forall(_ == 0.0 )),
339+ " The variance is zero, so the transformed result should be 0.0" )
340+ assert(data3.forall(_.toArray.forall(_ == 0.0 )),
341+ " The variance is zero, so the transformed result should be 0.0" )
306342 }
307-
308343}
0 commit comments