@@ -36,20 +36,30 @@ import org.apache.spark.sql.types.{StructField, StructType}
3636private [feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol {
3737
3838 /**
39- * Centers the data with mean before scaling.
39+ * Whether to center the data with mean before scaling.
4040 * It will build a dense output, so this does not work on sparse input
4141 * and will raise an exception.
4242 * Default: false
4343 * @group param
4444 */
45- val withMean : BooleanParam = new BooleanParam (this , " withMean" , " Center data with mean" )
45+ val withMean : BooleanParam = new BooleanParam (this , " withMean" ,
46+ " Whether to center data with mean" )
47+
48+ /** @group getParam */
49+ def getWithMean : Boolean = $(withMean)
4650
4751 /**
48- * Scales the data to unit standard deviation.
52+ * Whether to scale the data to unit standard deviation.
4953 * Default: true
5054 * @group param
5155 */
52- val withStd : BooleanParam = new BooleanParam (this , " withStd" , " Scale to unit standard deviation" )
56+ val withStd : BooleanParam = new BooleanParam (this , " withStd" ,
57+ " Whether to scale the data to unit standard deviation" )
58+
59+ /** @group getParam */
60+ def getWithStd : Boolean = $(withStd)
61+
62+ setDefault(withMean -> false , withStd -> true )
5363}
5464
5565/**
@@ -63,8 +73,6 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
6373
6474 def this () = this (Identifiable .randomUID(" stdScal" ))
6575
66- setDefault(withMean -> false , withStd -> true )
67-
6876 /** @group setParam */
6977 def setInputCol (value : String ): this .type = set(inputCol, value)
7078
@@ -82,7 +90,7 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
8290 val input = dataset.select($(inputCol)).map { case Row (v : Vector ) => v }
8391 val scaler = new feature.StandardScaler (withMean = $(withMean), withStd = $(withStd))
8492 val scalerModel = scaler.fit(input)
85- copyValues(new StandardScalerModel (uid, scalerModel).setParent(this ))
93+ copyValues(new StandardScalerModel (uid, scalerModel.std, scalerModel.mean ).setParent(this ))
8694 }
8795
8896 override def transformSchema (schema : StructType ): StructType = {
@@ -108,29 +116,19 @@ object StandardScaler extends DefaultParamsReadable[StandardScaler] {
108116/**
109117 * :: Experimental ::
110118 * Model fitted by [[StandardScaler ]].
119+ *
120+ * @param std Standard deviation of the StandardScalerModel
121+ * @param mean Mean of the StandardScalerModel
111122 */
112123@ Experimental
113124class StandardScalerModel private [ml] (
114125 override val uid : String ,
115- scaler : feature.StandardScalerModel )
126+ val std : Vector ,
127+ val mean : Vector )
116128 extends Model [StandardScalerModel ] with StandardScalerParams with MLWritable {
117129
118130 import StandardScalerModel ._
119131
120- /** Standard deviation of the StandardScalerModel */
121- val std : Vector = scaler.std
122-
123- /** Mean of the StandardScalerModel */
124- val mean : Vector = scaler.mean
125-
126- /** Whether to scale to unit standard deviation. */
127- @ Since (" 1.6.0" )
128- def getWithStd : Boolean = scaler.withStd
129-
130- /** Whether to center data with mean. */
131- @ Since (" 1.6.0" )
132- def getWithMean : Boolean = scaler.withMean
133-
134132 /** @group setParam */
135133 def setInputCol (value : String ): this .type = set(inputCol, value)
136134
@@ -139,6 +137,7 @@ class StandardScalerModel private[ml] (
139137
140138 override def transform (dataset : DataFrame ): DataFrame = {
141139 transformSchema(dataset.schema, logging = true )
140+ val scaler = new feature.StandardScalerModel (std, mean, $(withStd), $(withMean))
142141 val scale = udf { scaler.transform _ }
143142 dataset.withColumn($(outputCol), scale(col($(inputCol))))
144143 }
@@ -154,7 +153,7 @@ class StandardScalerModel private[ml] (
154153 }
155154
156155 override def copy (extra : ParamMap ): StandardScalerModel = {
157- val copied = new StandardScalerModel (uid, scaler )
156+ val copied = new StandardScalerModel (uid, std, mean )
158157 copyValues(copied, extra).setParent(parent)
159158 }
160159
@@ -168,11 +167,11 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] {
168167 private [StandardScalerModel ]
169168 class StandardScalerModelWriter (instance : StandardScalerModel ) extends MLWriter {
170169
171- private case class Data (std : Vector , mean : Vector , withStd : Boolean , withMean : Boolean )
170+ private case class Data (std : Vector , mean : Vector )
172171
173172 override protected def saveImpl (path : String ): Unit = {
174173 DefaultParamsWriter .saveMetadata(instance, path, sc)
175- val data = Data (instance.std, instance.mean, instance.getWithStd, instance.getWithMean )
174+ val data = Data (instance.std, instance.mean)
176175 val dataPath = new Path (path, " data" ).toString
177176 sqlContext.createDataFrame(Seq (data)).repartition(1 ).write.parquet(dataPath)
178177 }
@@ -185,13 +184,10 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] {
185184 override def load (path : String ): StandardScalerModel = {
186185 val metadata = DefaultParamsReader .loadMetadata(path, sc, className)
187186 val dataPath = new Path (path, " data" ).toString
188- val Row (std : Vector , mean : Vector , withStd : Boolean , withMean : Boolean ) =
189- sqlContext.read.parquet(dataPath)
190- .select(" std" , " mean" , " withStd" , " withMean" )
191- .head()
192- // This is very likely to change in the future because withStd and withMean should be params.
193- val oldModel = new feature.StandardScalerModel (std, mean, withStd, withMean)
194- val model = new StandardScalerModel (metadata.uid, oldModel)
187+ val Row (std : Vector , mean : Vector ) = sqlContext.read.parquet(dataPath)
188+ .select(" std" , " mean" )
189+ .head()
190+ val model = new StandardScalerModel (metadata.uid, std, mean)
195191 DefaultParamsReader .getAndSetParams(model, metadata)
196192 model
197193 }
0 commit comments