Skip to content

Commit 9ace2e5

Browse files
yanboliangmengxr
authored andcommitted
[SPARK-11852][ML] StandardScaler minor refactor
```withStd``` and ```withMean``` should be params of ```StandardScaler``` and ```StandardScalerModel```. Author: Yanbo Liang <[email protected]> Closes #9839 from yanboliang/standardScaler-refactor.
1 parent a66142d commit 9ace2e5

File tree

2 files changed

+32
-39
lines changed

2 files changed

+32
-39
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala

Lines changed: 28 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,30 @@ import org.apache.spark.sql.types.{StructField, StructType}
3636
private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol {
3737

3838
/**
39-
* Centers the data with mean before scaling.
39+
* Whether to center the data with mean before scaling.
4040
* It will build a dense output, so this does not work on sparse input
4141
* and will raise an exception.
4242
* Default: false
4343
* @group param
4444
*/
45-
val withMean: BooleanParam = new BooleanParam(this, "withMean", "Center data with mean")
45+
val withMean: BooleanParam = new BooleanParam(this, "withMean",
46+
"Whether to center data with mean")
47+
48+
/** @group getParam */
49+
def getWithMean: Boolean = $(withMean)
4650

4751
/**
48-
* Scales the data to unit standard deviation.
52+
* Whether to scale the data to unit standard deviation.
4953
* Default: true
5054
* @group param
5155
*/
52-
val withStd: BooleanParam = new BooleanParam(this, "withStd", "Scale to unit standard deviation")
56+
val withStd: BooleanParam = new BooleanParam(this, "withStd",
57+
"Whether to scale the data to unit standard deviation")
58+
59+
/** @group getParam */
60+
def getWithStd: Boolean = $(withStd)
61+
62+
setDefault(withMean -> false, withStd -> true)
5363
}
5464

5565
/**
@@ -63,8 +73,6 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
6373

6474
def this() = this(Identifiable.randomUID("stdScal"))
6575

66-
setDefault(withMean -> false, withStd -> true)
67-
6876
/** @group setParam */
6977
def setInputCol(value: String): this.type = set(inputCol, value)
7078

@@ -82,7 +90,7 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
8290
val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v }
8391
val scaler = new feature.StandardScaler(withMean = $(withMean), withStd = $(withStd))
8492
val scalerModel = scaler.fit(input)
85-
copyValues(new StandardScalerModel(uid, scalerModel).setParent(this))
93+
copyValues(new StandardScalerModel(uid, scalerModel.std, scalerModel.mean).setParent(this))
8694
}
8795

8896
override def transformSchema(schema: StructType): StructType = {
@@ -108,29 +116,19 @@ object StandardScaler extends DefaultParamsReadable[StandardScaler] {
108116
/**
109117
* :: Experimental ::
110118
* Model fitted by [[StandardScaler]].
119+
*
120+
* @param std Standard deviation of the StandardScalerModel
121+
* @param mean Mean of the StandardScalerModel
111122
*/
112123
@Experimental
113124
class StandardScalerModel private[ml] (
114125
override val uid: String,
115-
scaler: feature.StandardScalerModel)
126+
val std: Vector,
127+
val mean: Vector)
116128
extends Model[StandardScalerModel] with StandardScalerParams with MLWritable {
117129

118130
import StandardScalerModel._
119131

120-
/** Standard deviation of the StandardScalerModel */
121-
val std: Vector = scaler.std
122-
123-
/** Mean of the StandardScalerModel */
124-
val mean: Vector = scaler.mean
125-
126-
/** Whether to scale to unit standard deviation. */
127-
@Since("1.6.0")
128-
def getWithStd: Boolean = scaler.withStd
129-
130-
/** Whether to center data with mean. */
131-
@Since("1.6.0")
132-
def getWithMean: Boolean = scaler.withMean
133-
134132
/** @group setParam */
135133
def setInputCol(value: String): this.type = set(inputCol, value)
136134

@@ -139,6 +137,7 @@ class StandardScalerModel private[ml] (
139137

140138
override def transform(dataset: DataFrame): DataFrame = {
141139
transformSchema(dataset.schema, logging = true)
140+
val scaler = new feature.StandardScalerModel(std, mean, $(withStd), $(withMean))
142141
val scale = udf { scaler.transform _ }
143142
dataset.withColumn($(outputCol), scale(col($(inputCol))))
144143
}
@@ -154,7 +153,7 @@ class StandardScalerModel private[ml] (
154153
}
155154

156155
override def copy(extra: ParamMap): StandardScalerModel = {
157-
val copied = new StandardScalerModel(uid, scaler)
156+
val copied = new StandardScalerModel(uid, std, mean)
158157
copyValues(copied, extra).setParent(parent)
159158
}
160159

@@ -168,11 +167,11 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] {
168167
private[StandardScalerModel]
169168
class StandardScalerModelWriter(instance: StandardScalerModel) extends MLWriter {
170169

171-
private case class Data(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean)
170+
private case class Data(std: Vector, mean: Vector)
172171

173172
override protected def saveImpl(path: String): Unit = {
174173
DefaultParamsWriter.saveMetadata(instance, path, sc)
175-
val data = Data(instance.std, instance.mean, instance.getWithStd, instance.getWithMean)
174+
val data = Data(instance.std, instance.mean)
176175
val dataPath = new Path(path, "data").toString
177176
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
178177
}
@@ -185,13 +184,10 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] {
185184
override def load(path: String): StandardScalerModel = {
186185
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
187186
val dataPath = new Path(path, "data").toString
188-
val Row(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean) =
189-
sqlContext.read.parquet(dataPath)
190-
.select("std", "mean", "withStd", "withMean")
191-
.head()
192-
// This is very likely to change in the future because withStd and withMean should be params.
193-
val oldModel = new feature.StandardScalerModel(std, mean, withStd, withMean)
194-
val model = new StandardScalerModel(metadata.uid, oldModel)
187+
val Row(std: Vector, mean: Vector) = sqlContext.read.parquet(dataPath)
188+
.select("std", "mean")
189+
.head()
190+
val model = new StandardScalerModel(metadata.uid, std, mean)
195191
DefaultParamsReader.getAndSetParams(model, metadata)
196192
model
197193
}

mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
7070

7171
test("params") {
7272
ParamsSuite.checkParams(new StandardScaler)
73-
val oldModel = new feature.StandardScalerModel(Vectors.dense(1.0), Vectors.dense(2.0))
74-
ParamsSuite.checkParams(new StandardScalerModel("empty", oldModel))
73+
ParamsSuite.checkParams(new StandardScalerModel("empty",
74+
Vectors.dense(1.0), Vectors.dense(2.0)))
7575
}
7676

7777
test("Standardization with default parameter") {
@@ -126,13 +126,10 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
126126
}
127127

128128
test("StandardScalerModel read/write") {
129-
val oldModel = new feature.StandardScalerModel(
130-
Vectors.dense(1.0, 2.0), Vectors.dense(3.0, 4.0), false, true)
131-
val instance = new StandardScalerModel("myStandardScalerModel", oldModel)
129+
val instance = new StandardScalerModel("myStandardScalerModel",
130+
Vectors.dense(1.0, 2.0), Vectors.dense(3.0, 4.0))
132131
val newInstance = testDefaultReadWrite(instance)
133132
assert(newInstance.std === instance.std)
134133
assert(newInstance.mean === instance.mean)
135-
assert(newInstance.getWithStd === instance.getWithStd)
136-
assert(newInstance.getWithMean === instance.getWithMean)
137134
}
138135
}

0 commit comments

Comments
 (0)