@@ -28,58 +28,88 @@ import org.apache.spark.rdd.RDD
2828
2929class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
3030
31- private def computeSummary (data : RDD [Vector ]): MultivariateStatisticalSummary = {
32- data.treeAggregate(new MultivariateOnlineSummarizer )(
33- (aggregator, data) => aggregator.add(data),
34- (aggregator1, aggregator2) => aggregator1.merge(aggregator2))
31+ // When the input data is all constant, the variance is zero. The standardization against
32+ // zero variance is not well-defined, but we decide to just set it into zero here.
33+ val constantData = Array (
34+ Vectors .dense(2.0 ),
35+ Vectors .dense(2.0 ),
36+ Vectors .dense(2.0 )
37+ )
38+
39+ val sparseData = Array (
40+ Vectors .sparse(3 , Seq ((0 , - 2.0 ), (1 , 2.3 ))),
41+ Vectors .sparse(3 , Seq ((1 , - 1.0 ), (2 , - 3.0 ))),
42+ Vectors .sparse(3 , Seq ((1 , - 5.1 ))),
43+ Vectors .sparse(3 , Seq ((0 , 3.8 ), (2 , 1.9 ))),
44+ Vectors .sparse(3 , Seq ((0 , 1.7 ), (1 , - 0.6 ))),
45+ Vectors .sparse(3 , Seq ((1 , 1.9 )))
46+ )
47+
48+ val denseData = Array (
49+ Vectors .dense(- 2.0 , 2.3 , 0 ),
50+ Vectors .dense(0.0 , - 1.0 , - 3.0 ),
51+ Vectors .dense(0.0 , - 5.1 , 0.0 ),
52+ Vectors .dense(3.8 , 0.0 , 1.9 ),
53+ Vectors .dense(1.7 , - 0.6 , 0.0 ),
54+ Vectors .dense(0.0 , 1.9 , 0.0 )
55+ )
56+
57+ def validateConstant (data1 : Array [Vector ], data2 : Array [Vector ], data3 : Array [Vector ]) {
58+ assert(data1.forall(_.toArray.forall(_ == 0.0 )),
59+ " The variance is zero, so the transformed result should be 0.0" )
60+ assert(data2.forall(_.toArray.forall(_ == 0.0 )),
61+ " The variance is zero, so the transformed result should be 0.0" )
62+ assert(data3.forall(_.toArray.forall(_ == 0.0 )),
63+ " The variance is zero, so the transformed result should be 0.0" )
3564 }
3665
37- test(" Standardization with dense input" ) {
38- val data = Array (
39- Vectors .dense(- 2.0 , 2.3 , 0 ),
40- Vectors .dense(0.0 , - 1.0 , - 3.0 ),
41- Vectors .dense(0.0 , - 5.1 , 0.0 ),
42- Vectors .dense(3.8 , 0.0 , 1.9 ),
43- Vectors .dense(1.7 , - 0.6 , 0.0 ),
44- Vectors .dense(0.0 , 1.9 , 0.0 )
45- )
66+ def validateSparse (data : Array [Vector ], dataRDD : RDD [Vector ]) = {
4667
47- val dataRDD = sc.parallelize(data, 3 )
68+ val summary = computeSummary(dataRDD )
4869
49- val standardizer1 = new StandardScaler (withMean = true , withStd = true )
50- val standardizer2 = new StandardScaler ()
51- val standardizer3 = new StandardScaler (withMean = true , withStd = false )
70+ assert((sparseData, data, dataRDD.collect()).zipped.forall {
71+ case (v1 : DenseVector , v2 : DenseVector , v3 : DenseVector ) => true
72+ case (v1 : SparseVector , v2 : SparseVector , v3 : SparseVector ) => true
73+ case _ => false
74+ }, " The vector type should be preserved after standardization." )
5275
53- val model1 = standardizer1.fit(dataRDD)
54- val model2 = standardizer2.fit(dataRDD)
55- val model3 = standardizer3.fit(dataRDD)
76+ assert((data, dataRDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5 ))
5677
57- val data1 = data.map(model1.transform)
58- val data2 = data.map(model2.transform)
59- val data3 = data.map(model3.transform)
78+ assert(summary.mean !~== Vectors .dense(0.0 , 0.0 , 0.0 ) absTol 1E-5 )
79+ assert(summary.variance ~== Vectors .dense(1.0 , 1.0 , 1.0 ) absTol 1E-5 )
6080
61- val data1RDD = model1.transform(dataRDD )
62- val data2RDD = model2.transform(dataRDD )
63- val data3RDD = model3.transform(dataRDD)
81+ assert(data( 4 ) ~== Vectors .sparse( 3 , Seq (( 0 , 0.865538862 ), ( 1 , - 0.22604255 ))) absTol 1E-5 )
82+ assert(data( 5 ) ~== Vectors .sparse( 3 , Seq (( 1 , 0.71580142 ))) absTol 1E-5 )
83+ }
6484
65- val summary = computeSummary(dataRDD)
85+ def validateDense (
86+ data1 : Array [Vector ],
87+ data2 : Array [Vector ],
88+ data3 : Array [Vector ],
89+ refDataRDD : RDD [Vector ],
90+ data1RDD : RDD [Vector ],
91+ data2RDD : RDD [Vector ],
92+ data3RDD : RDD [Vector ]
93+ ) = {
94+
95+ val refSummary = computeSummary(refDataRDD)
6696 val summary1 = computeSummary(data1RDD)
6797 val summary2 = computeSummary(data2RDD)
6898 val summary3 = computeSummary(data3RDD)
6999
70- assert((data , data1, data1RDD.collect()).zipped.forall {
100+ assert((denseData , data1, data1RDD.collect()).zipped.forall {
71101 case (v1 : DenseVector , v2 : DenseVector , v3 : DenseVector ) => true
72102 case (v1 : SparseVector , v2 : SparseVector , v3 : SparseVector ) => true
73103 case _ => false
74104 }, " The vector type should be preserved after standardization." )
75105
76- assert((data , data2, data2RDD.collect()).zipped.forall {
106+ assert((denseData , data2, data2RDD.collect()).zipped.forall {
77107 case (v1 : DenseVector , v2 : DenseVector , v3 : DenseVector ) => true
78108 case (v1 : SparseVector , v2 : SparseVector , v3 : SparseVector ) => true
79109 case _ => false
80110 }, " The vector type should be preserved after standardization." )
81111
82- assert((data , data3, data3RDD.collect()).zipped.forall {
112+ assert((denseData , data3, data3RDD.collect()).zipped.forall {
83113 case (v1 : DenseVector , v2 : DenseVector , v3 : DenseVector ) => true
84114 case (v1 : SparseVector , v2 : SparseVector , v3 : SparseVector ) => true
85115 case _ => false
@@ -96,7 +126,7 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
96126 assert(summary2.variance ~== Vectors .dense(1.0 , 1.0 , 1.0 ) absTol 1E-5 )
97127
98128 assert(summary3.mean ~== Vectors .dense(0.0 , 0.0 , 0.0 ) absTol 1E-5 )
99- assert(summary3.variance ~== summary .variance absTol 1E-5 )
129+ assert(summary3.variance ~== refSummary .variance absTol 1E-5 )
100130
101131 assert(data1(0 ) ~== Vectors .dense(- 1.31527964 , 1.023470449 , 0.11637768424 ) absTol 1E-5 )
102132 assert(data1(3 ) ~== Vectors .dense(1.637735298 , 0.156973995 , 1.32247368462 ) absTol 1E-5 )
@@ -106,18 +136,103 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
106136 assert(data3(5 ) ~== Vectors .dense(- 0.58333333 , 2.316666666 , 0.18333333333 ) absTol 1E-5 )
107137 }
108138
139+ private def computeSummary (data : RDD [Vector ]): MultivariateStatisticalSummary = {
140+ data.treeAggregate(new MultivariateOnlineSummarizer )(
141+ (aggregator, data) => aggregator.add(data),
142+ (aggregator1, aggregator2) => aggregator1.merge(aggregator2))
143+ }
144+
145+ test(" Standardization with dense input when means and variances are provided" ) {
146+
147+ val dataRDD = sc.parallelize(denseData, 3 )
148+
149+ val standardizer1 = new StandardScaler (withMean = true , withStd = true )
150+ val standardizer2 = new StandardScaler ()
151+ val standardizer3 = new StandardScaler (withMean = true , withStd = false )
152+
153+ val model1 = standardizer1.fit(dataRDD)
154+ val model2 = standardizer2.fit(dataRDD)
155+ val model3 = standardizer3.fit(dataRDD)
156+
157+ val equivalentModel1 = new StandardScalerModel (model1.mean, model1.variance, true , true )
158+ val equivalentModel2 = new StandardScalerModel (model2.mean, model2.variance)
159+ val equivalentModel3 = new StandardScalerModel (model3.mean, model3.variance, true , false )
160+
161+ val data1 = denseData.map(equivalentModel1.transform)
162+ val data2 = denseData.map(equivalentModel2.transform)
163+ val data3 = denseData.map(equivalentModel3.transform)
164+
165+ val data1RDD = equivalentModel1.transform(dataRDD)
166+ val data2RDD = equivalentModel2.transform(dataRDD)
167+ val data3RDD = equivalentModel3.transform(dataRDD)
168+
169+ validateDense(data1, data2, data3, dataRDD, data1RDD, data2RDD, data3RDD)
170+ }
171+
172+ test(" Standardization with dense input" ) {
173+
174+ val dataRDD = sc.parallelize(denseData, 3 )
175+
176+ val standardizer1 = new StandardScaler (withMean = true , withStd = true )
177+ val standardizer2 = new StandardScaler ()
178+ val standardizer3 = new StandardScaler (withMean = true , withStd = false )
179+
180+ val model1 = standardizer1.fit(dataRDD)
181+ val model2 = standardizer2.fit(dataRDD)
182+ val model3 = standardizer3.fit(dataRDD)
183+
184+ val data1 = denseData.map(model1.transform)
185+ val data2 = denseData.map(model2.transform)
186+ val data3 = denseData.map(model3.transform)
187+
188+ val data1RDD = model1.transform(dataRDD)
189+ val data2RDD = model2.transform(dataRDD)
190+ val data3RDD = model3.transform(dataRDD)
191+
192+ validateDense(data1, data2, data3, dataRDD, data1RDD, data2RDD, data3RDD)
193+ }
194+
195+
196+ test(" Standardization with sparse input when means and variances are provided" ) {
197+
198+ val dataRDD = sc.parallelize(sparseData, 3 )
199+
200+ val standardizer1 = new StandardScaler (withMean = true , withStd = true )
201+ val standardizer2 = new StandardScaler ()
202+ val standardizer3 = new StandardScaler (withMean = true , withStd = false )
203+
204+ val model1 = standardizer1.fit(dataRDD)
205+ val model2 = standardizer2.fit(dataRDD)
206+ val model3 = standardizer3.fit(dataRDD)
207+
208+ val equivalentModel1 = new StandardScalerModel (model1.mean, model1.variance, true , true )
209+ val equivalentModel2 = new StandardScalerModel (model2.mean, model2.variance)
210+ val equivalentModel3 = new StandardScalerModel (model3.mean, model3.variance, true , false )
211+
212+
213+ val data2 = sparseData.map(equivalentModel2.transform)
214+
215+ withClue(" Standardization with mean can not be applied on sparse input." ) {
216+ intercept[IllegalArgumentException ] {
217+ sparseData.map(equivalentModel1.transform)
218+ }
219+ }
220+
221+ withClue(" Standardization with mean can not be applied on sparse input." ) {
222+ intercept[IllegalArgumentException ] {
223+ sparseData.map(equivalentModel3.transform)
224+ }
225+ }
226+
227+ val data2RDD = equivalentModel2.transform(dataRDD)
228+
229+ validateSparse(data2, data2RDD)
230+
231+ }
109232
110233 test(" Standardization with sparse input" ) {
111- val data = Array (
112- Vectors .sparse(3 , Seq ((0 , - 2.0 ), (1 , 2.3 ))),
113- Vectors .sparse(3 , Seq ((1 , - 1.0 ), (2 , - 3.0 ))),
114- Vectors .sparse(3 , Seq ((1 , - 5.1 ))),
115- Vectors .sparse(3 , Seq ((0 , 3.8 ), (2 , 1.9 ))),
116- Vectors .sparse(3 , Seq ((0 , 1.7 ), (1 , - 0.6 ))),
117- Vectors .sparse(3 , Seq ((1 , 1.9 )))
118- )
119234
120- val dataRDD = sc.parallelize(data , 3 )
235+ val dataRDD = sc.parallelize(sparseData , 3 )
121236
122237 val standardizer1 = new StandardScaler (withMean = true , withStd = true )
123238 val standardizer2 = new StandardScaler ()
@@ -127,49 +242,52 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
127242 val model2 = standardizer2.fit(dataRDD)
128243 val model3 = standardizer3.fit(dataRDD)
129244
130- val data2 = data .map(model2.transform)
245+ val data2 = sparseData .map(model2.transform)
131246
132247 withClue(" Standardization with mean can not be applied on sparse input." ) {
133248 intercept[IllegalArgumentException ] {
134- data .map(model1.transform)
249+ sparseData .map(model1.transform)
135250 }
136251 }
137252
138253 withClue(" Standardization with mean can not be applied on sparse input." ) {
139254 intercept[IllegalArgumentException ] {
140- data .map(model3.transform)
255+ sparseData .map(model3.transform)
141256 }
142257 }
143258
144259 val data2RDD = model2.transform(dataRDD)
145260
146- val summary2 = computeSummary(data2RDD)
261+ validateSparse(data2, data2RDD)
262+ }
147263
148- assert((data, data2, data2RDD.collect()).zipped.forall {
149- case (v1 : DenseVector , v2 : DenseVector , v3 : DenseVector ) => true
150- case (v1 : SparseVector , v2 : SparseVector , v3 : SparseVector ) => true
151- case _ => false
152- }, " The vector type should be preserved after standardization." )
264+ test(" Standardization with constant input when means and variances are provided" ) {
153265
154- assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5 ) )
266+ val dataRDD = sc.parallelize(constantData, 2 )
155267
156- assert(summary2.mean !~== Vectors .dense(0.0 , 0.0 , 0.0 ) absTol 1E-5 )
157- assert(summary2.variance ~== Vectors .dense(1.0 , 1.0 , 1.0 ) absTol 1E-5 )
268+ val standardizer1 = new StandardScaler (withMean = true , withStd = true )
269+ val standardizer2 = new StandardScaler (withMean = true , withStd = false )
270+ val standardizer3 = new StandardScaler (withMean = false , withStd = true )
271+
272+ val model1 = standardizer1.fit(dataRDD)
273+ val model2 = standardizer2.fit(dataRDD)
274+ val model3 = standardizer3.fit(dataRDD)
275+
276+ val equivalentModel1 = new StandardScalerModel (model1.mean, model1.variance, true , true )
277+ val equivalentModel2 = new StandardScalerModel (model2.mean, model2.variance, true , false )
278+ val equivalentModel3 = new StandardScalerModel (model3.mean, model3.variance, false , true )
279+
280+ val data1 = constantData.map(equivalentModel1.transform)
281+ val data2 = constantData.map(equivalentModel2.transform)
282+ val data3 = constantData.map(equivalentModel3.transform)
283+
284+ validateConstant(data1, data2, data3)
158285
159- assert(data2(4 ) ~== Vectors .sparse(3 , Seq ((0 , 0.865538862 ), (1 , - 0.22604255 ))) absTol 1E-5 )
160- assert(data2(5 ) ~== Vectors .sparse(3 , Seq ((1 , 0.71580142 ))) absTol 1E-5 )
161286 }
162287
163288 test(" Standardization with constant input" ) {
164- // When the input data is all constant, the variance is zero. The standardization against
165- // zero variance is not well-defined, but we decide to just set it into zero here.
166- val data = Array (
167- Vectors .dense(2.0 ),
168- Vectors .dense(2.0 ),
169- Vectors .dense(2.0 )
170- )
171289
172- val dataRDD = sc.parallelize(data , 2 )
290+ val dataRDD = sc.parallelize(constantData , 2 )
173291
174292 val standardizer1 = new StandardScaler (withMean = true , withStd = true )
175293 val standardizer2 = new StandardScaler (withMean = true , withStd = false )
@@ -179,16 +297,12 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
179297 val model2 = standardizer2.fit(dataRDD)
180298 val model3 = standardizer3.fit(dataRDD)
181299
182- val data1 = data.map(model1.transform)
183- val data2 = data.map(model2.transform)
184- val data3 = data.map(model3.transform)
300+ val data1 = constantData.map(model1.transform)
301+ val data2 = constantData.map(model2.transform)
302+ val data3 = constantData.map(model3.transform)
303+
304+ validateConstant(data1, data2, data3)
185305
186- assert(data1.forall(_.toArray.forall(_ == 0.0 )),
187- " The variance is zero, so the transformed result should be 0.0" )
188- assert(data2.forall(_.toArray.forall(_ == 0.0 )),
189- " The variance is zero, so the transformed result should be 0.0" )
190- assert(data3.forall(_.toArray.forall(_ == 0.0 )),
191- " The variance is zero, so the transformed result should be 0.0" )
192306 }
193307
194308}
0 commit comments