-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-5207] [MLLIB] StandardScalerModel mean and variance re-use #4140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
64408a4
997d2e0
9078fe0
fa64dfa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,15 +18,15 @@ | |
| package org.apache.spark.mllib.feature | ||
|
|
||
| import org.apache.spark.Logging | ||
| import org.apache.spark.annotation.Experimental | ||
| import org.apache.spark.annotation.{DeveloperApi, Experimental} | ||
| import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} | ||
| import org.apache.spark.mllib.rdd.RDDFunctions._ | ||
| import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer | ||
| import org.apache.spark.rdd.RDD | ||
|
|
||
| /** | ||
| * :: Experimental :: | ||
| * Standardizes features by removing the mean and scaling to unit variance using column summary | ||
| * Standardizes features by removing the mean and scaling to unit std using column summary | ||
| * statistics on the samples in the training set. | ||
| * | ||
| * @param withMean False by default. Centers the data with mean before scaling. It will build a | ||
|
|
@@ -53,36 +53,55 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging { | |
| val summary = data.treeAggregate(new MultivariateOnlineSummarizer)( | ||
| (aggregator, data) => aggregator.add(data), | ||
| (aggregator1, aggregator2) => aggregator1.merge(aggregator2)) | ||
| new StandardScalerModel(withMean, withStd, summary.mean, summary.variance) | ||
| new StandardScalerModel( | ||
| Vectors.dense(summary.variance.toArray.map(v => math.sqrt(v))), | ||
| summary.mean, | ||
| withStd, | ||
| withMean) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * :: Experimental :: | ||
| * Represents a StandardScaler model that can transform vectors. | ||
| * | ||
| * @param withMean whether to center the data before scaling | ||
| * @param withStd whether to scale the data to have unit standard deviation | ||
| * @param std column standard deviation values | ||
| * @param mean column mean values | ||
| * @param variance column variance values | ||
| * @param withStd whether to scale the data to have unit standard deviation | ||
| * @param withMean whether to center the data before scaling | ||
| */ | ||
| @Experimental | ||
| class StandardScalerModel private[mllib] ( | ||
| val withMean: Boolean, | ||
| val withStd: Boolean, | ||
| class StandardScalerModel ( | ||
| val std: Vector, | ||
| val mean: Vector, | ||
| val variance: Vector) extends VectorTransformer { | ||
|
|
||
| require(mean.size == variance.size) | ||
| var withStd: Boolean, | ||
| var withMean: Boolean) extends VectorTransformer { | ||
|
|
||
| private lazy val factor: Array[Double] = { | ||
| val f = Array.ofDim[Double](variance.size) | ||
| var i = 0 | ||
| while (i < f.size) { | ||
| f(i) = if (variance(i) != 0.0) 1.0 / math.sqrt(variance(i)) else 0.0 | ||
| i += 1 | ||
| def this(std: Vector, mean: Vector) { | ||
| this(std, mean, withStd = std != null, withMean = mean != null) | ||
| require(this.withStd || this.withMean, | ||
| "at least one of std or mean vectors must be provided") | ||
| if (this.withStd && this.withMean) { | ||
| require(mean.size == std.size, | ||
| "mean and std vectors must have equal size if both are provided") | ||
| } | ||
| f | ||
| } | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moving
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have a question about this API. If the default
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds reasonable for me. Although the changes will be larger, this will be more handy and save extra space if withMean is not used.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mengxr Just to make sure I'm clear, are you suggesting changing the
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In my opinion, taking |
||
| def this(std: Vector) = this(std, null) | ||
|
|
||
| @DeveloperApi | ||
| def setWithMean(withMean: Boolean): this.type = { | ||
| require(!(withMean && this.mean == null),"cannot set withMean to true while mean is null") | ||
| this.withMean = withMean | ||
| this | ||
| } | ||
|
|
||
| @DeveloperApi | ||
| def setWithStd(withStd: Boolean): this.type = { | ||
| require(!(withStd && this.std == null), | ||
| "cannot set withStd to true while std is null") | ||
| this.withStd = withStd | ||
| this | ||
| } | ||
|
|
||
| // Since `shift` will be only used in `withMean` branch, we have it as | ||
|
|
@@ -94,8 +113,8 @@ class StandardScalerModel private[mllib] ( | |
| * Applies standardization transformation on a vector. | ||
| * | ||
| * @param vector Vector to be standardized. | ||
| * @return Standardized vector. If the variance of a column is zero, it will return default `0.0` | ||
| * for the column with zero variance. | ||
| * @return Standardized vector. If the std of a column is zero, it will return default `0.0` | ||
| * for the column with zero std. | ||
| */ | ||
| override def transform(vector: Vector): Vector = { | ||
| require(mean.size == vector.size) | ||
|
|
@@ -109,11 +128,9 @@ class StandardScalerModel private[mllib] ( | |
| val values = vs.clone() | ||
| val size = values.size | ||
| if (withStd) { | ||
| // Having a local reference of `factor` to avoid overhead as the comment before. | ||
| val localFactor = factor | ||
| var i = 0 | ||
| while (i < size) { | ||
| values(i) = (values(i) - localShift(i)) * localFactor(i) | ||
| values(i) = if (std(i) != 0.0) (values(i) - localShift(i)) * (1.0 / std(i)) else 0.0 | ||
| i += 1 | ||
| } | ||
| } else { | ||
|
|
@@ -127,15 +144,13 @@ class StandardScalerModel private[mllib] ( | |
| case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) | ||
| } | ||
| } else if (withStd) { | ||
| // Having a local reference of `factor` to avoid overhead as the comment before. | ||
| val localFactor = factor | ||
| vector match { | ||
| case DenseVector(vs) => | ||
| val values = vs.clone() | ||
| val size = values.size | ||
| var i = 0 | ||
| while(i < size) { | ||
| values(i) *= localFactor(i) | ||
| values(i) *= (if (std(i) != 0.0) 1.0 / std(i) else 0.0) | ||
| i += 1 | ||
| } | ||
| Vectors.dense(values) | ||
|
|
@@ -146,7 +161,7 @@ class StandardScalerModel private[mllib] ( | |
| val nnz = values.size | ||
| var i = 0 | ||
| while (i < nnz) { | ||
| values(i) *= localFactor(indices(i)) | ||
| values(i) *= (if (std(indices(i)) != 0.0) 1.0 / std(indices(i)) else 0.0) | ||
| i += 1 | ||
| } | ||
| Vectors.sparse(size, indices, values) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default argument is not friendly for Java though; why don't we add another constructor which takes only mean and variance?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, users will want to know if
withMeanorwithStdis used, do we really need to have them as private variables?