Skip to content

Commit bdb0680

Browse files
ogeaglamengxr
authored andcommitted
[SPARK-5207] [MLLIB] StandardScalerModel mean and variance re-use
This seems complete, the duplication of tests for provided means/variances might be overkill, would appreciate some feedback. Author: Octavian Geagla <[email protected]> Closes apache#4140 from ogeagla/SPARK-5207 and squashes the following commits: fa64dfa [Octavian Geagla] [SPARK-5207] [MLLIB] [WIP] change StandardScalerModel to take stddev instead of variance 9078fe0 [Octavian Geagla] [SPARK-5207] [MLLIB] [WIP] Incorporate code review feedback: change arg ordering, add dev api annotations, do better null checking, add another test and some doc for this. 997d2e0 [Octavian Geagla] [SPARK-5207] [MLLIB] [WIP] make withMean and withStd public, add constructor which uses defaults, un-refactor test class 64408a4 [Octavian Geagla] [SPARK-5207] [MLLIB] [WIP] change StandardScalerModel contructor to not be private to mllib, added tests for newly-exposed functionality
1 parent 80bd715 commit bdb0680

File tree

3 files changed

+267
-73
lines changed

3 files changed

+267
-73
lines changed

docs/mllib-feature-extraction.md

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,11 +240,11 @@ following parameters in the constructor:
240240

241241
* `withMean` False by default. Centers the data with mean before scaling. It will build a dense
242242
output, so this does not work on sparse input and will raise an exception.
243-
* `withStd` True by default. Scales the data to unit variance.
243+
* `withStd` True by default. Scales the data to unit standard deviation.
244244

