1818package org .apache .spark .mllib .feature
1919
2020import org .apache .spark .Logging
21- import org .apache .spark .annotation .Experimental
21+ import org .apache .spark .annotation .{ DeveloperApi , Experimental }
2222import org .apache .spark .mllib .linalg .{DenseVector , SparseVector , Vector , Vectors }
2323import org .apache .spark .mllib .stat .MultivariateOnlineSummarizer
2424import org .apache .spark .rdd .RDD
2525
2626/**
2727 * :: Experimental ::
28- * Standardizes features by removing the mean and scaling to unit variance using column summary
28+ * Standardizes features by removing the mean and scaling to unit std using column summary
2929 * statistics on the samples in the training set.
3030 *
3131 * @param withMean False by default. Centers the data with mean before scaling. It will build a
@@ -52,36 +52,55 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging {
5252 val summary = data.treeAggregate(new MultivariateOnlineSummarizer )(
5353 (aggregator, data) => aggregator.add(data),
5454 (aggregator1, aggregator2) => aggregator1.merge(aggregator2))
55- new StandardScalerModel (withMean, withStd, summary.mean, summary.variance)
55+ new StandardScalerModel (
56+ Vectors .dense(summary.variance.toArray.map(v => math.sqrt(v))),
57+ summary.mean,
58+ withStd,
59+ withMean)
5660 }
5761}
5862
5963/**
6064 * :: Experimental ::
6165 * Represents a StandardScaler model that can transform vectors.
6266 *
63- * @param withMean whether to center the data before scaling
64- * @param withStd whether to scale the data to have unit standard deviation
67+ * @param std column standard deviation values
6568 * @param mean column mean values
66- * @param variance column variance values
69+ * @param withStd whether to scale the data to have unit standard deviation
70+ * @param withMean whether to center the data before scaling
6771 */
6872@ Experimental
69- class StandardScalerModel private [mllib] (
70- val withMean : Boolean ,
71- val withStd : Boolean ,
73+ class StandardScalerModel (
74+ val std : Vector ,
7275 val mean : Vector ,
73- val variance : Vector ) extends VectorTransformer {
74-
75- require(mean.size == variance.size)
76+ var withStd : Boolean ,
77+ var withMean : Boolean ) extends VectorTransformer {
7678
77- private lazy val factor : Array [Double ] = {
78- val f = Array .ofDim[Double ](variance.size)
79- var i = 0
80- while (i < f.size) {
81- f(i) = if (variance(i) != 0.0 ) 1.0 / math.sqrt(variance(i)) else 0.0
82- i += 1
79+ def this (std : Vector , mean : Vector ) {
80+ this (std, mean, withStd = std != null , withMean = mean != null )
81+ require(this .withStd || this .withMean,
82+ " at least one of std or mean vectors must be provided" )
83+ if (this .withStd && this .withMean) {
84+ require(mean.size == std.size,
85+ " mean and std vectors must have equal size if both are provided" )
8386 }
84- f
87+ }
88+
89+ def this (std : Vector ) = this (std, null )
90+
91+ @ DeveloperApi
92+ def setWithMean (withMean : Boolean ): this .type = {
93+ require(! (withMean && this .mean == null )," cannot set withMean to true while mean is null" )
94+ this .withMean = withMean
95+ this
96+ }
97+
98+ @ DeveloperApi
99+ def setWithStd (withStd : Boolean ): this .type = {
100+ require(! (withStd && this .std == null ),
101+ " cannot set withStd to true while std is null" )
102+ this .withStd = withStd
103+ this
85104 }
86105
87106 // Since `shift` will be only used in `withMean` branch, we have it as
@@ -93,8 +112,8 @@ class StandardScalerModel private[mllib] (
93112 * Applies standardization transformation on a vector.
94113 *
95114 * @param vector Vector to be standardized.
96- * @return Standardized vector. If the variance of a column is zero, it will return default `0.0`
97- * for the column with zero variance .
115+ * @return Standardized vector. If the std of a column is zero, it will return default `0.0`
116+ * for the column with zero std .
98117 */
99118 override def transform (vector : Vector ): Vector = {
100119 require(mean.size == vector.size)
@@ -108,11 +127,9 @@ class StandardScalerModel private[mllib] (
108127 val values = vs.clone()
109128 val size = values.size
110129 if (withStd) {
111- // Having a local reference of `factor` to avoid overhead as the comment before.
112- val localFactor = factor
113130 var i = 0
114131 while (i < size) {
115- values(i) = ( values(i) - localShift(i)) * localFactor(i)
132+ values(i) = if (std(i) != 0.0 ) ( values(i) - localShift(i)) * ( 1.0 / std(i)) else 0.0
116133 i += 1
117134 }
118135 } else {
@@ -126,15 +143,13 @@ class StandardScalerModel private[mllib] (
126143 case v => throw new IllegalArgumentException (" Do not support vector type " + v.getClass)
127144 }
128145 } else if (withStd) {
129- // Having a local reference of `factor` to avoid overhead as the comment before.
130- val localFactor = factor
131146 vector match {
132147 case DenseVector (vs) =>
133148 val values = vs.clone()
134149 val size = values.size
135150 var i = 0
136151 while (i < size) {
137- values(i) *= localFactor(i )
152+ values(i) *= ( if (std(i) != 0.0 ) 1.0 / std(i) else 0.0 )
138153 i += 1
139154 }
140155 Vectors .dense(values)
@@ -145,7 +160,7 @@ class StandardScalerModel private[mllib] (
145160 val nnz = values.size
146161 var i = 0
147162 while (i < nnz) {
148- values(i) *= localFactor( indices(i))
163+ values(i) *= ( if (std( indices(i)) != 0.0 ) 1.0 / std(indices(i)) else 0.0 )
149164 i += 1
150165 }
151166 Vectors .sparse(size, indices, values)
0 commit comments