Skip to content

Commit a8a402e

Browse files
committed
refactor ml.StandardScalerModel construction
1 parent 76ef338 commit a8a402e

File tree

2 files changed

+31
-33
lines changed

2 files changed

+31
-33
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
9090
val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v }
9191
val scaler = new feature.StandardScaler(withMean = $(withMean), withStd = $(withStd))
9292
val scalerModel = scaler.fit(input)
93-
copyValues(new StandardScalerModel(uid, scalerModel).setParent(this))
93+
copyValues(new StandardScalerModel(uid, scalerModel.std, scalerModel.mean).setParent(this))
9494
}
9595

9696
override def transformSchema(schema: StructType): StructType = {
@@ -116,21 +116,19 @@ object StandardScaler extends DefaultParamsReadable[StandardScaler] {
116116
/**
117117
* :: Experimental ::
118118
* Model fitted by [[StandardScaler]].
119+
*
120+
* @param std Standard deviation of the StandardScalerModel
121+
* @param mean Mean of the StandardScalerModel
119122
*/
120123
@Experimental
121124
class StandardScalerModel private[ml] (
122125
override val uid: String,
123-
scaler: feature.StandardScalerModel)
126+
val std: Vector,
127+
val mean: Vector)
124128
extends Model[StandardScalerModel] with StandardScalerParams with MLWritable {
125129

126130
import StandardScalerModel._
127131

128-
/** Standard deviation of the StandardScalerModel */
129-
val std: Vector = scaler.std
130-
131-
/** Mean of the StandardScalerModel */
132-
val mean: Vector = scaler.mean
133-
134132
/** @group setParam */
135133
def setInputCol(value: String): this.type = set(inputCol, value)
136134

@@ -139,6 +137,7 @@ class StandardScalerModel private[ml] (
139137

140138
override def transform(dataset: DataFrame): DataFrame = {
141139
transformSchema(dataset.schema, logging = true)
140+
val scaler = new feature.StandardScalerModel(std, mean, $(withStd), $(withMean))
142141
val scale = udf { scaler.transform _ }
143142
dataset.withColumn($(outputCol), scale(col($(inputCol))))
144143
}
@@ -154,7 +153,7 @@ class StandardScalerModel private[ml] (
154153
}
155154

156155
override def copy(extra: ParamMap): StandardScalerModel = {
157-
val copied = new StandardScalerModel(uid, scaler)
156+
val copied = new StandardScalerModel(uid, std, mean)
158157
copyValues(copied, extra).setParent(parent)
159158
}
160159

@@ -168,11 +167,11 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] {
168167
private[StandardScalerModel]
169168
class StandardScalerModelWriter(instance: StandardScalerModel) extends MLWriter {
170169

171-
private case class Data(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean)
170+
private case class Data(std: Vector, mean: Vector)
172171

173172
override protected def saveImpl(path: String): Unit = {
174173
DefaultParamsWriter.saveMetadata(instance, path, sc)
175-
val data = Data(instance.std, instance.mean, instance.getWithStd, instance.getWithMean)
174+
val data = Data(instance.std, instance.mean)
176175
val dataPath = new Path(path, "data").toString
177176
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
178177
}
@@ -185,12 +184,10 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] {
185184
override def load(path: String): StandardScalerModel = {
186185
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
187186
val dataPath = new Path(path, "data").toString
188-
val Row(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean) =
189-
sqlContext.read.parquet(dataPath)
190-
.select("std", "mean", "withStd", "withMean")
191-
.head()
192-
val oldModel = new feature.StandardScalerModel(std, mean, withStd, withMean)
193-
val model = new StandardScalerModel(metadata.uid, oldModel)
187+
val Row(std: Vector, mean: Vector) = sqlContext.read.parquet(dataPath)
188+
.select("std", "mean")
189+
.head()
190+
val model = new StandardScalerModel(metadata.uid, std, mean)
194191
DefaultParamsReader.getAndSetParams(model, metadata)
195192
model
196193
}

mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
7070

7171
test("params") {
7272
ParamsSuite.checkParams(new StandardScaler)
73-
val oldModel = new feature.StandardScalerModel(Vectors.dense(1.0), Vectors.dense(2.0))
74-
ParamsSuite.checkParams(new StandardScalerModel("empty", oldModel))
73+
ParamsSuite.checkParams(new StandardScalerModel("empty",
74+
Vectors.dense(1.0), Vectors.dense(2.0)))
7575
}
7676

7777
test("Standardization with default parameter") {
@@ -116,19 +116,20 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
116116
assertResult(standardScaler3.transform(df3))
117117
}
118118

119-
test("read/write") {
120-
def checkModelData(model1: StandardScalerModel, model2: StandardScalerModel): Unit = {
121-
assert(model1.mean === model2.mean)
122-
assert(model1.std === model2.std)
123-
}
124-
val allParams: Map[String, Any] = Map(
125-
"inputCol" -> "features",
126-
"outputCol" -> "standardized_features",
127-
"withMean" -> true,
128-
"withStd" -> true
129-
)
130-
val df = sqlContext.createDataFrame(data.zip(resWithBoth)).toDF("features", "expected")
131-
val standardScaler = new StandardScaler()
132-
testEstimatorAndModelReadWrite(standardScaler, df, allParams, checkModelData)
119+
test("StandardScaler read/write") {
120+
val t = new StandardScaler()
121+
.setInputCol("myInputCol")
122+
.setOutputCol("myOutputCol")
123+
.setWithMean(true)
124+
.setWithStd(true)
125+
testDefaultReadWrite(t)
126+
}
127+
128+
test("StandardScalerModel read/write") {
129+
val instance = new StandardScalerModel("myStandardScalerModel",
130+
Vectors.dense(0.5, 1.2), Vectors.dense(1.0, 10.0))
131+
val newInstance = testDefaultReadWrite(instance)
132+
assert(newInstance.std === instance.std)
133+
assert(newInstance.mean === instance.mean)
133134
}
134135
}

0 commit comments

Comments
 (0)