245245
We provide a [`fit`](api/scala/index.html#org.apache.spark.mllib.feature.StandardScaler) method in
246246
`StandardScaler` which can take an input of `RDD[Vector]`, learn the summary statistics, and then
247-
return a model which can transform the input dataset into unit variance and/or zero mean features
247+
return a model which can transform the input dataset into unit standard deviation and/or zero mean features
248248
depending how we configure the `StandardScaler`.
249249

250250
This model implements [`VectorTransformer`](api/scala/index.html#org.apache.spark.mllib.feature.VectorTransformer)
@@ -257,7 +257,7 @@ for that feature.
257257
### Example
258258

259259
The example below demonstrates how to load a dataset in libsvm format, and standardize the features
260-
so that the new features have unit variance and/or zero mean.
260+
so that the new features have unit standard deviation and/or zero mean.
261261

262262
<div class="codetabs">
263263
<div data-lang="scala">
@@ -271,6 +271,8 @@ val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
271271

272272
val scaler1 = new StandardScaler().fit(data.map(x => x.features))
273273
val scaler2 = new StandardScaler(withMean = true, withStd = true).fit(data.map(x => x.features))
274+
// scaler3 is an identical model to scaler2, and will produce identical transformations
275+
val scaler3 = new StandardScalerModel(scaler2.std, scaler2.mean)
274276

275277
// data1 will be unit variance.
276278
val data1 = data.map(x => (x.label, scaler1.transform(x.features)))
@@ -294,6 +296,9 @@ features = data.map(lambda x: x.features)
294296

295297
scaler1 = StandardScaler().fit(features)
296298
scaler2 = StandardScaler(withMean=True, withStd=True).fit(features)
299+
# scaler3 is an identical model to scaler2, and will produce identical transformations
300+
scaler3 = StandardScalerModel(scaler2.std, scaler2.mean)
301+
297302

298303
# data1 will be unit variance.
299304
data1 = label.zip(scaler1.transform(features))

mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@
1818
package org.apache.spark.mllib.feature
1919

2020
import org.apache.spark.Logging
21-
import org.apache.spark.annotation.Experimental
21+
import org.apache.spark.annotation.{DeveloperApi, Experimental}
2222
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
2323
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
2424
import org.apache.spark.rdd.RDD
2525

2626
/**
2727
* :: Experimental ::
28-
* Standardizes features by removing the mean and scaling to unit variance using column summary
28+
* Standardizes features by removing the mean and scaling to unit std using column summary
2929
* statistics on the samples in the training set.
3030
*
3131
* @param withMean False by default. Centers the data with mean before scaling. It will build a
@@ -52,36 +52,55 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging {
5252
val summary = data.treeAggregate(new MultivariateOnlineSummarizer)(
5353
(aggregator, data) => aggregator.add(data),
5454
(aggregator1, aggregator2) => aggregator1.merge(aggregator2))
55-
new StandardScalerModel(withMean, withStd, summary.mean, summary.variance)
55+
new StandardScalerModel(
56+
Vectors.dense(summary.variance.toArray.map(v => math.sqrt(v))),
57+
summary.mean,
58+
withStd,
59+
withMean)
5660
}
5761
}
5862

5963
/**
6064
* :: Experimental ::
6165
* Represents a StandardScaler model that can transform vectors.
6266
*
63-
* @param withMean whether to center the data before scaling
64-
* @param withStd whether to scale the data to have unit standard deviation
67+
* @param std column standard deviation values
6568
* @param mean column mean values
66-
* @param variance column variance values
69+
* @param withStd whether to scale the data to have unit standard deviation
70+
* @param withMean whether to center the data before scaling
6771
*/
6872
@Experimental
69-
class StandardScalerModel private[mllib] (
70-
val withMean: Boolean,
71-
val withStd: Boolean,
73+
class StandardScalerModel (
74+
val std: Vector,
7275
val mean: Vector,
73-
val variance: Vector) extends VectorTransformer {
74-
75-
require(mean.size == variance.size)
76+
var withStd: Boolean,
77+
var withMean: Boolean) extends VectorTransformer {
7678

77-
private lazy val factor: Array[Double] = {
78-
val f = Array.ofDim[Double](variance.size)
79-
var i = 0
80-
while (i < f.size) {
81-
f(i) = if (variance(i) != 0.0) 1.0 / math.sqrt(variance(i)) else 0.0
82-
i += 1
79+
def this(std: Vector, mean: Vector) {
80+
this(std, mean, withStd = std != null, withMean = mean != null)
81+
require(this.withStd || this.withMean,
82+
"at least one of std or mean vectors must be provided")
83+
if (this.withStd && this.withMean) {
84+
require(mean.size == std.size,
85+
"mean and std vectors must have equal size if both are provided")
8386
}
84-
f
87+
}
88+
89+
def this(std: Vector) = this(std, null)
90+
91+
@DeveloperApi
92+
def setWithMean(withMean: Boolean): this.type = {
93+
require(!(withMean && this.mean == null),"cannot set withMean to true while mean is null")
94+
this.withMean = withMean
95+
this
96+
}
97+
98+
@DeveloperApi
99+
def setWithStd(withStd: Boolean): this.type = {
100+
require(!(withStd && this.std == null),
101+
"cannot set withStd to true while std is null")
102+
this.withStd = withStd
103+
this
85104
}
86105

87106
// Since `shift` will be only used in `withMean` branch, we have it as
@@ -93,8 +112,8 @@ class StandardScalerModel private[mllib] (
93112
* Applies standardization transformation on a vector.
94113
*
95114
* @param vector Vector to be standardized.
96-
* @return Standardized vector. If the variance of a column is zero, it will return default `0.0`
97-
* for the column with zero variance.
115+
* @return Standardized vector. If the std of a column is zero, it will return default `0.0`
116+
* for the column with zero std.
98117
*/
99118
override def transform(vector: Vector): Vector = {
100119
require(mean.size == vector.size)
@@ -108,11 +127,9 @@ class StandardScalerModel private[mllib] (
108127
val values = vs.clone()
109128
val size = values.size
110129
if (withStd) {
111-
// Having a local reference of `factor` to avoid overhead as the comment before.
112-
val localFactor = factor
113130
var i = 0
114131
while (i < size) {
115-
values(i) = (values(i) - localShift(i)) * localFactor(i)
132+
values(i) = if (std(i) != 0.0) (values(i) - localShift(i)) * (1.0 / std(i)) else 0.0
116133
i += 1
117134
}
118135
} else {
@@ -126,15 +143,13 @@ class StandardScalerModel private[mllib] (
126143
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
127144
}
128145
} else if (withStd) {
129-
// Having a local reference of `factor` to avoid overhead as the comment before.
130-
val localFactor = factor
131146
vector match {
132147
case DenseVector(vs) =>
133148
val values = vs.clone()
134149
val size = values.size
135150
var i = 0
136151
while(i < size) {
137-
values(i) *= localFactor(i)
152+
values(i) *= (if (std(i) != 0.0) 1.0 / std(i) else 0.0)
138153
i += 1
139154
}
140155
Vectors.dense(values)
@@ -145,7 +160,7 @@ class StandardScalerModel private[mllib] (
145160
val nnz = values.size
146161
var i = 0
147162
while (i < nnz) {
148-
values(i) *= localFactor(indices(i))
163+
values(i) *= (if (std(indices(i)) != 0.0) 1.0 / std(indices(i)) else 0.0)
149164
i += 1
150165
}
151166
Vectors.sparse(size, indices, values)

0 commit comments

Comments
 (0)