Skip to content

Commit 64408a4

Browse files
committed
[SPARK-5207] [MLLIB] [WIP] change StandardScalerModel contructor to not be private to mllib, added tests for newly-exposed functionality
1 parent 2eeada3 commit 64408a4

File tree

2 files changed

+202
-78
lines changed

2 files changed

+202
-78
lines changed

mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,28 +53,38 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging {
5353
val summary = data.treeAggregate(new MultivariateOnlineSummarizer)(
5454
(aggregator, data) => aggregator.add(data),
5555
(aggregator1, aggregator2) => aggregator1.merge(aggregator2))
56-
new StandardScalerModel(withMean, withStd, summary.mean, summary.variance)
56+
new StandardScalerModel(summary.mean, summary.variance, withMean, withStd)
5757
}
5858
}
5959

6060
/**
6161
* :: Experimental ::
6262
* Represents a StandardScaler model that can transform vectors.
6363
*
64-
* @param withMean whether to center the data before scaling
65-
* @param withStd whether to scale the data to have unit standard deviation
6664
* @param mean column mean values
6765
* @param variance column variance values
66+
* @param withMean whether to center the data before scaling
67+
* @param withStd whether to scale the data to have unit standard deviation
6868
*/
6969
@Experimental
70-
class StandardScalerModel private[mllib] (
71-
val withMean: Boolean,
72-
val withStd: Boolean,
70+
class StandardScalerModel (
7371
val mean: Vector,
74-
val variance: Vector) extends VectorTransformer {
72+
val variance: Vector,
73+
private var withMean: Boolean = false,
74+
private var withStd: Boolean = true) extends VectorTransformer {
7575

7676
require(mean.size == variance.size)
7777

78+
def setWithMean(withMean: Boolean): this.type = {
79+
this.withMean = withMean
80+
this
81+
}
82+
83+
def setWithStd(withStd: Boolean): this.type = {
84+
this.withStd = withStd
85+
this
86+
}
87+
7888
private lazy val factor: Array[Double] = {
7989
val f = Array.ofDim[Double](variance.size)
8090
var i = 0

mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala

Lines changed: 185 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -28,58 +28,88 @@ import org.apache.spark.rdd.RDD
2828

2929
class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
3030

31-
private def computeSummary(data: RDD[Vector]): MultivariateStatisticalSummary = {
32-
data.treeAggregate(new MultivariateOnlineSummarizer)(
33-
(aggregator, data) => aggregator.add(data),
34-
(aggregator1, aggregator2) => aggregator1.merge(aggregator2))
31+
// When the input data is all constant, the variance is zero. The standardization against
32+
// zero variance is not well-defined, but we decide to just set it into zero here.
33+
val constantData = Array(
34+
Vectors.dense(2.0),
35+
Vectors.dense(2.0),
36+
Vectors.dense(2.0)
37+
)
38+
39+
val sparseData = Array(
40+
Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))),
41+
Vectors.sparse(3, Seq((1, -1.0), (2, -3.0))),
42+
Vectors.sparse(3, Seq((1, -5.1))),
43+
Vectors.sparse(3, Seq((0, 3.8), (2, 1.9))),
44+
Vectors.sparse(3, Seq((0, 1.7), (1, -0.6))),
45+
Vectors.sparse(3, Seq((1, 1.9)))
46+
)
47+
48+
val denseData = Array(
49+
Vectors.dense(-2.0, 2.3, 0),
50+
Vectors.dense(0.0, -1.0, -3.0),
51+
Vectors.dense(0.0, -5.1, 0.0),
52+
Vectors.dense(3.8, 0.0, 1.9),
53+
Vectors.dense(1.7, -0.6, 0.0),
54+
Vectors.dense(0.0, 1.9, 0.0)
55+
)
56+
57+
def validateConstant(data1: Array[Vector], data2: Array[Vector], data3: Array[Vector]) {
58+
assert(data1.forall(_.toArray.forall(_ == 0.0)),
59+
"The variance is zero, so the transformed result should be 0.0")
60+
assert(data2.forall(_.toArray.forall(_ == 0.0)),
61+
"The variance is zero, so the transformed result should be 0.0")
62+
assert(data3.forall(_.toArray.forall(_ == 0.0)),
63+
"The variance is zero, so the transformed result should be 0.0")
3564
}
3665

37-
test("Standardization with dense input") {
38-
val data = Array(
39-
Vectors.dense(-2.0, 2.3, 0),
40-
Vectors.dense(0.0, -1.0, -3.0),
41-
Vectors.dense(0.0, -5.1, 0.0),
42-
Vectors.dense(3.8, 0.0, 1.9),
43-
Vectors.dense(1.7, -0.6, 0.0),
44-
Vectors.dense(0.0, 1.9, 0.0)
45-
)
66+
def validateSparse(data: Array[Vector], dataRDD: RDD[Vector]) = {
4667

47-
val dataRDD = sc.parallelize(data, 3)
68+
val summary = computeSummary(dataRDD)
4869

49-
val standardizer1 = new StandardScaler(withMean = true, withStd = true)
50-
val standardizer2 = new StandardScaler()
51-
val standardizer3 = new StandardScaler(withMean = true, withStd = false)
70+
assert((sparseData, data, dataRDD.collect()).zipped.forall {
71+
case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true
72+
case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true
73+
case _ => false
74+
}, "The vector type should be preserved after standardization.")
5275

53-
val model1 = standardizer1.fit(dataRDD)
54-
val model2 = standardizer2.fit(dataRDD)
55-
val model3 = standardizer3.fit(dataRDD)
76+
assert((data, dataRDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5))
5677

57-
val data1 = data.map(model1.transform)
58-
val data2 = data.map(model2.transform)
59-
val data3 = data.map(model3.transform)
78+
assert(summary.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
79+
assert(summary.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5)
6080

61-
val data1RDD = model1.transform(dataRDD)
62-
val data2RDD = model2.transform(dataRDD)
63-
val data3RDD = model3.transform(dataRDD)
81+
assert(data(4) ~== Vectors.sparse(3, Seq((0, 0.865538862), (1, -0.22604255))) absTol 1E-5)
82+
assert(data(5) ~== Vectors.sparse(3, Seq((1, 0.71580142))) absTol 1E-5)
83+
}
6484

65-
val summary = computeSummary(dataRDD)
85+
def validateDense(
86+
data1: Array[Vector],
87+
data2: Array[Vector],
88+
data3: Array[Vector],
89+
refDataRDD: RDD[Vector],
90+
data1RDD: RDD[Vector],
91+
data2RDD: RDD[Vector],
92+
data3RDD: RDD[Vector]
93+
) = {
94+
95+
val refSummary = computeSummary(refDataRDD)
6696
val summary1 = computeSummary(data1RDD)
6797
val summary2 = computeSummary(data2RDD)
6898
val summary3 = computeSummary(data3RDD)
6999

70-
assert((data, data1, data1RDD.collect()).zipped.forall {
100+
assert((denseData, data1, data1RDD.collect()).zipped.forall {
71101
case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true
72102
case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true
73103
case _ => false
74104
}, "The vector type should be preserved after standardization.")
75105

76-
assert((data, data2, data2RDD.collect()).zipped.forall {
106+
assert((denseData, data2, data2RDD.collect()).zipped.forall {
77107
case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true
78108
case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true
79109
case _ => false
80110
}, "The vector type should be preserved after standardization.")
81111

82-
assert((data, data3, data3RDD.collect()).zipped.forall {
112+
assert((denseData, data3, data3RDD.collect()).zipped.forall {
83113
case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true
84114
case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true
85115
case _ => false
@@ -96,7 +126,7 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
96126
assert(summary2.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5)
97127

98128
assert(summary3.mean ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
99-
assert(summary3.variance ~== summary.variance absTol 1E-5)
129+
assert(summary3.variance ~== refSummary.variance absTol 1E-5)
100130

101131
assert(data1(0) ~== Vectors.dense(-1.31527964, 1.023470449, 0.11637768424) absTol 1E-5)
102132
assert(data1(3) ~== Vectors.dense(1.637735298, 0.156973995, 1.32247368462) absTol 1E-5)
@@ -106,18 +136,103 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
106136
assert(data3(5) ~== Vectors.dense(-0.58333333, 2.316666666, 0.18333333333) absTol 1E-5)
107137
}
108138

139+
private def computeSummary(data: RDD[Vector]): MultivariateStatisticalSummary = {
140+
data.treeAggregate(new MultivariateOnlineSummarizer)(
141+
(aggregator, data) => aggregator.add(data),
142+
(aggregator1, aggregator2) => aggregator1.merge(aggregator2))
143+
}
144+
145+
test("Standardization with dense input when means and variances are provided") {
146+
147+
val dataRDD = sc.parallelize(denseData, 3)
148+
149+
val standardizer1 = new StandardScaler(withMean = true, withStd = true)
150+
val standardizer2 = new StandardScaler()
151+
val standardizer3 = new StandardScaler(withMean = true, withStd = false)
152+
153+
val model1 = standardizer1.fit(dataRDD)
154+
val model2 = standardizer2.fit(dataRDD)
155+
val model3 = standardizer3.fit(dataRDD)
156+
157+
val equivalentModel1 = new StandardScalerModel(model1.mean, model1.variance, true, true)
158+
val equivalentModel2 = new StandardScalerModel(model2.mean, model2.variance)
159+
val equivalentModel3 = new StandardScalerModel(model3.mean, model3.variance, true, false)
160+
161+
val data1 = denseData.map(equivalentModel1.transform)
162+
val data2 = denseData.map(equivalentModel2.transform)
163+
val data3 = denseData.map(equivalentModel3.transform)
164+
165+
val data1RDD = equivalentModel1.transform(dataRDD)
166+
val data2RDD = equivalentModel2.transform(dataRDD)
167+
val data3RDD = equivalentModel3.transform(dataRDD)
168+
169+
validateDense(data1, data2, data3, dataRDD, data1RDD, data2RDD, data3RDD)
170+
}
171+
172+
test("Standardization with dense input") {
173+
174+
val dataRDD = sc.parallelize(denseData, 3)
175+
176+
val standardizer1 = new StandardScaler(withMean = true, withStd = true)
177+
val standardizer2 = new StandardScaler()
178+
val standardizer3 = new StandardScaler(withMean = true, withStd = false)
179+
180+
val model1 = standardizer1.fit(dataRDD)
181+
val model2 = standardizer2.fit(dataRDD)
182+
val model3 = standardizer3.fit(dataRDD)
183+
184+
val data1 = denseData.map(model1.transform)
185+
val data2 = denseData.map(model2.transform)
186+
val data3 = denseData.map(model3.transform)
187+
188+
val data1RDD = model1.transform(dataRDD)
189+
val data2RDD = model2.transform(dataRDD)
190+
val data3RDD = model3.transform(dataRDD)
191+
192+
validateDense(data1, data2, data3, dataRDD, data1RDD, data2RDD, data3RDD)
193+
}
194+
195+
196+
test("Standardization with sparse input when means and variances are provided") {
197+
198+
val dataRDD = sc.parallelize(sparseData, 3)
199+
200+
val standardizer1 = new StandardScaler(withMean = true, withStd = true)
201+
val standardizer2 = new StandardScaler()
202+
val standardizer3 = new StandardScaler(withMean = true, withStd = false)
203+
204+
val model1 = standardizer1.fit(dataRDD)
205+
val model2 = standardizer2.fit(dataRDD)
206+
val model3 = standardizer3.fit(dataRDD)
207+
208+
val equivalentModel1 = new StandardScalerModel(model1.mean, model1.variance, true, true)
209+
val equivalentModel2 = new StandardScalerModel(model2.mean, model2.variance)
210+
val equivalentModel3 = new StandardScalerModel(model3.mean, model3.variance, true, false)
211+
212+
213+
val data2 = sparseData.map(equivalentModel2.transform)
214+
215+
withClue("Standardization with mean can not be applied on sparse input.") {
216+
intercept[IllegalArgumentException] {
217+
sparseData.map(equivalentModel1.transform)
218+
}
219+
}
220+
221+
withClue("Standardization with mean can not be applied on sparse input.") {
222+
intercept[IllegalArgumentException] {
223+
sparseData.map(equivalentModel3.transform)
224+
}
225+
}
226+
227+
val data2RDD = equivalentModel2.transform(dataRDD)
228+
229+
validateSparse(data2, data2RDD)
230+
231+
}
109232

110233
test("Standardization with sparse input") {
111-
val data = Array(
112-
Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))),
113-
Vectors.sparse(3, Seq((1, -1.0), (2, -3.0))),
114-
Vectors.sparse(3, Seq((1, -5.1))),
115-
Vectors.sparse(3, Seq((0, 3.8), (2, 1.9))),
116-
Vectors.sparse(3, Seq((0, 1.7), (1, -0.6))),
117-
Vectors.sparse(3, Seq((1, 1.9)))
118-
)
119234

120-
val dataRDD = sc.parallelize(data, 3)
235+
val dataRDD = sc.parallelize(sparseData, 3)
121236

122237
val standardizer1 = new StandardScaler(withMean = true, withStd = true)
123238
val standardizer2 = new StandardScaler()
@@ -127,49 +242,52 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
127242
val model2 = standardizer2.fit(dataRDD)
128243
val model3 = standardizer3.fit(dataRDD)
129244

130-
val data2 = data.map(model2.transform)
245+
val data2 = sparseData.map(model2.transform)
131246

132247
withClue("Standardization with mean can not be applied on sparse input.") {
133248
intercept[IllegalArgumentException] {
134-
data.map(model1.transform)
249+
sparseData.map(model1.transform)
135250
}
136251
}
137252

138253
withClue("Standardization with mean can not be applied on sparse input.") {
139254
intercept[IllegalArgumentException] {
140-
data.map(model3.transform)
255+
sparseData.map(model3.transform)
141256
}
142257
}
143258

144259
val data2RDD = model2.transform(dataRDD)
145260

146-
val summary2 = computeSummary(data2RDD)
261+
validateSparse(data2, data2RDD)
262+
}
147263

148-
assert((data, data2, data2RDD.collect()).zipped.forall {
149-
case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true
150-
case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true
151-
case _ => false
152-
}, "The vector type should be preserved after standardization.")
264+
test("Standardization with constant input when means and variances are provided") {
153265

154-
assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5))
266+
val dataRDD = sc.parallelize(constantData, 2)
155267

156-
assert(summary2.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5)
157-
assert(summary2.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5)
268+
val standardizer1 = new StandardScaler(withMean = true, withStd = true)
269+
val standardizer2 = new StandardScaler(withMean = true, withStd = false)
270+
val standardizer3 = new StandardScaler(withMean = false, withStd = true)
271+
272+
val model1 = standardizer1.fit(dataRDD)
273+
val model2 = standardizer2.fit(dataRDD)
274+
val model3 = standardizer3.fit(dataRDD)
275+
276+
val equivalentModel1 = new StandardScalerModel(model1.mean, model1.variance, true, true)
277+
val equivalentModel2 = new StandardScalerModel(model2.mean, model2.variance, true, false)
278+
val equivalentModel3 = new StandardScalerModel(model3.mean, model3.variance, false, true)
279+
280+
val data1 = constantData.map(equivalentModel1.transform)
281+
val data2 = constantData.map(equivalentModel2.transform)
282+
val data3 = constantData.map(equivalentModel3.transform)
283+
284+
validateConstant(data1, data2, data3)
158285

159-
assert(data2(4) ~== Vectors.sparse(3, Seq((0, 0.865538862), (1, -0.22604255))) absTol 1E-5)
160-
assert(data2(5) ~== Vectors.sparse(3, Seq((1, 0.71580142))) absTol 1E-5)
161286
}
162287

163288
test("Standardization with constant input") {
164-
// When the input data is all constant, the variance is zero. The standardization against
165-
// zero variance is not well-defined, but we decide to just set it into zero here.
166-
val data = Array(
167-
Vectors.dense(2.0),
168-
Vectors.dense(2.0),
169-
Vectors.dense(2.0)
170-
)
171289

172-
val dataRDD = sc.parallelize(data, 2)
290+
val dataRDD = sc.parallelize(constantData, 2)
173291

174292
val standardizer1 = new StandardScaler(withMean = true, withStd = true)
175293
val standardizer2 = new StandardScaler(withMean = true, withStd = false)
@@ -179,16 +297,12 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
179297
val model2 = standardizer2.fit(dataRDD)
180298
val model3 = standardizer3.fit(dataRDD)
181299

182-
val data1 = data.map(model1.transform)
183-
val data2 = data.map(model2.transform)
184-
val data3 = data.map(model3.transform)
300+
val data1 = constantData.map(model1.transform)
301+
val data2 = constantData.map(model2.transform)
302+
val data3 = constantData.map(model3.transform)
303+
304+
validateConstant(data1, data2, data3)
185305

186-
assert(data1.forall(_.toArray.forall(_ == 0.0)),
187-
"The variance is zero, so the transformed result should be 0.0")
188-
assert(data2.forall(_.toArray.forall(_ == 0.0)),
189-
"The variance is zero, so the transformed result should be 0.0")
190-
assert(data3.forall(_.toArray.forall(_ == 0.0)),
191-
"The variance is zero, so the transformed result should be 0.0")
192306
}
193307

194308
}

0 commit comments

Comments
 (0)