@@ -90,7 +90,7 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
9090 val input = dataset.select($(inputCol)).map { case Row (v : Vector ) => v }
9191 val scaler = new feature.StandardScaler (withMean = $(withMean), withStd = $(withStd))
9292 val scalerModel = scaler.fit(input)
93- copyValues(new StandardScalerModel (uid, scalerModel).setParent(this ))
93+ copyValues(new StandardScalerModel (uid, scalerModel.std, scalerModel.mean ).setParent(this ))
9494 }
9595
9696 override def transformSchema (schema : StructType ): StructType = {
@@ -116,21 +116,19 @@ object StandardScaler extends DefaultParamsReadable[StandardScaler] {
116116/**
117117 * :: Experimental ::
118118 * Model fitted by [[StandardScaler ]].
119+ *
120+ * @param std Standard deviation of the StandardScalerModel
121+ * @param mean Mean of the StandardScalerModel
119122 */
120123@ Experimental
121124class StandardScalerModel private [ml] (
122125 override val uid : String ,
123- scaler : feature.StandardScalerModel )
126+ val std : Vector ,
127+ val mean : Vector )
124128 extends Model [StandardScalerModel ] with StandardScalerParams with MLWritable {
125129
126130 import StandardScalerModel ._
127131
128- /** Standard deviation of the StandardScalerModel */
129- val std : Vector = scaler.std
130-
131- /** Mean of the StandardScalerModel */
132- val mean : Vector = scaler.mean
133-
134132 /** @group setParam */
135133 def setInputCol (value : String ): this .type = set(inputCol, value)
136134
@@ -139,6 +137,7 @@ class StandardScalerModel private[ml] (
139137
140138 override def transform (dataset : DataFrame ): DataFrame = {
141139 transformSchema(dataset.schema, logging = true )
140+ val scaler = new feature.StandardScalerModel (std, mean, $(withStd), $(withMean))
142141 val scale = udf { scaler.transform _ }
143142 dataset.withColumn($(outputCol), scale(col($(inputCol))))
144143 }
@@ -154,7 +153,7 @@ class StandardScalerModel private[ml] (
154153 }
155154
156155 override def copy (extra : ParamMap ): StandardScalerModel = {
157- val copied = new StandardScalerModel (uid, scaler )
156+ val copied = new StandardScalerModel (uid, std, mean )
158157 copyValues(copied, extra).setParent(parent)
159158 }
160159
@@ -168,11 +167,11 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] {
168167 private [StandardScalerModel ]
169168 class StandardScalerModelWriter (instance : StandardScalerModel ) extends MLWriter {
170169
171- private case class Data (std : Vector , mean : Vector , withStd : Boolean , withMean : Boolean )
170+ private case class Data (std : Vector , mean : Vector )
172171
173172 override protected def saveImpl (path : String ): Unit = {
174173 DefaultParamsWriter .saveMetadata(instance, path, sc)
175- val data = Data (instance.std, instance.mean, instance.getWithStd, instance.getWithMean )
174+ val data = Data (instance.std, instance.mean)
176175 val dataPath = new Path (path, " data" ).toString
177176 sqlContext.createDataFrame(Seq (data)).repartition(1 ).write.parquet(dataPath)
178177 }
@@ -185,12 +184,10 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] {
185184 override def load (path : String ): StandardScalerModel = {
186185 val metadata = DefaultParamsReader .loadMetadata(path, sc, className)
187186 val dataPath = new Path (path, " data" ).toString
188- val Row (std : Vector , mean : Vector , withStd : Boolean , withMean : Boolean ) =
189- sqlContext.read.parquet(dataPath)
190- .select(" std" , " mean" , " withStd" , " withMean" )
191- .head()
192- val oldModel = new feature.StandardScalerModel (std, mean, withStd, withMean)
193- val model = new StandardScalerModel (metadata.uid, oldModel)
187+ val Row (std : Vector , mean : Vector ) = sqlContext.read.parquet(dataPath)
188+ .select(" std" , " mean" )
189+ .head()
190+ val model = new StandardScalerModel (metadata.uid, std, mean)
194191 DefaultParamsReader .getAndSetParams(model, metadata)
195192 model
196193 }
0 commit comments