Skip to content

Commit fa64dfa

Browse files
committed
[SPARK-5207] [MLLIB] [WIP] change StandardScalerModel to take stddev instead of variance
1 parent 9078fe0 commit fa64dfa

File tree

3 files changed

+44
-49
lines changed

3 files changed

+44
-49
lines changed

docs/mllib-feature-extraction.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -240,11 +240,11 @@ following parameters in the constructor:
240240

241241
* `withMean` False by default. Centers the data with mean before scaling. It will build a dense
242242
output, so this does not work on sparse input and will raise an exception.
243-
* `withStd` True by default. Scales the data to unit variance.
243+
* `withStd` True by default. Scales the data to unit standard deviation.
244244

245245
We provide a [`fit`](api/scala/index.html#org.apache.spark.mllib.feature.StandardScaler) method in
246246
`StandardScaler` which can take an input of `RDD[Vector]`, learn the summary statistics, and then
247-
return a model which can transform the input dataset into unit variance and/or zero mean features
247+
return a model which can transform the input dataset into unit standard deviation and/or zero mean features
248248
depending how we configure the `StandardScaler`.
249249

250250
This model implements [`VectorTransformer`](api/scala/index.html#org.apache.spark.mllib.feature.VectorTransformer)
@@ -257,7 +257,7 @@ for that feature.
257257
### Example
258258

259259
The example below demonstrates how to load a dataset in libsvm format, and standardize the features
260-
so that the new features have unit variance and/or zero mean.
260+
so that the new features have unit standard deviation and/or zero mean.
261261

262262
<div class="codetabs">
263263
<div data-lang="scala">
@@ -272,7 +272,7 @@ val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
272272
val scaler1 = new StandardScaler().fit(data.map(x => x.features))
273273
val scaler2 = new StandardScaler(withMean = true, withStd = true).fit(data.map(x => x.features))
274274
// scaler3 is an identical model to scaler2, and will produce identical transformations
275-
val scaler3 = new StandardScalerModel(scaler2.variance, scaler2.mean)
275+
val scaler3 = new StandardScalerModel(scaler2.std, scaler2.mean)
276276

277277
// data1 will be unit variance.
278278
val data1 = data.map(x => (x.label, scaler1.transform(x.features)))
@@ -297,7 +297,7 @@ features = data.map(lambda x: x.features)
297297
scaler1 = StandardScaler().fit(features)
298298
scaler2 = StandardScaler(withMean=True, withStd=True).fit(features)
299299
# scaler3 is an identical model to scaler2, and will produce identical transformations
300-
scaler3 = StandardScalerModel(scaler2.variance, scaler2.mean)
300+
scaler3 = StandardScalerModel(scaler2.std, scaler2.mean)
301301

302302

303303
# data1 will be unit variance.

mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.apache.spark.rdd.RDD
2626

2727
/**
2828
* :: Experimental ::
29-
* Standardizes features by removing the mean and scaling to unit standard deviation using column summary
29+
* Standardizes features by removing the mean and scaling to unit std using column summary
3030
* statistics on the samples in the training set.
3131
*
3232
* @param withMean False by default. Centers the data with mean before scaling. It will build a
@@ -53,33 +53,41 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging {
5353
val summary = data.treeAggregate(new MultivariateOnlineSummarizer)(
5454
(aggregator, data) => aggregator.add(data),
5555
(aggregator1, aggregator2) => aggregator1.merge(aggregator2))
56-
new StandardScalerModel(summary.variance, summary.mean, withStd, withMean)
56+
new StandardScalerModel(
57+
Vectors.dense(summary.variance.toArray.map(v => math.sqrt(v))),
58+
summary.mean,
59+
withStd,
60+
withMean)
5761
}
5862
}
5963

6064
/**
6165
* :: Experimental ::
6266
* Represents a StandardScaler model that can transform vectors.
6367
*
64-
* @param variance column variance values
68+
* @param std column standard deviation values
6569
* @param mean column mean values
6670
* @param withStd whether to scale the data to have unit standard deviation
6771
* @param withMean whether to center the data before scaling
6872
*/
6973
@Experimental
7074
class StandardScalerModel (
71-
val variance: Vector,
75+
val std: Vector,
7276
val mean: Vector,
7377
var withStd: Boolean,
7478
var withMean: Boolean) extends VectorTransformer {
7579

76-
def this(variance: Vector, mean: Vector) {
77-
this(variance, mean, withStd = variance != null, withMean = mean != null)
78-
require(this.withStd || this.withMean, "at least one of variance or mean vectors must be provided")
79-
if (this.withStd && this.withMean) require(mean.size == variance.size, "mean and variance vectors must have equal size if both are provided")
80+
def this(std: Vector, mean: Vector) {
81+
this(std, mean, withStd = std != null, withMean = mean != null)
82+
require(this.withStd || this.withMean,
83+
"at least one of std or mean vectors must be provided")
84+
if (this.withStd && this.withMean) {
85+
require(mean.size == std.size,
86+
"mean and std vectors must have equal size if both are provided")
87+
}
8088
}
8189

82-
def this(variance: Vector) = this(variance, null)
90+
def this(std: Vector) = this(std, null)
8391

8492
@DeveloperApi
8593
def setWithMean(withMean: Boolean): this.type = {
@@ -90,21 +98,12 @@ class StandardScalerModel (
9098

9199
@DeveloperApi
92100
def setWithStd(withStd: Boolean): this.type = {
93-
require(!(withStd && this.variance == null), "cannot set withStd to true while variance is null")
101+
require(!(withStd && this.std == null),
102+
"cannot set withStd to true while std is null")
94103
this.withStd = withStd
95104
this
96105
}
97106

98-
private lazy val factor: Array[Double] = {
99-
val f = Array.ofDim[Double](variance.size)
100-
var i = 0
101-
while (i < f.size) {
102-
f(i) = if (variance(i) != 0.0) 1.0 / math.sqrt(variance(i)) else 0.0
103-
i += 1
104-
}
105-
f
106-
}
107-
108107
// Since `shift` will be only used in `withMean` branch, we have it as
109108
// `lazy val` so it will be evaluated in that branch. Note that we don't
110109
// want to create this array multiple times in `transform` function.
@@ -114,8 +113,8 @@ class StandardScalerModel (
114113
* Applies standardization transformation on a vector.
115114
*
116115
* @param vector Vector to be standardized.
117-
* @return Standardized vector. If the standard deviation of a column is zero, it will return default `0.0`
118-
* for the column with zero standard deviation.
116+
* @return Standardized vector. If the std of a column is zero, it will return default `0.0`
117+
* for the column with zero std.
119118
*/
120119
override def transform(vector: Vector): Vector = {
121120
require(mean.size == vector.size)
@@ -129,11 +128,9 @@ class StandardScalerModel (
129128
val values = vs.clone()
130129
val size = values.size
131130
if (withStd) {
132-
// Having a local reference of `factor` to avoid overhead as the comment before.
133-
val localFactor = factor
134131
var i = 0
135132
while (i < size) {
136-
values(i) = (values(i) - localShift(i)) * localFactor(i)
133+
values(i) = if (std(i) != 0.0) (values(i) - localShift(i)) * (1.0 / std(i)) else 0.0
137134
i += 1
138135
}
139136
} else {
@@ -147,15 +144,13 @@ class StandardScalerModel (
147144
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
148145
}
149146
} else if (withStd) {
150-
// Having a local reference of `factor` to avoid overhead as the comment before.
151-
val localFactor = factor
152147
vector match {
153148
case DenseVector(vs) =>
154149
val values = vs.clone()
155150
val size = values.size
156151
var i = 0
157152
while(i < size) {
158-
values(i) *= localFactor(i)
153+
values(i) *= (if (std(i) != 0.0) 1.0 / std(i) else 0.0)
159154
i += 1
160155
}
161156
Vectors.dense(values)
@@ -166,7 +161,7 @@ class StandardScalerModel (
166161
val nnz = values.size
167162
var i = 0
168163
while (i < nnz) {
169-
values(i) *= localFactor(indices(i))
164+
values(i) *= (if (std(indices(i)) != 0.0) 1.0 / std(indices(i)) else 0.0)
170165
i += 1
171166
}
172167
Vectors.sparse(size, indices, values)

mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
6060
(aggregator1, aggregator2) => aggregator1.merge(aggregator2))
6161
}
6262

63-
test("Standardization with dense input when means and variances are provided") {
63+
test("Standardization with dense input when means and stds are provided") {
6464

6565
val dataRDD = sc.parallelize(denseData, 3)
6666

@@ -72,9 +72,9 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
7272
val model2 = standardizer2.fit(dataRDD)
7373
val model3 = standardizer3.fit(dataRDD)
7474

75-
val equivalentModel1 = new StandardScalerModel(model1.variance, model1.mean)
76-
val equivalentModel2 = new StandardScalerModel(model2.variance, model2.mean, true, false)
77-
val equivalentModel3 = new StandardScalerModel(model3.variance, model3.mean, false, true)
75+
val equivalentModel1 = new StandardScalerModel(model1.std, model1.mean)
76+
val equivalentModel2 = new StandardScalerModel(model2.std, model2.mean, true, false)
77+
val equivalentModel3 = new StandardScalerModel(model3.std, model3.mean, false, true)
7878

7979
val data1 = denseData.map(equivalentModel1.transform)
8080
val data2 = denseData.map(equivalentModel2.transform)
@@ -193,7 +193,7 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
193193
}
194194

195195

196-
test("Standardization with sparse input when means and variances are provided") {
196+
test("Standardization with sparse input when means and stds are provided") {
197197

198198
val dataRDD = sc.parallelize(sparseData, 3)
199199

@@ -205,9 +205,9 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
205205
val model2 = standardizer2.fit(dataRDD)
206206
val model3 = standardizer3.fit(dataRDD)
207207

208-
val equivalentModel1 = new StandardScalerModel(model1.variance, model1.mean)
209-
val equivalentModel2 = new StandardScalerModel(model2.variance, model2.mean, true, false)
210-
val equivalentModel3 = new StandardScalerModel(model3.variance, model3.mean, false, true)
208+
val equivalentModel1 = new StandardScalerModel(model1.std, model1.mean)
209+
val equivalentModel2 = new StandardScalerModel(model2.std, model2.mean, true, false)
210+
val equivalentModel3 = new StandardScalerModel(model3.std, model3.mean, false, true)
211211

212212
val data2 = sparseData.map(equivalentModel2.transform)
213213

@@ -288,7 +288,7 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
288288
assert(data2(5) ~== Vectors.sparse(3, Seq((1, 0.71580142))) absTol 1E-5)
289289
}
290290

291-
test("Standardization with constant input when means and variances are provided") {
291+
test("Standardization with constant input when means and stds are provided") {
292292

293293
val dataRDD = sc.parallelize(constantData, 2)
294294

@@ -300,9 +300,9 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
300300
val model2 = standardizer2.fit(dataRDD)
301301
val model3 = standardizer3.fit(dataRDD)
302302

303-
val equivalentModel1 = new StandardScalerModel(model1.variance, model1.mean)
304-
val equivalentModel2 = new StandardScalerModel(model2.variance, model2.mean, true, false)
305-
val equivalentModel3 = new StandardScalerModel(model3.variance, model3.mean, false, true)
303+
val equivalentModel1 = new StandardScalerModel(model1.std, model1.mean)
304+
val equivalentModel2 = new StandardScalerModel(model2.std, model2.mean, true, false)
305+
val equivalentModel3 = new StandardScalerModel(model3.std, model3.mean, false, true)
306306

307307
val data1 = constantData.map(equivalentModel1.transform)
308308
val data2 = constantData.map(equivalentModel2.transform)
@@ -342,12 +342,12 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
342342

343343
test("StandardScalerModel argument nulls are properly handled") {
344344

345-
withClue("model needs at least one of variance or mean vectors") {
345+
withClue("model needs at least one of std or mean vectors") {
346346
intercept[IllegalArgumentException] {
347347
val model = new StandardScalerModel(null, null)
348348
}
349349
}
350-
withClue("model needs variance to set withStd to true") {
350+
withClue("model needs std to set withStd to true") {
351351
intercept[IllegalArgumentException] {
352352
val model = new StandardScalerModel(null, Vectors.dense(0.0))
353353
model.setWithStd(true)
@@ -359,7 +359,7 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext {
359359
model.setWithMean(true)
360360
}
361361
}
362-
withClue("model needs variance and mean vectors to be equal size when both are provided") {
362+
withClue("model needs std and mean vectors to be equal size when both are provided") {
363363
intercept[IllegalArgumentException] {
364364
val model = new StandardScalerModel(Vectors.dense(0.0), Vectors.dense(0.0,1.0))
365365
}

0 commit comments

Comments
 (0)