From dcab97aa1dbf4573c3fcb379c3152bca1b0837e6 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 8 Apr 2015 13:04:51 -0700 Subject: [PATCH 01/21] add codegen for shared params --- .../ml/param/shared/SharedParamCodeGen.scala | 159 ++++++++++++++ .../spark/ml/param/shared/sharedParams.scala | 204 ++++++++++++++++++ .../apache/spark/ml/param/sharedParams.scala | 151 ------------- 3 files changed, 363 insertions(+), 151 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala delete mode 100644 mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala new file mode 100644 index 000000000000..5403012c03a8 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param.shared + +import java.io.PrintWriter + +import scala.reflect.ClassTag + +/** + * Code generator for shared params (sharedParams.scala). Run under the Spark folder with + * {{{ + * build/sbt "mllib/runMain org.apache.spark.ml.param.shared.SharedParamCodeGen" + * }}} + */ +private[shared] object SharedParamCodeGen { + + def main(args: Array[String]): Unit = { + val params = Seq( + ParamDesc[Double]("regParam", "regularization parameter"), + ParamDesc[Int]("maxIter", "max number of iterations"), + ParamDesc[String]("featuresCol", "features column name"), + ParamDesc[String]("labelCol", "label column name"), + ParamDesc[String]("predictionCol", "prediction column name"), + ParamDesc[String]("rawPredictionCol", "raw prediction (a.k.a. confidence) column name"), + ParamDesc[String]( + "probabilityCol", "column name for predicted class conditional probabilities"), + ParamDesc[Double]("threshold", "threshold in prediction"), + ParamDesc[String]("inputCol", "input column name"), + ParamDesc[String]("outputCol", "output column name"), + ParamDesc[Int]("checkpointInterval", "checkpoint interval")) + + val code = genSharedParams(params) + val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala" + val writer = new PrintWriter(file) + writer.write(code) + writer.close() + } + + /** Description of a param. */ + private case class ParamDesc[T: ClassTag](name: String, doc: String) { + require(name.matches("[a-z][a-zA-Z0-9]*"), s"Param name $name is invalid.") + require(doc.nonEmpty) // TODO: more rigorous on doc + + def paramTypeName: String = { + val c = implicitly[ClassTag[T]].runtimeClass + c match { + case _ if c == classOf[Int] => "IntParam" + case _ if c == classOf[Long] => "LongParam" + case _ if c == classOf[Float] => "FloatParam" + case _ if c == classOf[Double] => "DoubleParam" + case _ if c == classOf[Boolean] => "BooleanParam" + case _ => s"Param[${getTypeString(c)}]" + } + } + + def valueTypeName: String = { + val c = implicitly[ClassTag[T]].runtimeClass + getTypeString(c) + } + + private def getTypeString(c: Class[_]): String = { + c match { + case _ if c == classOf[Int] => "Int" + case _ if c == classOf[Long] => "Long" + case _ if c == classOf[Float] => "Float" + case _ if c == classOf[Double] => "Double" + case _ if c == classOf[Boolean] => "Boolean" + case _ if c == classOf[String] => "String" + case _ if c.isArray => s"Array[${getTypeString(c.getComponentType)}]" + } + } + } + + /** Generates the HasParam trait code for the input param. */ + private def genHasParamTrait(param: ParamDesc[_]): String = { + val name = param.name + val Name = name(0).toUpper +: name.substring(1) + val Param = param.paramTypeName + val T = param.valueTypeName + val doc = param.doc + + s""" + |/** + | * :: DeveloperApi :: + | * Trait for shared param $name. + | */ + |@DeveloperApi + |trait Has$Name extends Params { + | /** + | * Param for $doc. + | * @group param + | */ + | final val $name: $Param = new $Param(this, "$name", "$doc") + | + | /** @group getParam */ + | final def get$Name: $T = get($name) + | + | /** @group setParam */ + | protected def set$Name(value: $T): this.type = set($name, value) + |} + """.stripMargin + } + + /** Generates Scala source code for the input params with header. */ + private def genSharedParams(params: Seq[ParamDesc[_]]): String = { + val header = + """ + |/* + | * Licensed to the Apache Software Foundation (ASF) under one or more + | * contributor license agreements. See the NOTICE file distributed with + | * this work for additional information regarding copyright ownership. + | * The ASF licenses this file to You under the Apache License, Version 2.0 + | * (the "License"); you may not use this file except in compliance with + | * the License. You may obtain a copy of the License at + | * + | * http://www.apache.org/licenses/LICENSE-2.0 + | * + | * Unless required by applicable law or agreed to in writing, software + | * distributed under the License is distributed on an "AS IS" BASIS, + | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + | * See the License for the specific language governing permissions and + | * limitations under the License. + | */ + | + |package org.apache.spark.ml.param.shared + | + |import org.apache.spark.annotation.DeveloperApi + |import org.apache.spark.ml.param._ + | + |// DO NOT MODIFY THIS FILE! It was generated by SharedParamCodeGen. + | + |// scalastyle:off + """.stripMargin + + val footer = + """ + |// scalastyle:on + """.stripMargin + + val traits = params.map(genHasParamTrait).mkString + + header + traits + footer + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala new file mode 100644 index 000000000000..f26a6773ed0d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -0,0 +1,204 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param.shared + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.ml.param._ + +// DO NOT MODIFY THIS FILE! It was generated by SharedParamCodeGen. + +// scalastyle:off + +/** + * :: DeveloperApi :: + * Trait for shared param regParam. + */ +@DeveloperApi +trait HasRegParam extends Params { + /** + * Param for regularization parameter. + * @group param + */ + final val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter") + + /** @group getParam */ + final def getRegParam: Double = get(regParam) +} + +/** + * :: DeveloperApi :: + * Trait for shared param maxIter. + */ +@DeveloperApi +trait HasMaxIter extends Params { + /** + * Param for max number of iterations. + * @group param + */ + final val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations") + + /** @group getParam */ + final def getMaxIter: Int = get(maxIter) +} + +/** + * :: DeveloperApi :: + * Trait for shared param featuresCol. + */ +@DeveloperApi +trait HasFeaturesCol extends Params { + /** + * Param for features column name. + * @group param + */ + final val featuresCol: Param[String] = new Param[String](this, "featuresCol", "features column name") + + /** @group getParam */ + final def getFeaturesCol: String = get(featuresCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param labelCol. + */ +@DeveloperApi +trait HasLabelCol extends Params { + /** + * Param for label column name. + * @group param + */ + final val labelCol: Param[String] = new Param[String](this, "labelCol", "label column name") + + /** @group getParam */ + final def getLabelCol: String = get(labelCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param predictionCol. + */ +@DeveloperApi +trait HasPredictionCol extends Params { + /** + * Param for prediction column name. + * @group param + */ + final val predictionCol: Param[String] = new Param[String](this, "predictionCol", "prediction column name") + + /** @group getParam */ + final def getPredictionCol: String = get(predictionCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param rawPredictionCol. + */ +@DeveloperApi +trait HasRawPredictionCol extends Params { + /** + * Param for raw prediction (a.k.a. confidence) column name. + * @group param + */ + final val rawPredictionCol: Param[String] = new Param[String](this, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name") + + /** @group getParam */ + final def getRawPredictionCol: String = get(rawPredictionCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param probabilityCol. + */ +@DeveloperApi +trait HasProbabilityCol extends Params { + /** + * Param for column name for predicted class conditional probabilities. + * @group param + */ + final val probabilityCol: Param[String] = new Param[String](this, "probabilityCol", "column name for predicted class conditional probabilities") + + /** @group getParam */ + final def getProbabilityCol: String = get(probabilityCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param threshold. + */ +@DeveloperApi +trait HasThreshold extends Params { + /** + * Param for threshold in prediction. + * @group param + */ + final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in prediction") + + /** @group getParam */ + final def getThreshold: Double = get(threshold) +} + +/** + * :: DeveloperApi :: + * Trait for shared param inputCol. + */ +@DeveloperApi +trait HasInputCol extends Params { + /** + * Param for input column name. + * @group param + */ + final val inputCol: Param[String] = new Param[String](this, "inputCol", "input column name") + + /** @group getParam */ + final def getInputCol: String = get(inputCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param outputCol. + */ +@DeveloperApi +trait HasOutputCol extends Params { + /** + * Param for output column name. + * @group param + */ + final val outputCol: Param[String] = new Param[String](this, "outputCol", "output column name") + + /** @group getParam */ + final def getOutputCol: String = get(outputCol) +} + +/** + * :: DeveloperApi :: + * Trait for shared param checkpointInterval. + */ +@DeveloperApi +trait HasCheckpointInterval extends Params { + /** + * Param for checkpoint interval. + * @group param + */ + final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval") + + /** @group getParam */ + final def getCheckpointInterval: Int = get(checkpointInterval) +} + +// scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala deleted file mode 100644 index 5d660d1e151a..000000000000 --- a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala +++ /dev/null @@ -1,151 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.param - -/* NOTE TO DEVELOPERS: - * If you mix these parameter traits into your algorithm, please add a setter method as well - * so that users may use a builder pattern: - * val myLearner = new MyLearner().setParam1(x).setParam2(y)... - */ - -private[ml] trait HasRegParam extends Params { - /** - * param for regularization parameter - * @group param - */ - val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter") - - /** @group getParam */ - def getRegParam: Double = get(regParam) -} - -private[ml] trait HasMaxIter extends Params { - /** - * param for max number of iterations - * @group param - */ - val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations") - - /** @group getParam */ - def getMaxIter: Int = get(maxIter) -} - -private[ml] trait HasFeaturesCol extends Params { - /** - * param for features column name - * @group param - */ - val featuresCol: Param[String] = - new Param(this, "featuresCol", "features column name", Some("features")) - - /** @group getParam */ - def getFeaturesCol: String = get(featuresCol) -} - -private[ml] trait HasLabelCol extends Params { - /** - * param for label column name - * @group param - */ - val labelCol: Param[String] = new Param(this, "labelCol", "label column name", Some("label")) - - /** @group getParam */ - def getLabelCol: String = get(labelCol) -} - -private[ml] trait HasPredictionCol extends Params { - /** - * param for prediction column name - * @group param - */ - val predictionCol: Param[String] = - new Param(this, "predictionCol", "prediction column name", Some("prediction")) - - /** @group getParam */ - def getPredictionCol: String = get(predictionCol) -} - -private[ml] trait HasRawPredictionCol extends Params { - /** - * param for raw prediction column name - * @group param - */ - val rawPredictionCol: Param[String] = - new Param(this, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name", - Some("rawPrediction")) - - /** @group getParam */ - def getRawPredictionCol: String = get(rawPredictionCol) -} - -private[ml] trait HasProbabilityCol extends Params { - /** - * param for predicted class conditional probabilities column name - * @group param - */ - val probabilityCol: Param[String] = - new Param(this, "probabilityCol", "column name for predicted class conditional probabilities", - Some("probability")) - - /** @group getParam */ - def getProbabilityCol: String = get(probabilityCol) -} - -private[ml] trait HasThreshold extends Params { - /** - * param for threshold in (binary) prediction - * @group param - */ - val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in prediction") - - /** @group getParam */ - def getThreshold: Double = get(threshold) -} - -private[ml] trait HasInputCol extends Params { - /** - * param for input column name - * @group param - */ - val inputCol: Param[String] = new Param(this, "inputCol", "input column name") - - /** @group getParam */ - def getInputCol: String = get(inputCol) -} - -private[ml] trait HasOutputCol extends Params { - /** - * param for output column name - * @group param - */ - val outputCol: Param[String] = new Param(this, "outputCol", "output column name") - - /** @group getParam */ - def getOutputCol: String = get(outputCol) -} - -private[ml] trait HasCheckpointInterval extends Params { - /** - * param for checkpoint interval - * @group param - */ - val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval") - - /** @group getParam */ - def getCheckpointInterval: Int = get(checkpointInterval) -} From abb7a3bcabe9091d5b5ad493660daaa4c671f613 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 8 Apr 2015 13:56:08 -0700 Subject: [PATCH 02/21] update default values handling --- .../org/apache/spark/ml/Transformer.scala | 3 +- .../spark/ml/classification/Classifier.scala | 8 +- .../classification/LogisticRegression.scala | 14 ++- .../ProbabilisticClassifier.scala | 3 +- .../BinaryClassificationEvaluator.scala | 5 +- .../apache/spark/ml/feature/HashingTF.scala | 4 +- .../apache/spark/ml/feature/Normalizer.scala | 5 +- .../spark/ml/feature/StandardScaler.scala | 1 + .../apache/spark/ml/feature/Tokenizer.scala | 10 +- .../spark/ml/impl/estimator/Predictor.scala | 1 + .../org/apache/spark/ml/param/params.scala | 91 ++++++++++--------- .../ml/param/shared/SharedParamCodeGen.scala | 3 - .../apache/spark/ml/recommendation/ALS.scala | 29 +++--- .../ml/regression/LinearRegression.scala | 6 +- .../spark/ml/tuning/CrossValidator.scala | 6 +- 15 files changed, 102 insertions(+), 87 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index 9a5848684b17..9b5d1dd7a45a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -22,6 +22,7 @@ import scala.annotation.varargs import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -65,7 +66,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O /** @group setParam */ def setInputCol(value: String): T = set(inputCol, value).asInstanceOf[T] - /** @group setParam */ + /** @goup setParam */ def setOutputCol(value: String): T = set(outputCol, value).asInstanceOf[T] /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index c5fc89f93543..9b5b3a3d5527 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -19,7 +19,8 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.{DeveloperApi, AlphaComponent} import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams} -import org.apache.spark.ml.param.{Params, ParamMap, HasRawPredictionCol} +import org.apache.spark.ml.param.{Params, ParamMap} +import org.apache.spark.ml.param.shared.HasRawPredictionCol import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.functions._ import org.apache.spark.sql.DataFrame @@ -36,6 +37,8 @@ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} private[spark] trait ClassifierParams extends PredictorParams with HasRawPredictionCol { + setDefault(rawPredictionCol, "rawPrediction") + override protected def validateAndTransformSchema( schema: StructType, paramMap: ParamMap, @@ -67,8 +70,7 @@ private[spark] abstract class Classifier[ with ClassifierParams { /** @group setParam */ - def setRawPredictionCol(value: String): E = - set(rawPredictionCol, value).asInstanceOf[E] + def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E] // TODO: defaultEvaluator (follow-up PR) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 49c00f77480e..ef672abc22f5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -19,11 +19,11 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.linalg.{VectorUDT, BLAS, Vector, Vectors} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.DoubleType import org.apache.spark.storage.StorageLevel @@ -31,7 +31,11 @@ import org.apache.spark.storage.StorageLevel * Params for logistic regression. */ private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams - with HasRegParam with HasMaxIter with HasThreshold + with HasRegParam with HasMaxIter with HasThreshold { + + setDefault(regParam -> 0.1, maxIter -> 100, threshold -> 0.5) +} + /** @@ -45,10 +49,6 @@ class LogisticRegression extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel] with LogisticRegressionParams { - setRegParam(0.1) - setMaxIter(100) - setThreshold(0.5) - /** @group setParam */ def setRegParam(value: Double): this.type = set(regParam, value) @@ -96,8 +96,6 @@ class LogisticRegressionModel private[ml] ( extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] with LogisticRegressionParams { - setThreshold(0.5) - /** @group setParam */ def setThreshold(value: Double): this.type = set(threshold, value) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index bd8caac85598..715141098a23 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -18,7 +18,8 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} -import org.apache.spark.ml.param.{HasProbabilityCol, ParamMap, Params} +import org.apache.spark.ml.param.{ParamMap, Params} +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 2360f4479f1c..99853df12d76 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.Evaluator import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.{DataFrame, Row} @@ -40,7 +41,7 @@ class BinaryClassificationEvaluator extends Evaluator with Params * @group param */ val metricName: Param[String] = new Param(this, "metricName", - "metric name in evaluation (areaUnderROC|areaUnderPR)", Some("areaUnderROC")) + "metric name in evaluation (areaUnderROC|areaUnderPR)") /** @group getParam */ def getMetricName: String = get(metricName) @@ -51,7 +52,7 @@ class BinaryClassificationEvaluator extends Evaluator with Params /** @group setParam */ def setScoreCol(value: String): this.type = set(rawPredictionCol, value) - /** @group setParam */ + /** @goup setParam */ def setLabelCol(value: String): this.type = set(labelCol, value) override def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index fc4e12773c46..6c02eb674ad3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -31,11 +31,13 @@ import org.apache.spark.sql.types.DataType @AlphaComponent class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] { + setDefault(numFeatures -> (1 << 18)) + /** * number of features * @group param */ - val numFeatures = new IntParam(this, "numFeatures", "number of features", Some(1 << 18)) + val numFeatures = new IntParam(this, "numFeatures", "number of features") /** @group getParam */ def getNumFeatures: Int = get(numFeatures) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala index 05f91dc9105f..5a94fa488ab6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -31,11 +31,13 @@ import org.apache.spark.sql.types.DataType @AlphaComponent class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] { + setDefault(p -> 2.0) + /** * Normalization in L^p^ space, p = 2 by default. * @group param */ - val p = new DoubleParam(this, "p", "the p norm value", Some(2)) + val p = new DoubleParam(this, "p", "the p norm value") /** @group getParam */ def getP: Double = get(p) @@ -50,4 +52,3 @@ class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] { override protected def outputDataType: DataType = new VectorUDT() } - diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 1142aa4f8e73..0952300bdd69 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml._ import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 68401e36950b..76a71f30b262 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -52,11 +52,13 @@ class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] { @AlphaComponent class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenizer] { + setDefault(minTokenLength -> 1, gaps -> false, pattern -> "\\p{L}+|[^\\p{L}\\s]+") + /** * param for minimum token length, default is one to avoid returning empty strings * @group param */ - val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length", Some(1)) + val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length") /** @group setParam */ def setMinTokenLength(value: Int): this.type = set(minTokenLength, value) @@ -68,8 +70,7 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize * param sets regex as splitting on gaps (true) or matching tokens (false) * @group param */ - val gaps: BooleanParam = new BooleanParam( - this, "gaps", "Set regex to match gaps or tokens", Some(false)) + val gaps: BooleanParam = new BooleanParam(this, "gaps", "Set regex to match gaps or tokens") /** @group setParam */ def setGaps(value: Boolean): this.type = set(gaps, value) @@ -81,8 +82,7 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize * param sets regex pattern used by tokenizer * @group param */ - val pattern: Param[String] = new Param( - this, "pattern", "regex pattern used for tokenizing", Some("\\p{L}+|[^\\p{L}\\s]+")) + val pattern: Param[String] = new Param(this, "pattern", "regex pattern used for tokenizing") /** @group setParam */ def setPattern(value: String): this.type = set(pattern, value) diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala index dfb89cc8d4af..f910d9a088fb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.impl.estimator import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.linalg.{VectorUDT, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 17ece897a6c5..0d414395bad0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -26,7 +26,6 @@ import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} import org.apache.spark.ml.Identifiable import org.apache.spark.sql.types.{DataType, StructField, StructType} - /** * :: AlphaComponent :: * A param with self-contained documentation and optionally default value. Primitive-typed param @@ -38,12 +37,7 @@ import org.apache.spark.sql.types.{DataType, StructField, StructType} * @tparam T param value type */ @AlphaComponent -class Param[T] ( - val parent: Params, - val name: String, - val doc: String, - val defaultValue: Option[T] = None) - extends Serializable { +class Param[T] (val parent: Params, val name: String, val doc: String) extends Serializable { /** * Creates a param pair with the given value (for Java). @@ -55,58 +49,42 @@ class Param[T] ( */ def ->(value: T): ParamPair[T] = ParamPair(this, value) - override def toString: String = { - if (defaultValue.isDefined) { - s"$name: $doc (default: ${defaultValue.get})" - } else { - s"$name: $doc" - } - } + override def toString: String = s"$name: $doc" } // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... /** Specialized version of [[Param[Double]]] for Java. */ -class DoubleParam(parent: Params, name: String, doc: String, defaultValue: Option[Double]) - extends Param[Double](parent, name, doc, defaultValue) { - - def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) +class DoubleParam(parent: Params, name: String, doc: String) + extends Param[Double](parent, name, doc) { override def w(value: Double): ParamPair[Double] = super.w(value) } /** Specialized version of [[Param[Int]]] for Java. */ -class IntParam(parent: Params, name: String, doc: String, defaultValue: Option[Int]) - extends Param[Int](parent, name, doc, defaultValue) { - - def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) +class IntParam(parent: Params, name: String, doc: String) + extends Param[Int](parent, name, doc) { override def w(value: Int): ParamPair[Int] = super.w(value) } /** Specialized version of [[Param[Float]]] for Java. */ -class FloatParam(parent: Params, name: String, doc: String, defaultValue: Option[Float]) - extends Param[Float](parent, name, doc, defaultValue) { - - def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) +class FloatParam(parent: Params, name: String, doc: String) + extends Param[Float](parent, name, doc) { override def w(value: Float): ParamPair[Float] = super.w(value) } /** Specialized version of [[Param[Long]]] for Java. */ -class LongParam(parent: Params, name: String, doc: String, defaultValue: Option[Long]) - extends Param[Long](parent, name, doc, defaultValue) { - - def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) +class LongParam(parent: Params, name: String, doc: String) + extends Param[Long](parent, name, doc) { override def w(value: Long): ParamPair[Long] = super.w(value) } /** Specialized version of [[Param[Boolean]]] for Java. */ -class BooleanParam(parent: Params, name: String, doc: String, defaultValue: Option[Boolean]) - extends Param[Boolean](parent, name, doc, defaultValue) { - - def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) +class BooleanParam(parent: Params, name: String, doc: String) + extends Param[Boolean](parent, name, doc) { override def w(value: Boolean): ParamPair[Boolean] = super.w(value) } @@ -124,7 +102,10 @@ case class ParamPair[T](param: Param[T], value: T) @AlphaComponent trait Params extends Identifiable with Serializable { - /** Returns all params. */ + /** + * Returns all params. The default implementation uses Java reflection to list all public methods + * that have return type [[Param]]. + */ def params: Array[Param[_]] = { val methods = this.getClass.getMethods methods.filter { m => @@ -159,7 +140,7 @@ trait Params extends Identifiable with Serializable { } /** Gets a param by its name. */ - private[ml] def getParam(paramName: String): Param[Any] = { + protected final def getParam(paramName: String): Param[Any] = { val m = this.getClass.getMethod(paramName) assert(Modifier.isPublic(m.getModifiers) && classOf[Param[_]].isAssignableFrom(m.getReturnType) && @@ -170,7 +151,7 @@ trait Params extends Identifiable with Serializable { /** * Sets a parameter in the embedded param map. */ - protected def set[T](param: Param[T], value: T): this.type = { + protected final def set[T](param: Param[T], value: T): this.type = { require(param.parent.eq(this)) paramMap.put(param.asInstanceOf[Param[Any]], value) this @@ -179,14 +160,14 @@ trait Params extends Identifiable with Serializable { /** * Sets a parameter (by name) in the embedded param map. */ - private[ml] def set(param: String, value: Any): this.type = { + protected final def set(param: String, value: Any): this.type = { set(getParam(param), value) } /** * Gets the value of a parameter in the embedded param map. */ - protected def get[T](param: Param[T]): T = { + protected final def get[T](param: Param[T]): T = { require(param.parent.eq(this)) paramMap(param) } @@ -194,7 +175,33 @@ trait Params extends Identifiable with Serializable { /** * Internal param map. */ - protected val paramMap: ParamMap = ParamMap.empty + protected final val paramMap: ParamMap = ParamMap.empty + + /** + * Internal param map for default values. + */ + protected final val defaultValues: ParamMap = ParamMap.empty + + /** + * Sets a default value. + */ + protected final def setDefault[T](param: Param[T], value: T): this.type = { + require(param.parent.eq(this)) + defaultValues.put(param, value) + this + } + + protected final def setDefault(paramPairs: ParamPair[_]*): this.type = { + paramPairs.foreach { p => + setDefault(p.param.asInstanceOf[Param[Any]], p.value) + } + this + } + + protected final def getDefault[T](param: Param[T]): Option[T] = { + require(param.parent.eq(this)) + defaultValues.get(param) + } /** * Check whether the given schema contains an input column. @@ -283,9 +290,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten * Optionally returns the value associated with a param or its default. */ def get[T](param: Param[T]): Option[T] = { - map.get(param.asInstanceOf[Param[Any]]) - .orElse(param.defaultValue) - .asInstanceOf[Option[T]] + map.get(param.asInstanceOf[Param[Any]]).asInstanceOf[Option[T]] } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala index 5403012c03a8..e8bb296a732a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala @@ -109,9 +109,6 @@ private[shared] object SharedParamCodeGen { | | /** @group getParam */ | final def get$Name: $T = get($name) - | - | /** @group setParam */ - | protected def set$Name(value: $T): this.type = set($name, value) |} """.stripMargin } diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 52c9e95d6012..4edd947654c7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -34,6 +34,7 @@ import org.apache.spark.{Logging, Partitioner} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame @@ -50,11 +51,14 @@ import org.apache.spark.util.random.XORShiftRandom private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam with HasPredictionCol with HasCheckpointInterval { + setDefault(rank -> 10, numUserBlocks -> 10, numItemBlocks -> 10, implicitPrefs -> false, + alpha -> 1.0, userCol -> "user", itemCol -> "item", ratingCol -> "rating", nonnegative -> false) + /** * Param for rank of the matrix factorization. * @group param */ - val rank = new IntParam(this, "rank", "rank of the factorization", Some(10)) + val rank = new IntParam(this, "rank", "rank of the factorization") /** @group getParam */ def getRank: Int = get(rank) @@ -63,7 +67,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR * Param for number of user blocks. * @group param */ - val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks", Some(10)) + val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks") /** @group getParam */ def getNumUserBlocks: Int = get(numUserBlocks) @@ -73,7 +77,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR * @group param */ val numItemBlocks = - new IntParam(this, "numItemBlocks", "number of item blocks", Some(10)) + new IntParam(this, "numItemBlocks", "number of item blocks") /** @group getParam */ def getNumItemBlocks: Int = get(numItemBlocks) @@ -83,7 +87,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR * @group param */ val implicitPrefs = - new BooleanParam(this, "implicitPrefs", "whether to use implicit preference", Some(false)) + new BooleanParam(this, "implicitPrefs", "whether to use implicit preference") /** @group getParam */ def getImplicitPrefs: Boolean = get(implicitPrefs) @@ -92,7 +96,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR * Param for the alpha parameter in the implicit preference formulation. * @group param */ - val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference", Some(1.0)) + val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference") /** @group getParam */ def getAlpha: Double = get(alpha) @@ -101,7 +105,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR * Param for the column name for user ids. * @group param */ - val userCol = new Param[String](this, "userCol", "column name for user ids", Some("user")) + val userCol = new Param[String](this, "userCol", "column name for user ids") /** @group getParam */ def getUserCol: String = get(userCol) @@ -110,8 +114,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR * Param for the column name for item ids. * @group param */ - val itemCol = - new Param[String](this, "itemCol", "column name for item ids", Some("item")) + val itemCol = new Param[String](this, "itemCol", "column name for item ids") /** @group getParam */ def getItemCol: String = get(itemCol) @@ -120,7 +123,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR * Param for the column name for ratings. * @group param */ - val ratingCol = new Param[String](this, "ratingCol", "column name for ratings", Some("rating")) + val ratingCol = new Param[String](this, "ratingCol", "column name for ratings") /** @group getParam */ def getRatingCol: String = get(ratingCol) @@ -130,7 +133,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR * @group param */ val nonnegative = new BooleanParam( - this, "nonnegative", "whether to use nonnegative constraint for least squares", Some(false)) + this, "nonnegative", "whether to use nonnegative constraint for least squares") /** @group getParam */ val getNonnegative: Boolean = get(nonnegative) @@ -253,6 +256,9 @@ class ALS extends Estimator[ALSModel] with ALSParams { /** @group setParam */ def setRatingCol(value: String): this.type = set(ratingCol, value) + /** @group setParam */ + def setNonnegative(value: Boolean): this.type = set(nonnegative, value) + /** @group setParam */ def setPredictionCol(value: String): this.type = set(predictionCol, value) @@ -262,9 +268,6 @@ class ALS extends Estimator[ALSModel] with ALSParams { /** @group setParam */ def setRegParam(value: Double): this.type = set(regParam, value) - /** @group setParam */ - def setNonnegative(value: Boolean): this.type = set(nonnegative, value) - /** @group setParam */ def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 65f6627a0c35..f0b422316c4b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -18,7 +18,8 @@ package org.apache.spark.ml.regression import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.param.{Params, ParamMap, HasMaxIter, HasRegParam} +import org.apache.spark.ml.param.{Params, ParamMap} +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.linalg.{BLAS, Vector} import org.apache.spark.mllib.regression.LinearRegressionWithSGD import org.apache.spark.sql.DataFrame @@ -41,8 +42,7 @@ private[regression] trait LinearRegressionParams extends RegressorParams class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegressionModel] with LinearRegressionParams { - setRegParam(0.1) - setMaxIter(100) + setDefault(regParam -> 0.1, maxIter -> 100) /** @group setParam */ def setRegParam(value: Double): this.type = set(regParam, value) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 2eb1dac56f1e..d1be3ec437b5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -31,6 +31,9 @@ import org.apache.spark.sql.types.StructType * Params for [[CrossValidator]] and [[CrossValidatorModel]]. */ private[ml] trait CrossValidatorParams extends Params { + + setDefault(numFolds -> 3) + /** * param for the estimator to be cross-validated * @group param @@ -63,8 +66,7 @@ private[ml] trait CrossValidatorParams extends Params { * param for number of folds for cross validation * @group param */ - val numFolds: IntParam = - new IntParam(this, "numFolds", "number of folds for cross validation", Some(3)) + val numFolds: IntParam = new IntParam(this, "numFolds", "number of folds for cross validation") /** @group getParam */ def getNumFolds: Int = get(numFolds) From 1c72579a731d61e2bbbccd00b4bc9358618a58da Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 8 Apr 2015 14:01:31 -0700 Subject: [PATCH 03/21] pass test compile --- mllib/src/main/scala/org/apache/spark/ml/Transformer.scala | 2 +- .../ml/evaluation/BinaryClassificationEvaluator.scala | 2 +- .../src/main/scala/org/apache/spark/ml/param/params.scala | 2 +- .../scala/org/apache/spark/ml/recommendation/ALS.scala | 6 +++--- .../test/scala/org/apache/spark/ml/param/ParamsSuite.scala | 7 +++---- .../test/scala/org/apache/spark/ml/param/TestParams.scala | 4 +++- 6 files changed, 12 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index 9b5d1dd7a45a..d5e4c6ac1c71 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -66,7 +66,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O /** @group setParam */ def setInputCol(value: String): T = set(inputCol, value).asInstanceOf[T] - /** @goup setParam */ + /** @group setParam */ def setOutputCol(value: String): T = set(outputCol, value).asInstanceOf[T] /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 99853df12d76..b27472966a3c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -52,7 +52,7 @@ class BinaryClassificationEvaluator extends Evaluator with Params /** @group setParam */ def setScoreCol(value: String): this.type = set(rawPredictionCol, value) - /** @goup setParam */ + /** @group setParam */ def setLabelCol(value: String): this.type = set(labelCol, value) override def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 0d414395bad0..5089ec90f11b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -140,7 +140,7 @@ trait Params extends Identifiable with Serializable { } /** Gets a param by its name. */ - protected final def getParam(paramName: String): Param[Any] = { + def getParam(paramName: String): Param[Any] = { val m = this.getClass.getMethod(paramName) assert(Modifier.isPublic(m.getModifiers) && classOf[Param[_]].isAssignableFrom(m.getReturnType) && diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 4edd947654c7..1b53c38dd8c3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -256,9 +256,6 @@ class ALS extends Estimator[ALSModel] with ALSParams { /** @group setParam */ def setRatingCol(value: String): this.type = set(ratingCol, value) - /** @group setParam */ - def setNonnegative(value: Boolean): this.type = set(nonnegative, value) - /** @group setParam */ def setPredictionCol(value: String): this.type = set(predictionCol, value) @@ -268,6 +265,9 @@ class ALS extends Estimator[ALSModel] with ALSParams { /** @group setParam */ def setRegParam(value: Double): this.type = set(regParam, value) + /** @group setParam */ + def setNonnegative(value: Boolean): this.type = set(nonnegative, value) + /** @group setParam */ def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 1ce298761237..856cb3ce8f4a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -27,10 +27,10 @@ class ParamsSuite extends FunSuite { test("param") { assert(maxIter.name === "maxIter") assert(maxIter.doc === "max number of iterations") - assert(maxIter.defaultValue.get === 100) assert(maxIter.parent.eq(solver)) - assert(maxIter.toString === "maxIter: max number of iterations (default: 100)") - assert(inputCol.defaultValue === None) + assert(maxIter.toString === "maxIter: max number of iterations") + assert(solver.getMaxIter === 10) + assert(!solver.isSet(inputCol)) } test("param pair") { @@ -47,7 +47,6 @@ class ParamsSuite extends FunSuite { val map0 = ParamMap.empty assert(!map0.contains(maxIter)) - assert(map0(maxIter) === maxIter.defaultValue.get) map0.put(maxIter, 10) assert(map0.contains(maxIter)) assert(map0(maxIter) === 10) diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala index 1a65883d78a7..2b99247bd349 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala @@ -20,7 +20,9 @@ package org.apache.spark.ml.param /** A subclass of Params for testing. */ class TestParams extends Params { - val maxIter = new IntParam(this, "maxIter", "max number of iterations", Some(100)) + setDefault(maxIter -> 10) + + val maxIter = new IntParam(this, "maxIter", "max number of iterations") def setMaxIter(value: Int): this.type = { set(maxIter, value); this } def getMaxIter: Int = get(maxIter) From d9302b82eb2c325b15f20dd1b30bdf4c7cdc58a6 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 8 Apr 2015 14:24:31 -0700 Subject: [PATCH 04/21] generate default values --- .../ml/param/shared/SharedParamCodeGen.scala | 31 +++++++++++++----- .../spark/ml/param/shared/sharedParams.scala | 32 ++++++++++++++++--- 2 files changed, 50 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala index e8bb296a732a..d3c92bd4c113 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala @@ -33,12 +33,13 @@ private[shared] object SharedParamCodeGen { val params = Seq( ParamDesc[Double]("regParam", "regularization parameter"), ParamDesc[Int]("maxIter", "max number of iterations"), - ParamDesc[String]("featuresCol", "features column name"), - ParamDesc[String]("labelCol", "label column name"), - ParamDesc[String]("predictionCol", "prediction column name"), - ParamDesc[String]("rawPredictionCol", "raw prediction (a.k.a. confidence) column name"), - ParamDesc[String]( - "probabilityCol", "column name for predicted class conditional probabilities"), + ParamDesc[String]("featuresCol", "features column name", Some("\"features\"")), + ParamDesc[String]("labelCol", "label column name", Some("\"label\"")), + ParamDesc[String]("predictionCol", "prediction column name", Some("\"prediction\"")), + ParamDesc[String]("rawPredictionCol", "raw prediction (a.k.a. confidence) column name", + Some("\"rawPrediction\"")), + ParamDesc[String]("probabilityCol", + "column name for predicted class conditional probabilities", Some("\"probability\"")), ParamDesc[Double]("threshold", "threshold in prediction"), ParamDesc[String]("inputCol", "input column name"), ParamDesc[String]("outputCol", "output column name"), @@ -52,7 +53,11 @@ private[shared] object SharedParamCodeGen { } /** Description of a param. */ - private case class ParamDesc[T: ClassTag](name: String, doc: String) { + private case class ParamDesc[T: ClassTag]( + name: String, + doc: String, + defaultValueStr: Option[String] = None) { + require(name.matches("[a-z][a-zA-Z0-9]*"), s"Param name $name is invalid.") require(doc.nonEmpty) // TODO: more rigorous on doc @@ -93,14 +98,24 @@ private[shared] object SharedParamCodeGen { val Param = param.paramTypeName val T = param.valueTypeName val doc = param.doc + val defaultValue = param.defaultValueStr + val defaultValueDoc = defaultValue.map { v => + s" (default: $v)" + }.getOrElse("") + val setDefault = defaultValue.map { v => + s""" + | setDefault($name, $v) + """.stripMargin + }.getOrElse("") s""" |/** | * :: DeveloperApi :: - | * Trait for shared param $name. + | * Trait for shared param $name$defaultValueDoc. | */ |@DeveloperApi |trait Has$Name extends Params { + |$setDefault | /** | * Param for $doc. | * @group param diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index f26a6773ed0d..7e9886d347b8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -31,6 +31,7 @@ import org.apache.spark.ml.param._ */ @DeveloperApi trait HasRegParam extends Params { + /** * Param for regularization parameter. * @group param @@ -47,6 +48,7 @@ trait HasRegParam extends Params { */ @DeveloperApi trait HasMaxIter extends Params { + /** * Param for max number of iterations. * @group param @@ -59,10 +61,13 @@ trait HasMaxIter extends Params { /** * :: DeveloperApi :: - * Trait for shared param featuresCol. + * Trait for shared param featuresCol (default: "features"). */ @DeveloperApi trait HasFeaturesCol extends Params { + + setDefault(featuresCol, "features") + /** * Param for features column name. * @group param @@ -75,10 +80,13 @@ trait HasFeaturesCol extends Params { /** * :: DeveloperApi :: - * Trait for shared param labelCol. + * Trait for shared param labelCol (default: "label"). */ @DeveloperApi trait HasLabelCol extends Params { + + setDefault(labelCol, "label") + /** * Param for label column name. * @group param @@ -91,10 +99,13 @@ trait HasLabelCol extends Params { /** * :: DeveloperApi :: - * Trait for shared param predictionCol. + * Trait for shared param predictionCol (default: "prediction"). */ @DeveloperApi trait HasPredictionCol extends Params { + + setDefault(predictionCol, "prediction") + /** * Param for prediction column name. * @group param @@ -107,10 +118,13 @@ trait HasPredictionCol extends Params { /** * :: DeveloperApi :: - * Trait for shared param rawPredictionCol. + * Trait for shared param rawPredictionCol (default: "rawPrediction"). */ @DeveloperApi trait HasRawPredictionCol extends Params { + + setDefault(rawPredictionCol, "rawPrediction") + /** * Param for raw prediction (a.k.a. confidence) column name. * @group param @@ -123,10 +137,13 @@ trait HasRawPredictionCol extends Params { /** * :: DeveloperApi :: - * Trait for shared param probabilityCol. + * Trait for shared param probabilityCol (default: "probability"). */ @DeveloperApi trait HasProbabilityCol extends Params { + + setDefault(probabilityCol, "probability") + /** * Param for column name for predicted class conditional probabilities. * @group param @@ -143,6 +160,7 @@ trait HasProbabilityCol extends Params { */ @DeveloperApi trait HasThreshold extends Params { + /** * Param for threshold in prediction. * @group param @@ -159,6 +177,7 @@ trait HasThreshold extends Params { */ @DeveloperApi trait HasInputCol extends Params { + /** * Param for input column name. * @group param @@ -175,6 +194,7 @@ trait HasInputCol extends Params { */ @DeveloperApi trait HasOutputCol extends Params { + /** * Param for output column name. * @group param @@ -191,6 +211,7 @@ trait HasOutputCol extends Params { */ @DeveloperApi trait HasCheckpointInterval extends Params { + /** * Param for checkpoint interval. * @group param @@ -202,3 +223,4 @@ trait HasCheckpointInterval extends Params { } // scalastyle:on + \ No newline at end of file From a9dbf59a960e1a35c0cbdadebefa65b721e73bf5 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 8 Apr 2015 14:31:12 -0700 Subject: [PATCH 05/21] minor updates --- .../spark/ml/classification/Classifier.scala | 2 -- .../BinaryClassificationEvaluator.scala | 2 ++ .../ml/param/shared/SharedParamCodeGen.scala | 14 +++----- .../spark/ml/param/shared/sharedParams.scala | 35 +++++++++---------- .../apache/spark/ml/recommendation/ALS.scala | 5 +-- 5 files changed, 26 insertions(+), 32 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 9b5b3a3d5527..caac2d8166cf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -37,8 +37,6 @@ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} private[spark] trait ClassifierParams extends PredictorParams with HasRawPredictionCol { - setDefault(rawPredictionCol, "rawPrediction") - override protected def validateAndTransformSchema( schema: StructType, paramMap: ParamMap, diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index b27472966a3c..91a18c9cc918 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -36,6 +36,8 @@ import org.apache.spark.sql.types.DoubleType class BinaryClassificationEvaluator extends Evaluator with Params with HasRawPredictionCol with HasLabelCol { + setDefault(metricName -> "areaUnderROC") + /** * param for metric name in evaluation * @group param diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala index d3c92bd4c113..588f8823f22b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala @@ -105,7 +105,7 @@ private[shared] object SharedParamCodeGen { val setDefault = defaultValue.map { v => s""" | setDefault($name, $v) - """.stripMargin + |""".stripMargin }.getOrElse("") s""" @@ -125,14 +125,13 @@ private[shared] object SharedParamCodeGen { | /** @group getParam */ | final def get$Name: $T = get($name) |} - """.stripMargin + |""".stripMargin } /** Generates Scala source code for the input params with header. */ private def genSharedParams(params: Seq[ParamDesc[_]]): String = { val header = - """ - |/* + """/* | * Licensed to the Apache Software Foundation (ASF) under one or more | * contributor license agreements. See the NOTICE file distributed with | * this work for additional information regarding copyright ownership. @@ -157,12 +156,9 @@ private[shared] object SharedParamCodeGen { |// DO NOT MODIFY THIS FILE! It was generated by SharedParamCodeGen. | |// scalastyle:off - """.stripMargin + |""".stripMargin - val footer = - """ - |// scalastyle:on - """.stripMargin + val footer = "// scalastyle:on\n" val traits = params.map(genHasParamTrait).mkString diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 7e9886d347b8..9bf0ff467091 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -1,4 +1,3 @@ - /* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with @@ -24,7 +23,7 @@ import org.apache.spark.ml.param._ // DO NOT MODIFY THIS FILE! It was generated by SharedParamCodeGen. // scalastyle:off - + /** * :: DeveloperApi :: * Trait for shared param regParam. @@ -41,7 +40,7 @@ trait HasRegParam extends Params { /** @group getParam */ final def getRegParam: Double = get(regParam) } - + /** * :: DeveloperApi :: * Trait for shared param maxIter. @@ -58,7 +57,7 @@ trait HasMaxIter extends Params { /** @group getParam */ final def getMaxIter: Int = get(maxIter) } - + /** * :: DeveloperApi :: * Trait for shared param featuresCol (default: "features"). @@ -67,7 +66,7 @@ trait HasMaxIter extends Params { trait HasFeaturesCol extends Params { setDefault(featuresCol, "features") - + /** * Param for features column name. * @group param @@ -77,7 +76,7 @@ trait HasFeaturesCol extends Params { /** @group getParam */ final def getFeaturesCol: String = get(featuresCol) } - + /** * :: DeveloperApi :: * Trait for shared param labelCol (default: "label"). @@ -86,7 +85,7 @@ trait HasFeaturesCol extends Params { trait HasLabelCol extends Params { setDefault(labelCol, "label") - + /** * Param for label column name. * @group param @@ -96,7 +95,7 @@ trait HasLabelCol extends Params { /** @group getParam */ final def getLabelCol: String = get(labelCol) } - + /** * :: DeveloperApi :: * Trait for shared param predictionCol (default: "prediction"). @@ -105,7 +104,7 @@ trait HasLabelCol extends Params { trait HasPredictionCol extends Params { setDefault(predictionCol, "prediction") - + /** * Param for prediction column name. * @group param @@ -115,7 +114,7 @@ trait HasPredictionCol extends Params { /** @group getParam */ final def getPredictionCol: String = get(predictionCol) } - + /** * :: DeveloperApi :: * Trait for shared param rawPredictionCol (default: "rawPrediction"). @@ -124,7 +123,7 @@ trait HasPredictionCol extends Params { trait HasRawPredictionCol extends Params { setDefault(rawPredictionCol, "rawPrediction") - + /** * Param for raw prediction (a.k.a. confidence) column name. * @group param @@ -134,7 +133,7 @@ trait HasRawPredictionCol extends Params { /** @group getParam */ final def getRawPredictionCol: String = get(rawPredictionCol) } - + /** * :: DeveloperApi :: * Trait for shared param probabilityCol (default: "probability"). @@ -143,7 +142,7 @@ trait HasRawPredictionCol extends Params { trait HasProbabilityCol extends Params { setDefault(probabilityCol, "probability") - + /** * Param for column name for predicted class conditional probabilities. * @group param @@ -153,7 +152,7 @@ trait HasProbabilityCol extends Params { /** @group getParam */ final def getProbabilityCol: String = get(probabilityCol) } - + /** * :: DeveloperApi :: * Trait for shared param threshold. @@ -170,7 +169,7 @@ trait HasThreshold extends Params { /** @group getParam */ final def getThreshold: Double = get(threshold) } - + /** * :: DeveloperApi :: * Trait for shared param inputCol. @@ -187,7 +186,7 @@ trait HasInputCol extends Params { /** @group getParam */ final def getInputCol: String = get(inputCol) } - + /** * :: DeveloperApi :: * Trait for shared param outputCol. @@ -204,7 +203,7 @@ trait HasOutputCol extends Params { /** @group getParam */ final def getOutputCol: String = get(outputCol) } - + /** * :: DeveloperApi :: * Trait for shared param checkpointInterval. @@ -221,6 +220,4 @@ trait HasCheckpointInterval extends Params { /** @group getParam */ final def getCheckpointInterval: Int = get(checkpointInterval) } - // scalastyle:on - \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 1b53c38dd8c3..4c40944f525b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -51,8 +51,9 @@ import org.apache.spark.util.random.XORShiftRandom private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam with HasPredictionCol with HasCheckpointInterval { - setDefault(rank -> 10, numUserBlocks -> 10, numItemBlocks -> 10, implicitPrefs -> false, - alpha -> 1.0, userCol -> "user", itemCol -> "item", ratingCol -> "rating", nonnegative -> false) + setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10, + implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item", + ratingCol -> "rating", nonnegative -> false) /** * Param for rank of the matrix factorization. From 0d3fc5ba1fb0cb15bcb2bda20ebf1d98cfd48243 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 8 Apr 2015 14:49:04 -0700 Subject: [PATCH 06/21] setDefault after param --- .../org/apache/spark/ml/param/params.scala | 3 +++ .../ml/param/shared/SharedParamCodeGen.scala | 8 ++++---- .../spark/ml/param/shared/sharedParams.scala | 20 +++++++++---------- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 5089ec90f11b..4d9eb3c83d2a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -186,6 +186,9 @@ trait Params extends Identifiable with Serializable { * Sets a default value. */ protected final def setDefault[T](param: Param[T], value: T): this.type = { + println(s"param: $param") + println(param.parent) + println(value) require(param.parent.eq(this)) defaultValues.put(param, value) this diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala index 588f8823f22b..279b51d662f5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala @@ -104,8 +104,8 @@ private[shared] object SharedParamCodeGen { }.getOrElse("") val setDefault = defaultValue.map { v => s""" - | setDefault($name, $v) - |""".stripMargin + | setDefault($name, $v) + |""".stripMargin }.getOrElse("") s""" @@ -115,13 +115,13 @@ private[shared] object SharedParamCodeGen { | */ |@DeveloperApi |trait Has$Name extends Params { - |$setDefault + | | /** | * Param for $doc. | * @group param | */ | final val $name: $Param = new $Param(this, "$name", "$doc") - | + |$setDefault | /** @group getParam */ | final def get$Name: $T = get($name) |} diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 9bf0ff467091..5fbf49af507f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -65,14 +65,14 @@ trait HasMaxIter extends Params { @DeveloperApi trait HasFeaturesCol extends Params { - setDefault(featuresCol, "features") - /** * Param for features column name. * @group param */ final val featuresCol: Param[String] = new Param[String](this, "featuresCol", "features column name") + setDefault(featuresCol, "features") + /** @group getParam */ final def getFeaturesCol: String = get(featuresCol) } @@ -84,14 +84,14 @@ trait HasFeaturesCol extends Params { @DeveloperApi trait HasLabelCol extends Params { - setDefault(labelCol, "label") - /** * Param for label column name. * @group param */ final val labelCol: Param[String] = new Param[String](this, "labelCol", "label column name") + setDefault(labelCol, "label") + /** @group getParam */ final def getLabelCol: String = get(labelCol) } @@ -103,14 +103,14 @@ trait HasLabelCol extends Params { @DeveloperApi trait HasPredictionCol extends Params { - setDefault(predictionCol, "prediction") - /** * Param for prediction column name. * @group param */ final val predictionCol: Param[String] = new Param[String](this, "predictionCol", "prediction column name") + setDefault(predictionCol, "prediction") + /** @group getParam */ final def getPredictionCol: String = get(predictionCol) } @@ -122,14 +122,14 @@ trait HasPredictionCol extends Params { @DeveloperApi trait HasRawPredictionCol extends Params { - setDefault(rawPredictionCol, "rawPrediction") - /** * Param for raw prediction (a.k.a. confidence) column name. * @group param */ final val rawPredictionCol: Param[String] = new Param[String](this, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name") + setDefault(rawPredictionCol, "rawPrediction") + /** @group getParam */ final def getRawPredictionCol: String = get(rawPredictionCol) } @@ -141,14 +141,14 @@ trait HasRawPredictionCol extends Params { @DeveloperApi trait HasProbabilityCol extends Params { - setDefault(probabilityCol, "probability") - /** * Param for column name for predicted class conditional probabilities. * @group param */ final val probabilityCol: Param[String] = new Param[String](this, "probabilityCol", "column name for predicted class conditional probabilities") + setDefault(probabilityCol, "probability") + /** @group getParam */ final def getProbabilityCol: String = get(probabilityCol) } From eeeffe89409e7b97ad0e091e18f1084621543422 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 8 Apr 2015 15:33:19 -0700 Subject: [PATCH 07/21] map ++ paramMap => extractValues --- .../scala/org/apache/spark/ml/Estimator.scala | 2 +- .../scala/org/apache/spark/ml/Pipeline.scala | 8 ++++---- .../scala/org/apache/spark/ml/Transformer.scala | 4 ++-- .../spark/ml/classification/Classifier.scala | 4 ++-- .../ml/classification/LogisticRegression.scala | 6 +++--- .../ProbabilisticClassifier.scala | 4 ++-- .../BinaryClassificationEvaluator.scala | 6 +++--- .../spark/ml/feature/StandardScaler.scala | 8 ++++---- .../spark/ml/impl/estimator/Predictor.scala | 8 ++++---- .../org/apache/spark/ml/param/params.scala | 17 +++++++++-------- .../apache/spark/ml/recommendation/ALS.scala | 6 +++--- .../spark/ml/regression/LinearRegression.scala | 2 +- .../apache/spark/ml/tuning/CrossValidator.scala | 4 ++-- 13 files changed, 40 insertions(+), 39 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala index eff7ef925dfb..d6b3503ebdd9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -40,7 +40,7 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { */ @varargs def fit(dataset: DataFrame, paramPairs: ParamPair[_]*): M = { - val map = new ParamMap().put(paramPairs: _*) + val map = ParamMap(paramPairs: _*) fit(dataset, map) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index c4a36103303a..87fee34b5e22 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -98,7 +98,7 @@ class Pipeline extends Estimator[PipelineModel] { */ override def fit(dataset: DataFrame, paramMap: ParamMap): PipelineModel = { transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractValues(paramMap) val theStages = map(stages) // Search for the last estimator. var indexOfLastEstimator = -1 @@ -135,7 +135,7 @@ class Pipeline extends Estimator[PipelineModel] { } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractValues(paramMap) val theStages = map(stages) require(theStages.toSet.size == theStages.size, "Cannot have duplicate components in a pipeline.") @@ -174,14 +174,14 @@ class PipelineModel private[ml] ( override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap - val map = (fittingParamMap ++ this.paramMap) ++ paramMap + val map = fittingParamMap ++ extractValues(paramMap) transformSchema(dataset.schema, map, logging = true) stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, map)) } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap - val map = (fittingParamMap ++ this.paramMap) ++ paramMap + val map = fittingParamMap ++ extractValues(paramMap) stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, map)) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index d5e4c6ac1c71..e2ae1b93897e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -87,7 +87,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O protected def validateInputType(inputType: DataType): Unit = {} override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractValues(paramMap) val inputType = schema(map(inputCol)).dataType validateInputType(inputType) if (schema.fieldNames.contains(map(outputCol))) { @@ -100,7 +100,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractValues(paramMap) dataset.withColumn(map(outputCol), callUDF(this.createTransformFunc(map), outputDataType, dataset(map(inputCol)))) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index caac2d8166cf..3e89f9276ac0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -43,7 +43,7 @@ private[spark] trait ClassifierParams extends PredictorParams fitting: Boolean, featuresDataType: DataType): StructType = { val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType) - val map = this.paramMap ++ paramMap + val map = extractValues(paramMap) addOutputColumn(parentSchema, map(rawPredictionCol), new VectorUDT) } } @@ -109,7 +109,7 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur // Check schema transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractValues(paramMap) // Prepare model val tmpModel = if (paramMap.size != 0) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index ef672abc22f5..a87c007af0ad 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -117,7 +117,7 @@ class LogisticRegressionModel private[ml] ( // Check schema transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractValues(paramMap) // Output selected columns only. // This is a bit complicated since it tries to avoid repeated computation. @@ -178,7 +178,7 @@ class LogisticRegressionModel private[ml] ( * The behavior of this can be adjusted using [[threshold]]. */ override protected def predict(features: Vector): Double = { - if (score(features) > paramMap(threshold)) 1 else 0 + if (score(features) > getThreshold) 1 else 0 } override protected def predictProbabilities(features: Vector): Vector = { @@ -193,7 +193,7 @@ class LogisticRegressionModel private[ml] ( override protected def copy(): LogisticRegressionModel = { val m = new LogisticRegressionModel(parent, fittingParamMap, weights, intercept) - Params.inheritValues(this.paramMap, this, m) + Params.inheritValues(this.extractValues(), this, m) m } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 715141098a23..941e0297c7f2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -38,7 +38,7 @@ private[classification] trait ProbabilisticClassifierParams fitting: Boolean, featuresDataType: DataType): StructType = { val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType) - val map = this.paramMap ++ paramMap + val map = extractValues(paramMap) addOutputColumn(parentSchema, map(probabilityCol), new VectorUDT) } } @@ -103,7 +103,7 @@ private[spark] abstract class ProbabilisticClassificationModel[ // Check schema transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractValues(paramMap) // Prepare model val tmpModel = if (paramMap.size != 0) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 91a18c9cc918..b71368b6f420 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -36,8 +36,6 @@ import org.apache.spark.sql.types.DoubleType class BinaryClassificationEvaluator extends Evaluator with Params with HasRawPredictionCol with HasLabelCol { - setDefault(metricName -> "areaUnderROC") - /** * param for metric name in evaluation * @group param @@ -57,8 +55,10 @@ class BinaryClassificationEvaluator extends Evaluator with Params /** @group setParam */ def setLabelCol(value: String): this.type = set(labelCol, value) + setDefault(metricName -> "areaUnderROC") + override def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = { - val map = this.paramMap ++ paramMap + val map = extractValues(paramMap) val schema = dataset.schema checkInputColumn(schema, map(rawPredictionCol), new VectorUDT) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 0952300bdd69..29ad3b3dfdb4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -48,7 +48,7 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP override def fit(dataset: DataFrame, paramMap: ParamMap): StandardScalerModel = { transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractValues(paramMap) val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v } val scaler = new feature.StandardScaler().fit(input) val model = new StandardScalerModel(this, map, scaler) @@ -57,7 +57,7 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractValues(paramMap) val inputType = schema(map(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${map(inputCol)} must be a vector column") @@ -87,13 +87,13 @@ class StandardScalerModel private[ml] ( override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractValues(paramMap) val scale = udf((v: Vector) => { scaler.transform(v) } : Vector) dataset.withColumn(map(outputCol), scale(col(map(inputCol)))) } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractValues(paramMap) val inputType = schema(map(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${map(inputCol)} must be a vector column") diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala index f910d9a088fb..82dfb6cdec5e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala @@ -54,7 +54,7 @@ private[spark] trait PredictorParams extends Params paramMap: ParamMap, fitting: Boolean, featuresDataType: DataType): StructType = { - val map = this.paramMap ++ paramMap + val map = extractValues(paramMap) // TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector checkInputColumn(schema, map(featuresCol), featuresDataType) if (fitting) { @@ -99,7 +99,7 @@ private[spark] abstract class Predictor[ // This handles a few items such as schema validation. // Developers only need to implement train(). transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractValues(paramMap) val model = train(dataset, map) Params.inheritValues(map, this, model) // copy params to model model @@ -142,7 +142,7 @@ private[spark] abstract class Predictor[ * and put it in an RDD with strong types. */ protected def extractLabeledPoints(dataset: DataFrame, paramMap: ParamMap): RDD[LabeledPoint] = { - val map = this.paramMap ++ paramMap + val map = extractValues(paramMap) dataset.select(map(labelCol), map(featuresCol)) .map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) @@ -202,7 +202,7 @@ private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel // Check schema transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractValues(paramMap) // Prepare model val tmpModel = if (paramMap.size != 0) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 4d9eb3c83d2a..e0fad294fb3d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -136,7 +136,7 @@ trait Params extends Identifiable with Serializable { /** Checks whether a param is explicitly set. */ def isSet(param: Param[_]): Boolean = { require(param.parent.eq(this)) - paramMap.contains(param) + values.contains(param) } /** Gets a param by its name. */ @@ -153,7 +153,7 @@ trait Params extends Identifiable with Serializable { */ protected final def set[T](param: Param[T], value: T): this.type = { require(param.parent.eq(this)) - paramMap.put(param.asInstanceOf[Param[Any]], value) + values.put(param.asInstanceOf[Param[Any]], value) this } @@ -169,26 +169,23 @@ trait Params extends Identifiable with Serializable { */ protected final def get[T](param: Param[T]): T = { require(param.parent.eq(this)) - paramMap(param) + values(param) } /** * Internal param map. */ - protected final val paramMap: ParamMap = ParamMap.empty + private val values: ParamMap = ParamMap.empty /** * Internal param map for default values. */ - protected final val defaultValues: ParamMap = ParamMap.empty + private val defaultValues: ParamMap = ParamMap.empty /** * Sets a default value. */ protected final def setDefault[T](param: Param[T], value: T): this.type = { - println(s"param: $param") - println(param.parent) - println(value) require(param.parent.eq(this)) defaultValues.put(param, value) this @@ -206,6 +203,10 @@ trait Params extends Identifiable with Serializable { defaultValues.get(param) } + protected final def extractValues(extraValues: ParamMap = ParamMap.empty): ParamMap = { + defaultValues ++ values ++ extraValues + } + /** * Check whether the given schema contains an input column. * @param colName Parameter name for the input column. diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 4c40944f525b..e3b8c92ee4e8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -146,7 +146,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR * @return output schema */ protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractValues(paramMap) assert(schema(map(userCol)).dataType == IntegerType) assert(schema(map(itemCol)).dataType== IntegerType) val ratingType = schema(map(ratingCol)).dataType @@ -175,7 +175,7 @@ class ALSModel private[ml] ( override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { import dataset.sqlContext.implicits._ - val map = this.paramMap ++ paramMap + val map = extractValues(paramMap) val users = userFactors.toDF("id", "features") val items = itemFactors.toDF("id", "features") @@ -287,7 +287,7 @@ class ALS extends Estimator[ALSModel] with ALSParams { setCheckpointInterval(10) override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = { - val map = this.paramMap ++ paramMap + val map = extractValues(paramMap) val ratings = dataset .select(col(map(userCol)), col(map(itemCol)), col(map(ratingCol)).cast(FloatType)) .map { row => diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index f0b422316c4b..8c7e61c60dd1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -93,7 +93,7 @@ class LinearRegressionModel private[ml] ( override protected def copy(): LinearRegressionModel = { val m = new LinearRegressionModel(parent, fittingParamMap, weights, intercept) - Params.inheritValues(this.paramMap, this, m) + Params.inheritValues(extractValues(), this, m) m } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index d1be3ec437b5..440e5be23181 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -94,7 +94,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP def setNumFolds(value: Int): this.type = set(numFolds, value) override def fit(dataset: DataFrame, paramMap: ParamMap): CrossValidatorModel = { - val map = this.paramMap ++ paramMap + val map = extractValues(paramMap) val schema = dataset.schema transformSchema(dataset.schema, paramMap, logging = true) val sqlCtx = dataset.sqlContext @@ -132,7 +132,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractValues(paramMap) map(estimator).transformSchema(schema, paramMap) } } From 0d9594e41d69fc74facb95d672612b22c4f5c44c Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 8 Apr 2015 15:43:59 -0700 Subject: [PATCH 08/21] add getOrElse to ParamMap --- .../org/apache/spark/ml/param/params.scala | 25 +++++++++++-------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index e0fad294fb3d..6e15919383eb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -104,7 +104,7 @@ trait Params extends Identifiable with Serializable { /** * Returns all params. The default implementation uses Java reflection to list all public methods - * that have return type [[Param]]. + * that return [[Param]] and have no arguments. */ def params: Array[Param[_]] = { val methods = this.getClass.getMethods @@ -264,12 +264,13 @@ private[spark] object Params { * A param to value map. */ @AlphaComponent -class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) extends Serializable { +final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) + extends Serializable { /** * Creates an empty param map. */ - def this() = this(mutable.Map.empty[Param[Any], Any]) + def this() = this(mutable.Map.empty) /** * Puts a (param, value) pair (overwrites if the input param exists). @@ -291,21 +292,25 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten } /** - * Optionally returns the value associated with a param or its default. + * Optionally returns the value associated with a param. */ def get[T](param: Param[T]): Option[T] = { map.get(param.asInstanceOf[Param[Any]]).asInstanceOf[Option[T]] } + /** + * Returns the value associated with a param or a default value. + */ + def getOrElse[T](param: Param[T], default: T): T = { + get(param).getOrElse(default) + } + /** * Gets the value of the input param or its default value if it does not exist. * Raises a NoSuchElementException if there is no value associated with the input param. */ def apply[T](param: Param[T]): T = { - val value = get(param) - if (value.isDefined) { - value.get - } else { + get(param).getOrElse { throw new NoSuchElementException(s"Cannot find param ${param.name}.") } } @@ -326,7 +331,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten } /** - * Make a copy of this param map. + * Creates a copy of this param map. */ def copy: ParamMap = new ParamMap(map.clone()) @@ -364,7 +369,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten } /** - * Number of param pairs in this set. + * Number of param pairs in this map. */ def size: Int = map.size } From 4ac6348c3211c1221d21030e13ac0f925dd81abc Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 8 Apr 2015 16:54:42 -0700 Subject: [PATCH 09/21] move schema utils to SchemaUtils add a few methods to Params --- .../examples/ml/JavaDeveloperApiExample.java | 2 +- .../examples/ml/DeveloperApiExample.scala | 2 +- .../scala/org/apache/spark/ml/Pipeline.scala | 10 +- .../org/apache/spark/ml/Transformer.scala | 4 +- .../spark/ml/classification/Classifier.scala | 14 +-- .../classification/LogisticRegression.scala | 4 +- .../ProbabilisticClassifier.scala | 8 +- .../BinaryClassificationEvaluator.scala | 10 +- .../apache/spark/ml/feature/HashingTF.scala | 6 +- .../apache/spark/ml/feature/Normalizer.scala | 6 +- .../spark/ml/feature/StandardScaler.scala | 8 +- .../apache/spark/ml/feature/Tokenizer.scala | 10 +- .../spark/ml/impl/estimator/Predictor.scala | 15 +-- .../org/apache/spark/ml/param/params.scala | 112 ++++++++++-------- .../ml/param/shared/SharedParamCodeGen.scala | 2 +- .../spark/ml/param/shared/sharedParams.scala | 22 ++-- .../apache/spark/ml/recommendation/ALS.scala | 35 +++--- .../ml/regression/LinearRegression.scala | 2 +- .../spark/ml/tuning/CrossValidator.scala | 16 +-- .../apache/spark/ml/util/SchemaUtils.scala | 60 ++++++++++ .../apache/spark/ml/param/ParamsSuite.scala | 4 +- .../apache/spark/ml/param/TestParams.scala | 10 +- 22 files changed, 217 insertions(+), 145 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index 19d0eb216848..1695ca25064d 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -116,7 +116,7 @@ class MyJavaLogisticRegression */ IntParam maxIter = new IntParam(this, "maxIter", "max number of iterations"); - int getMaxIter() { return (Integer) get(maxIter); } + int getMaxIter() { return (Integer) getOrDefault(maxIter); } public MyJavaLogisticRegression() { setMaxIter(100); diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index df26798e41b7..c8f6d1cd0944 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -99,7 +99,7 @@ private trait MyLogisticRegressionParams extends ClassifierParams { * class since the maxIter parameter is only used during training (not in the Model). */ val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations") - def getMaxIter: Int = get(maxIter) + def getMaxIter: Int = getOrDefault(maxIter) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 87fee34b5e22..83fd5085a526 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -81,7 +81,7 @@ class Pipeline extends Estimator[PipelineModel] { /** param for pipeline stages */ val stages: Param[Array[PipelineStage]] = new Param(this, "stages", "stages of the pipeline") def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this } - def getStages: Array[PipelineStage] = get(stages) + def getStages: Array[PipelineStage] = getOrDefault(stages) /** * Fits the pipeline to the input dataset with additional parameters. If a stage is an @@ -98,7 +98,7 @@ class Pipeline extends Estimator[PipelineModel] { */ override def fit(dataset: DataFrame, paramMap: ParamMap): PipelineModel = { transformSchema(dataset.schema, paramMap, logging = true) - val map = extractValues(paramMap) + val map = extractParamMap(paramMap) val theStages = map(stages) // Search for the last estimator. var indexOfLastEstimator = -1 @@ -135,7 +135,7 @@ class Pipeline extends Estimator[PipelineModel] { } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = extractValues(paramMap) + val map = extractParamMap(paramMap) val theStages = map(stages) require(theStages.toSet.size == theStages.size, "Cannot have duplicate components in a pipeline.") @@ -174,14 +174,14 @@ class PipelineModel private[ml] ( override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap - val map = fittingParamMap ++ extractValues(paramMap) + val map = fittingParamMap ++ extractParamMap(paramMap) transformSchema(dataset.schema, map, logging = true) stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, map)) } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap - val map = fittingParamMap ++ extractValues(paramMap) + val map = fittingParamMap ++ extractParamMap(paramMap) stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, map)) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index e2ae1b93897e..7fb87fe452ee 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -87,7 +87,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O protected def validateInputType(inputType: DataType): Unit = {} override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = extractValues(paramMap) + val map = extractParamMap(paramMap) val inputType = schema(map(inputCol)).dataType validateInputType(inputType) if (schema.fieldNames.contains(map(outputCol))) { @@ -100,7 +100,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) - val map = extractValues(paramMap) + val map = extractParamMap(paramMap) dataset.withColumn(map(outputCol), callUDF(this.createTransformFunc(map), outputDataType, dataset(map(inputCol)))) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 3e89f9276ac0..7a3f769253fc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -17,16 +17,16 @@ package org.apache.spark.ml.classification -import org.apache.spark.annotation.{DeveloperApi, AlphaComponent} +import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams} -import org.apache.spark.ml.param.{Params, ParamMap} +import org.apache.spark.ml.param.{ParamMap, Params} import org.apache.spark.ml.param.shared.HasRawPredictionCol +import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} -import org.apache.spark.sql.functions._ import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} - /** * :: DeveloperApi :: * Params for classification. @@ -43,8 +43,8 @@ private[spark] trait ClassifierParams extends PredictorParams fitting: Boolean, featuresDataType: DataType): StructType = { val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType) - val map = extractValues(paramMap) - addOutputColumn(parentSchema, map(rawPredictionCol), new VectorUDT) + val map = extractParamMap(paramMap) + SchemaUtils.appendColumn(parentSchema, map(rawPredictionCol), new VectorUDT) } } @@ -109,7 +109,7 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur // Check schema transformSchema(dataset.schema, paramMap, logging = true) - val map = extractValues(paramMap) + val map = extractParamMap(paramMap) // Prepare model val tmpModel = if (paramMap.size != 0) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index a87c007af0ad..8299155117ef 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -117,7 +117,7 @@ class LogisticRegressionModel private[ml] ( // Check schema transformSchema(dataset.schema, paramMap, logging = true) - val map = extractValues(paramMap) + val map = extractParamMap(paramMap) // Output selected columns only. // This is a bit complicated since it tries to avoid repeated computation. @@ -193,7 +193,7 @@ class LogisticRegressionModel private[ml] ( override protected def copy(): LogisticRegressionModel = { val m = new LogisticRegressionModel(parent, fittingParamMap, weights, intercept) - Params.inheritValues(this.extractValues(), this, m) + Params.inheritValues(this.extractParamMap(), this, m) m } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 941e0297c7f2..10404548ccfd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -20,12 +20,12 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} import org.apache.spark.ml.param.{ParamMap, Params} import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, StructType} - /** * Params for probabilistic classification. */ @@ -38,8 +38,8 @@ private[classification] trait ProbabilisticClassifierParams fitting: Boolean, featuresDataType: DataType): StructType = { val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType) - val map = extractValues(paramMap) - addOutputColumn(parentSchema, map(probabilityCol), new VectorUDT) + val map = extractParamMap(paramMap) + SchemaUtils.appendColumn(parentSchema, map(probabilityCol), new VectorUDT) } } @@ -103,7 +103,7 @@ private[spark] abstract class ProbabilisticClassificationModel[ // Check schema transformSchema(dataset.schema, paramMap, logging = true) - val map = extractValues(paramMap) + val map = extractParamMap(paramMap) // Prepare model val tmpModel = if (paramMap.size != 0) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index b71368b6f420..c865eb9fe092 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -21,12 +21,12 @@ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.Evaluator import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types.DoubleType - /** * :: AlphaComponent :: * @@ -44,7 +44,7 @@ class BinaryClassificationEvaluator extends Evaluator with Params "metric name in evaluation (areaUnderROC|areaUnderPR)") /** @group getParam */ - def getMetricName: String = get(metricName) + def getMetricName: String = getOrDefault(metricName) /** @group setParam */ def setMetricName(value: String): this.type = set(metricName, value) @@ -58,11 +58,11 @@ class BinaryClassificationEvaluator extends Evaluator with Params setDefault(metricName -> "areaUnderROC") override def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = { - val map = extractValues(paramMap) + val map = extractParamMap(paramMap) val schema = dataset.schema - checkInputColumn(schema, map(rawPredictionCol), new VectorUDT) - checkInputColumn(schema, map(labelCol), DoubleType) + SchemaUtils.checkColumnType(schema, map(rawPredictionCol), new VectorUDT) + SchemaUtils.checkColumnType(schema, map(labelCol), DoubleType) // TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2. val scoreAndLabels = dataset.select(map(rawPredictionCol), map(labelCol)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 6c02eb674ad3..b20f2fc49a8f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -31,8 +31,6 @@ import org.apache.spark.sql.types.DataType @AlphaComponent class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] { - setDefault(numFeatures -> (1 << 18)) - /** * number of features * @group param @@ -40,11 +38,13 @@ class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] { val numFeatures = new IntParam(this, "numFeatures", "number of features") /** @group getParam */ - def getNumFeatures: Int = get(numFeatures) + def getNumFeatures: Int = getOrDefault(numFeatures) /** @group setParam */ def setNumFeatures(value: Int): this.type = set(numFeatures, value) + setDefault(numFeatures -> (1 << 18)) + override protected def createTransformFunc(paramMap: ParamMap): Iterable[_] => Vector = { val hashingTF = new feature.HashingTF(paramMap(numFeatures)) hashingTF.transform diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala index 5a94fa488ab6..decaeb0da624 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -31,8 +31,6 @@ import org.apache.spark.sql.types.DataType @AlphaComponent class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] { - setDefault(p -> 2.0) - /** * Normalization in L^p^ space, p = 2 by default. * @group param @@ -40,11 +38,13 @@ class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] { val p = new DoubleParam(this, "p", "the p norm value") /** @group getParam */ - def getP: Double = get(p) + def getP: Double = getOrDefault(p) /** @group setParam */ def setP(value: Double): this.type = set(p, value) + setDefault(p -> 2.0) + override protected def createTransformFunc(paramMap: ParamMap): Vector => Vector = { val normalizer = new feature.Normalizer(paramMap(p)) normalizer.transform diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 29ad3b3dfdb4..1b102619b352 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -48,7 +48,7 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP override def fit(dataset: DataFrame, paramMap: ParamMap): StandardScalerModel = { transformSchema(dataset.schema, paramMap, logging = true) - val map = extractValues(paramMap) + val map = extractParamMap(paramMap) val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v } val scaler = new feature.StandardScaler().fit(input) val model = new StandardScalerModel(this, map, scaler) @@ -57,7 +57,7 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = extractValues(paramMap) + val map = extractParamMap(paramMap) val inputType = schema(map(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${map(inputCol)} must be a vector column") @@ -87,13 +87,13 @@ class StandardScalerModel private[ml] ( override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) - val map = extractValues(paramMap) + val map = extractParamMap(paramMap) val scale = udf((v: Vector) => { scaler.transform(v) } : Vector) dataset.withColumn(map(outputCol), scale(col(map(inputCol)))) } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = extractValues(paramMap) + val map = extractParamMap(paramMap) val inputType = schema(map(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${map(inputCol)} must be a vector column") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 76a71f30b262..376a004858b4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -52,8 +52,6 @@ class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] { @AlphaComponent class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenizer] { - setDefault(minTokenLength -> 1, gaps -> false, pattern -> "\\p{L}+|[^\\p{L}\\s]+") - /** * param for minimum token length, default is one to avoid returning empty strings * @group param @@ -64,7 +62,7 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize def setMinTokenLength(value: Int): this.type = set(minTokenLength, value) /** @group getParam */ - def getMinTokenLength: Int = get(minTokenLength) + def getMinTokenLength: Int = getOrDefault(minTokenLength) /** * param sets regex as splitting on gaps (true) or matching tokens (false) @@ -76,7 +74,7 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize def setGaps(value: Boolean): this.type = set(gaps, value) /** @group getParam */ - def getGaps: Boolean = get(gaps) + def getGaps: Boolean = getOrDefault(gaps) /** * param sets regex pattern used by tokenizer @@ -88,7 +86,9 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize def setPattern(value: String): this.type = set(pattern, value) /** @group getParam */ - def getPattern: String = get(pattern) + def getPattern: String = getOrDefault(pattern) + + setDefault(minTokenLength -> 1, gaps -> false, pattern -> "\\p{L}+|[^\\p{L}\\s]+") override protected def createTransformFunc(paramMap: ParamMap): String => Seq[String] = { str => val re = paramMap(pattern).r diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala index 82dfb6cdec5e..195333a5cc47 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.impl.estimator import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} +import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -54,14 +55,14 @@ private[spark] trait PredictorParams extends Params paramMap: ParamMap, fitting: Boolean, featuresDataType: DataType): StructType = { - val map = extractValues(paramMap) + val map = extractParamMap(paramMap) // TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector - checkInputColumn(schema, map(featuresCol), featuresDataType) + SchemaUtils.checkColumnType(schema, map(featuresCol), featuresDataType) if (fitting) { // TODO: Allow other numeric types - checkInputColumn(schema, map(labelCol), DoubleType) + SchemaUtils.checkColumnType(schema, map(labelCol), DoubleType) } - addOutputColumn(schema, map(predictionCol), DoubleType) + SchemaUtils.appendColumn(schema, map(predictionCol), DoubleType) } } @@ -99,7 +100,7 @@ private[spark] abstract class Predictor[ // This handles a few items such as schema validation. // Developers only need to implement train(). transformSchema(dataset.schema, paramMap, logging = true) - val map = extractValues(paramMap) + val map = extractParamMap(paramMap) val model = train(dataset, map) Params.inheritValues(map, this, model) // copy params to model model @@ -142,7 +143,7 @@ private[spark] abstract class Predictor[ * and put it in an RDD with strong types. */ protected def extractLabeledPoints(dataset: DataFrame, paramMap: ParamMap): RDD[LabeledPoint] = { - val map = extractValues(paramMap) + val map = extractParamMap(paramMap) dataset.select(map(labelCol), map(featuresCol)) .map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) @@ -202,7 +203,7 @@ private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel // Check schema transformSchema(dataset.schema, paramMap, logging = true) - val map = extractValues(paramMap) + val map = extractParamMap(paramMap) // Prepare model val tmpModel = if (paramMap.size != 0) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 6e15919383eb..63a3cac680c4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -17,11 +17,12 @@ package org.apache.spark.ml.param +import java.lang.reflect.Modifier +import java.util.NoSuchElementException + import scala.annotation.varargs import scala.collection.mutable -import java.lang.reflect.Modifier - import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} import org.apache.spark.ml.Identifiable import org.apache.spark.sql.types.{DataType, StructField, StructType} @@ -103,10 +104,10 @@ case class ParamPair[T](param: Param[T], value: T) trait Params extends Identifiable with Serializable { /** - * Returns all params. The default implementation uses Java reflection to list all public methods - * that return [[Param]] and have no arguments. + * Returns all params sorted by their names. The default implementation uses Java reflection to + * list all public methods that have no arguments and return [[Param]]. */ - def params: Array[Param[_]] = { + lazy val params: Array[Param[_]] = { val methods = this.getClass.getMethods methods.filter { m => Modifier.isPublic(m.getModifiers) && @@ -135,25 +136,29 @@ trait Params extends Identifiable with Serializable { /** Checks whether a param is explicitly set. */ def isSet(param: Param[_]): Boolean = { - require(param.parent.eq(this)) - values.contains(param) + shouldOwn(param) + paramMap.contains(param) + } + + /** Checks whether a param is explicitly set or has a default value. */ + def isDefined(param: Param[_]): Boolean = { + shouldOwn(param) + defaultParamMap.contains(param) || paramMap.contains(param) } /** Gets a param by its name. */ def getParam(paramName: String): Param[Any] = { - val m = this.getClass.getMethod(paramName) - assert(Modifier.isPublic(m.getModifiers) && - classOf[Param[_]].isAssignableFrom(m.getReturnType) && - m.getParameterTypes.isEmpty) - m.invoke(this).asInstanceOf[Param[Any]] + params.find(_.name == paramName).getOrElse { + throw new NoSuchElementException(s"Param $paramName does not exist.") + }.asInstanceOf[Param[Any]] } /** * Sets a parameter in the embedded param map. */ protected final def set[T](param: Param[T], value: T): this.type = { - require(param.parent.eq(this)) - values.put(param.asInstanceOf[Param[Any]], value) + shouldOwn(param) + paramMap.put(param.asInstanceOf[Param[Any]], value) this } @@ -165,32 +170,33 @@ trait Params extends Identifiable with Serializable { } /** - * Gets the value of a parameter in the embedded param map. + * Optionally returns the user-supplied value of a param. */ - protected final def get[T](param: Param[T]): T = { - require(param.parent.eq(this)) - values(param) + protected final def get[T](param: Param[T]): Option[T] = { + paramMap.get(param) } /** - * Internal param map. + * Gets the value of a param in the embedded param map or its default value. Throws an exception + * if neither is set. */ - private val values: ParamMap = ParamMap.empty - - /** - * Internal param map for default values. - */ - private val defaultValues: ParamMap = ParamMap.empty + protected final def getOrDefault[T](param: Param[T]): T = { + shouldOwn(param) + get(param).orElse(getDefault(param)).get + } /** - * Sets a default value. + * Sets a default value. Make sure that the input param is initialized before this gets called. */ protected final def setDefault[T](param: Param[T], value: T): this.type = { - require(param.parent.eq(this)) - defaultValues.put(param, value) + shouldOwn(param) + defaultParamMap.put(param, value) this } + /** + * Sets default values. Make sure that the input params are initialized before this gets called. + */ protected final def setDefault(paramPairs: ParamPair[_]*): this.type = { paramPairs.foreach { p => setDefault(p.param.asInstanceOf[Param[Any]], p.value) @@ -198,36 +204,42 @@ trait Params extends Identifiable with Serializable { this } + /** + * Gets the default value of a parameter. + */ protected final def getDefault[T](param: Param[T]): Option[T] = { - require(param.parent.eq(this)) - defaultValues.get(param) + shouldOwn(param) + defaultParamMap.get(param) } - protected final def extractValues(extraValues: ParamMap = ParamMap.empty): ParamMap = { - defaultValues ++ values ++ extraValues + /** + * Tests whether the input param has a default value set. + */ + protected final def hasDefault[T](param: Param[T]): Boolean = { + shouldOwn(param) + defaultParamMap.contains(param) } /** - * Check whether the given schema contains an input column. - * @param colName Parameter name for the input column. - * @param dataType SQL DataType of the input column. + * Extracts the embedded default param values and user-supplied values, and then merges them with + * extra values from input into a flat param map, where the latter value is used if there exist + * conflicts. */ - protected def checkInputColumn(schema: StructType, colName: String, dataType: DataType): Unit = { - val actualDataType = schema(colName).dataType - require(actualDataType.equals(dataType), - s"Input column $colName must be of type $dataType" + - s" but was actually $actualDataType. Column param description: ${getParam(colName)}") + protected final def extractParamMap(extraParamMap: ParamMap = ParamMap.empty): ParamMap = { + defaultParamMap ++ paramMap ++ extraParamMap } - protected def addOutputColumn( - schema: StructType, - colName: String, - dataType: DataType): StructType = { - if (colName.length == 0) return schema - val fieldNames = schema.fieldNames - require(!fieldNames.contains(colName), s"Prediction column $colName already exists.") - val outputFields = schema.fields ++ Seq(StructField(colName, dataType, nullable = false)) - StructType(outputFields) + + + /** Internal param map for user-supplied values. */ + private val paramMap: ParamMap = ParamMap.empty + + /** Internal param map for default values. */ + private val defaultParamMap: ParamMap = ParamMap.empty + + /** Validates that the input param belongs to this instance. */ + private def shouldOwn(param: Param[_]): Unit = { + require(param.parent.eq(this), s"Param $param does not belong to $this.") } } @@ -343,7 +355,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) /** * Returns a new param map that contains parameters in this map and the given map, - * where the latter overwrites this if there exists conflicts. + * where the latter overwrites this if there exist conflicts. */ def ++(other: ParamMap): ParamMap = { // TODO: Provide a better method name for Java users. diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala index 279b51d662f5..fb881160bf18 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala @@ -123,7 +123,7 @@ private[shared] object SharedParamCodeGen { | final val $name: $Param = new $Param(this, "$name", "$doc") |$setDefault | /** @group getParam */ - | final def get$Name: $T = get($name) + | final def get$Name: $T = getOrDefault($name) |} |""".stripMargin } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 5fbf49af507f..33a499103f1e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -38,7 +38,7 @@ trait HasRegParam extends Params { final val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter") /** @group getParam */ - final def getRegParam: Double = get(regParam) + final def getRegParam: Double = getOrDefault(regParam) } /** @@ -55,7 +55,7 @@ trait HasMaxIter extends Params { final val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations") /** @group getParam */ - final def getMaxIter: Int = get(maxIter) + final def getMaxIter: Int = getOrDefault(maxIter) } /** @@ -74,7 +74,7 @@ trait HasFeaturesCol extends Params { setDefault(featuresCol, "features") /** @group getParam */ - final def getFeaturesCol: String = get(featuresCol) + final def getFeaturesCol: String = getOrDefault(featuresCol) } /** @@ -93,7 +93,7 @@ trait HasLabelCol extends Params { setDefault(labelCol, "label") /** @group getParam */ - final def getLabelCol: String = get(labelCol) + final def getLabelCol: String = getOrDefault(labelCol) } /** @@ -112,7 +112,7 @@ trait HasPredictionCol extends Params { setDefault(predictionCol, "prediction") /** @group getParam */ - final def getPredictionCol: String = get(predictionCol) + final def getPredictionCol: String = getOrDefault(predictionCol) } /** @@ -131,7 +131,7 @@ trait HasRawPredictionCol extends Params { setDefault(rawPredictionCol, "rawPrediction") /** @group getParam */ - final def getRawPredictionCol: String = get(rawPredictionCol) + final def getRawPredictionCol: String = getOrDefault(rawPredictionCol) } /** @@ -150,7 +150,7 @@ trait HasProbabilityCol extends Params { setDefault(probabilityCol, "probability") /** @group getParam */ - final def getProbabilityCol: String = get(probabilityCol) + final def getProbabilityCol: String = getOrDefault(probabilityCol) } /** @@ -167,7 +167,7 @@ trait HasThreshold extends Params { final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in prediction") /** @group getParam */ - final def getThreshold: Double = get(threshold) + final def getThreshold: Double = getOrDefault(threshold) } /** @@ -184,7 +184,7 @@ trait HasInputCol extends Params { final val inputCol: Param[String] = new Param[String](this, "inputCol", "input column name") /** @group getParam */ - final def getInputCol: String = get(inputCol) + final def getInputCol: String = getOrDefault(inputCol) } /** @@ -201,7 +201,7 @@ trait HasOutputCol extends Params { final val outputCol: Param[String] = new Param[String](this, "outputCol", "output column name") /** @group getParam */ - final def getOutputCol: String = get(outputCol) + final def getOutputCol: String = getOrDefault(outputCol) } /** @@ -218,6 +218,6 @@ trait HasCheckpointInterval extends Params { final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval") /** @group getParam */ - final def getCheckpointInterval: Int = get(checkpointInterval) + final def getCheckpointInterval: Int = getOrDefault(checkpointInterval) } // scalastyle:on diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index e3b8c92ee4e8..bd793beba35b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -51,10 +51,6 @@ import org.apache.spark.util.random.XORShiftRandom private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam with HasPredictionCol with HasCheckpointInterval { - setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10, - implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item", - ratingCol -> "rating", nonnegative -> false) - /** * Param for rank of the matrix factorization. * @group param @@ -62,7 +58,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR val rank = new IntParam(this, "rank", "rank of the factorization") /** @group getParam */ - def getRank: Int = get(rank) + def getRank: Int = getOrDefault(rank) /** * Param for number of user blocks. @@ -71,7 +67,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks") /** @group getParam */ - def getNumUserBlocks: Int = get(numUserBlocks) + def getNumUserBlocks: Int = getOrDefault(numUserBlocks) /** * Param for number of item blocks. @@ -81,17 +77,16 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR new IntParam(this, "numItemBlocks", "number of item blocks") /** @group getParam */ - def getNumItemBlocks: Int = get(numItemBlocks) + def getNumItemBlocks: Int = getOrDefault(numItemBlocks) /** * Param to decide whether to use implicit preference. * @group param */ - val implicitPrefs = - new BooleanParam(this, "implicitPrefs", "whether to use implicit preference") + val implicitPrefs = new BooleanParam(this, "implicitPrefs", "whether to use implicit preference") /** @group getParam */ - def getImplicitPrefs: Boolean = get(implicitPrefs) + def getImplicitPrefs: Boolean = getOrDefault(implicitPrefs) /** * Param for the alpha parameter in the implicit preference formulation. @@ -100,7 +95,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference") /** @group getParam */ - def getAlpha: Double = get(alpha) + def getAlpha: Double = getOrDefault(alpha) /** * Param for the column name for user ids. @@ -109,7 +104,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR val userCol = new Param[String](this, "userCol", "column name for user ids") /** @group getParam */ - def getUserCol: String = get(userCol) + def getUserCol: String = getOrDefault(userCol) /** * Param for the column name for item ids. @@ -118,7 +113,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR val itemCol = new Param[String](this, "itemCol", "column name for item ids") /** @group getParam */ - def getItemCol: String = get(itemCol) + def getItemCol: String = getOrDefault(itemCol) /** * Param for the column name for ratings. @@ -127,7 +122,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR val ratingCol = new Param[String](this, "ratingCol", "column name for ratings") /** @group getParam */ - def getRatingCol: String = get(ratingCol) + def getRatingCol: String = getOrDefault(ratingCol) /** * Param for whether to apply nonnegativity constraints. @@ -137,7 +132,11 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR this, "nonnegative", "whether to use nonnegative constraint for least squares") /** @group getParam */ - val getNonnegative: Boolean = get(nonnegative) + def getNonnegative: Boolean = getOrDefault(nonnegative) + + setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10, + implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item", + ratingCol -> "rating", nonnegative -> false) /** * Validates and transforms the input schema. @@ -146,7 +145,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR * @return output schema */ protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = extractValues(paramMap) + val map = extractParamMap(paramMap) assert(schema(map(userCol)).dataType == IntegerType) assert(schema(map(itemCol)).dataType== IntegerType) val ratingType = schema(map(ratingCol)).dataType @@ -175,7 +174,7 @@ class ALSModel private[ml] ( override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { import dataset.sqlContext.implicits._ - val map = extractValues(paramMap) + val map = extractParamMap(paramMap) val users = userFactors.toDF("id", "features") val items = itemFactors.toDF("id", "features") @@ -287,7 +286,7 @@ class ALS extends Estimator[ALSModel] with ALSParams { setCheckpointInterval(10) override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = { - val map = extractValues(paramMap) + val map = extractParamMap(paramMap) val ratings = dataset .select(col(map(userCol)), col(map(itemCol)), col(map(ratingCol)).cast(FloatType)) .map { row => diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 8c7e61c60dd1..26ca7459c4fd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -93,7 +93,7 @@ class LinearRegressionModel private[ml] ( override protected def copy(): LinearRegressionModel = { val m = new LinearRegressionModel(parent, fittingParamMap, weights, intercept) - Params.inheritValues(extractValues(), this, m) + Params.inheritValues(extractParamMap(), this, m) m } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 440e5be23181..4bb4ed813c00 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -32,8 +32,6 @@ import org.apache.spark.sql.types.StructType */ private[ml] trait CrossValidatorParams extends Params { - setDefault(numFolds -> 3) - /** * param for the estimator to be cross-validated * @group param @@ -41,7 +39,7 @@ private[ml] trait CrossValidatorParams extends Params { val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection") /** @group getParam */ - def getEstimator: Estimator[_] = get(estimator) + def getEstimator: Estimator[_] = getOrDefault(estimator) /** * param for estimator param maps @@ -51,7 +49,7 @@ private[ml] trait CrossValidatorParams extends Params { new Param(this, "estimatorParamMaps", "param maps for the estimator") /** @group getParam */ - def getEstimatorParamMaps: Array[ParamMap] = get(estimatorParamMaps) + def getEstimatorParamMaps: Array[ParamMap] = getOrDefault(estimatorParamMaps) /** * param for the evaluator for selection @@ -60,7 +58,7 @@ private[ml] trait CrossValidatorParams extends Params { val evaluator: Param[Evaluator] = new Param(this, "evaluator", "evaluator for selection") /** @group getParam */ - def getEvaluator: Evaluator = get(evaluator) + def getEvaluator: Evaluator = getOrDefault(evaluator) /** * param for number of folds for cross validation @@ -69,7 +67,9 @@ private[ml] trait CrossValidatorParams extends Params { val numFolds: IntParam = new IntParam(this, "numFolds", "number of folds for cross validation") /** @group getParam */ - def getNumFolds: Int = get(numFolds) + def getNumFolds: Int = getOrDefault(numFolds) + + setDefault(numFolds -> 3) } /** @@ -94,7 +94,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP def setNumFolds(value: Int): this.type = set(numFolds, value) override def fit(dataset: DataFrame, paramMap: ParamMap): CrossValidatorModel = { - val map = extractValues(paramMap) + val map = extractParamMap(paramMap) val schema = dataset.schema transformSchema(dataset.schema, paramMap, logging = true) val sqlCtx = dataset.sqlContext @@ -132,7 +132,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = extractValues(paramMap) + val map = extractParamMap(paramMap) map(estimator).transformSchema(schema, paramMap) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala new file mode 100644 index 000000000000..f5e9b2069d19 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.types.{DataType, StructField, StructType} + +/** + * :: DeveloperApi :: + * Utils for handling schemas. + */ +@DeveloperApi +object SchemaUtils { + + // TODO: Move the utility methods to SQL. + + /** + * Check whether the given schema contains an column of the required data type. + * @param colName Parameter name for the input column. + * @param dataType SQL DataType of the input column. + */ + def checkColumnType(schema: StructType, colName: String, dataType: DataType): Unit = { + val actualDataType = schema(colName).dataType + require(actualDataType.equals(dataType), + s"Input column $colName must be of type $dataType but was actually $actualDataType.") + } + + /** + * Appends a new column to the input schema. + * @param schema input schema + * @param colName new column name + * @param dataType new column type + * @return schema with the input column appended + */ + def appendColumn( + schema: StructType, + colName: String, + dataType: DataType): StructType = { + require(colName.nonEmpty) + val fieldNames = schema.fieldNames + require(!fieldNames.contains(colName), s"Column $colName already exists.") + val outputFields = schema.fields :+ StructField(colName, dataType, nullable = false) + StructType(outputFields) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 856cb3ce8f4a..380c6476bd96 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -78,13 +78,13 @@ class ParamsSuite extends FunSuite { test("params") { val params = solver.params - assert(params.size === 2) + assert(params.length === 2) assert(params(0).eq(inputCol), "params must be ordered by name") assert(params(1).eq(maxIter)) assert(solver.explainParams() === Seq(inputCol, maxIter).mkString("\n")) assert(solver.getParam("inputCol").eq(inputCol)) assert(solver.getParam("maxIter").eq(maxIter)) - intercept[NoSuchMethodException] { + intercept[NoSuchElementException] { solver.getParam("abc") } assert(!solver.isSet(inputCol)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala index 2b99247bd349..84ed5d23322e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala @@ -20,18 +20,18 @@ package org.apache.spark.ml.param /** A subclass of Params for testing. */ class TestParams extends Params { - setDefault(maxIter -> 10) - val maxIter = new IntParam(this, "maxIter", "max number of iterations") def setMaxIter(value: Int): this.type = { set(maxIter, value); this } - def getMaxIter: Int = get(maxIter) + def getMaxIter: Int = getOrDefault(maxIter) val inputCol = new Param[String](this, "inputCol", "input column name") def setInputCol(value: String): this.type = { set(inputCol, value); this } - def getInputCol: String = get(inputCol) + def getInputCol: String = getOrDefault(inputCol) + + setDefault(maxIter -> 10) override def validate(paramMap: ParamMap) = { - val m = this.paramMap ++ paramMap + val m = extractParamMap(paramMap) require(m(maxIter) >= 0) require(m.contains(inputCol)) } From 48d0e843d8e1bb4d6db275316435829c81d628e3 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 8 Apr 2015 17:13:56 -0700 Subject: [PATCH 10/21] add remove and update explainParams --- .../org/apache/spark/ml/param/params.scala | 37 +++++++++++++++++-- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 63a3cac680c4..0ec848069e2e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -25,7 +25,6 @@ import scala.collection.mutable import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} import org.apache.spark.ml.Identifiable -import org.apache.spark.sql.types.{DataType, StructField, StructType} /** * :: AlphaComponent :: @@ -129,10 +128,25 @@ trait Params extends Identifiable with Serializable { */ def validate(): Unit = validate(ParamMap.empty) + /** + * Explain a param and optionally its default value and the user-supplied value. + */ + def explain(param: Param[_]): String = { + shouldOwn(param) + val valueStr = if (isDefined(param)) { + val defaultValueStr = getDefault(param).map("default: " + _) + val currentValueStr = get(param).map("current: " + _) + (defaultValueStr ++ currentValueStr).flatten.mkString("(", ", ", ")") + } else { + "(undefined)" + } + s"${param.name}: ${param.doc} $valueStr" + } + /** * Returns the documentation of all params. */ - def explainParams(): String = params.mkString("\n") + def explainParams(): String = params.map(explain).mkString("\n") /** Checks whether a param is explicitly set. */ def isSet(param: Param[_]): Boolean = { @@ -173,9 +187,19 @@ trait Params extends Identifiable with Serializable { * Optionally returns the user-supplied value of a param. */ protected final def get[T](param: Param[T]): Option[T] = { + shouldOwn(param) paramMap.get(param) } + /** + * Clears the user-supplied value for the input param. + */ + protected final def clear(param: Param[_]): this.type = { + shouldOwn(param) + paramMap.remove(param) + this + } + /** * Gets the value of a param in the embedded param map or its default value. Throws an exception * if neither is set. @@ -229,8 +253,6 @@ trait Params extends Identifiable with Serializable { defaultParamMap ++ paramMap ++ extraParamMap } - - /** Internal param map for user-supplied values. */ private val paramMap: ParamMap = ParamMap.empty @@ -334,6 +356,13 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) map.contains(param.asInstanceOf[Param[Any]]) } + /** + * Removes a key from this map and returns its value associated previously as an option. + */ + def remove[T](param: Param[T]): Option[T] = { + map.remove(param.asInstanceOf[Param[Any]]).asInstanceOf[Option[T]] + } + /** * Filters this param map for the given parent. */ From 94fd98e94ebf3c4460e941ad36aab711791d107d Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 8 Apr 2015 17:21:19 -0700 Subject: [PATCH 11/21] fix explain params --- mllib/src/main/scala/org/apache/spark/ml/param/params.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 0ec848069e2e..edb5d28bafdc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -136,7 +136,7 @@ trait Params extends Identifiable with Serializable { val valueStr = if (isDefined(param)) { val defaultValueStr = getDefault(param).map("default: " + _) val currentValueStr = get(param).map("current: " + _) - (defaultValueStr ++ currentValueStr).flatten.mkString("(", ", ", ")") + (defaultValueStr ++ currentValueStr).mkString("(", ", ", ")") } else { "(undefined)" } From 29b004ca130305969f7c48285c9c17fae8f662e1 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 8 Apr 2015 17:32:12 -0700 Subject: [PATCH 12/21] update ParamsSuite --- .../apache/spark/ml/param/ParamsSuite.scala | 37 +++++++++++++++---- 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 380c6476bd96..da37162eccb9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -21,19 +21,20 @@ import org.scalatest.FunSuite class ParamsSuite extends FunSuite { - val solver = new TestParams() - import solver.{inputCol, maxIter} - test("param") { + val solver = new TestParams() + import solver.maxIter + assert(maxIter.name === "maxIter") assert(maxIter.doc === "max number of iterations") assert(maxIter.parent.eq(solver)) assert(maxIter.toString === "maxIter: max number of iterations") - assert(solver.getMaxIter === 10) - assert(!solver.isSet(inputCol)) } test("param pair") { + val solver = new TestParams() + import solver.maxIter + val pair0 = maxIter -> 5 val pair1 = maxIter.w(5) val pair2 = ParamPair(maxIter, 5) @@ -44,6 +45,9 @@ class ParamsSuite extends FunSuite { } test("param map") { + val solver = new TestParams() + import solver.{maxIter, inputCol} + val map0 = ParamMap.empty assert(!map0.contains(maxIter)) @@ -77,23 +81,42 @@ class ParamsSuite extends FunSuite { } test("params") { + val solver = new TestParams() + import solver.{maxIter, inputCol} + val params = solver.params assert(params.length === 2) assert(params(0).eq(inputCol), "params must be ordered by name") assert(params(1).eq(maxIter)) - assert(solver.explainParams() === Seq(inputCol, maxIter).mkString("\n")) + + assert(!solver.isSet(maxIter)) + assert(solver.isDefined(maxIter)) + assert(solver.getMaxIter === 10) + solver.setMaxIter(100) + assert(solver.isSet(maxIter)) + assert(solver.getMaxIter === 100) + assert(!solver.isSet(inputCol)) + assert(!solver.isDefined(inputCol)) + intercept[NoSuchElementException](solver.getInputCol) + + assert( + solver.explain(maxIter) === "maxIter: max number of iterations (default: 10, current: 100)") + assert(solver.explain(inputCol) === "inputCol: input column name (undefined)") + assert(solver.explainParams() === Seq(inputCol, maxIter).map(solver.explain).mkString("\n")) + assert(solver.getParam("inputCol").eq(inputCol)) assert(solver.getParam("maxIter").eq(maxIter)) intercept[NoSuchElementException] { solver.getParam("abc") } - assert(!solver.isSet(inputCol)) + intercept[IllegalArgumentException] { solver.validate() } solver.validate(ParamMap(inputCol -> "input")) solver.setInputCol("input") assert(solver.isSet(inputCol)) + assert(solver.isDefined(inputCol)) assert(solver.getInputCol === "input") solver.validate() intercept[IllegalArgumentException] { From d63b5cca9aa334ab36cbd9403fd6c5e8dddd4aa6 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 8 Apr 2015 18:00:07 -0700 Subject: [PATCH 13/21] fix examples --- .../spark/examples/ml/JavaDeveloperApiExample.java | 2 +- .../apache/spark/examples/ml/DeveloperApiExample.scala | 4 ++-- .../main/scala/org/apache/spark/ml/param/params.scala | 9 ++++++++- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index 1695ca25064d..eaf00d09f550 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -211,7 +211,7 @@ public Vector predictRaw(Vector features) { public MyJavaLogisticRegressionModel copy() { MyJavaLogisticRegressionModel m = new MyJavaLogisticRegressionModel(parent_, fittingParamMap_, weights_); - Params$.MODULE$.inheritValues(this.paramMap(), this, m); + Params$.MODULE$.inheritValues(this.extractParamMap(), this, m); return m; } } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index c8f6d1cd0944..2245fa429fda 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -174,11 +174,11 @@ private class MyLogisticRegressionModel( * Create a copy of the model. * The copy is shallow, except for the embedded paramMap, which gets a deep copy. * - * This is used for the defaul implementation of [[transform()]]. + * This is used for the default implementation of [[transform()]]. */ override protected def copy(): MyLogisticRegressionModel = { val m = new MyLogisticRegressionModel(parent, fittingParamMap, weights) - Params.inheritValues(this.paramMap, this, m) + Params.inheritValues(extractParamMap(), this, m) m } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index edb5d28bafdc..fe09cadf52d7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -249,10 +249,17 @@ trait Params extends Identifiable with Serializable { * extra values from input into a flat param map, where the latter value is used if there exist * conflicts. */ - protected final def extractParamMap(extraParamMap: ParamMap = ParamMap.empty): ParamMap = { + protected final def extractParamMap(extraParamMap: ParamMap): ParamMap = { defaultParamMap ++ paramMap ++ extraParamMap } + /** + * [[extractParamMap]] with no extra values. + */ + protected final def extractParamMap(): ParamMap = { + extractParamMap(ParamMap.empty) + } + /** Internal param map for user-supplied values. */ private val paramMap: ParamMap = ParamMap.empty From e938f8112aae4aa82c4df7c78e1829f9d21483e5 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sun, 12 Apr 2015 23:01:12 -0700 Subject: [PATCH 14/21] update code to set default parameter values --- .../spark/ml/feature/StringIndexer.scala | 10 +++++--- .../spark/ml/feature/VectorAssembler.scala | 7 +++--- .../spark/ml/feature/VectorIndexer.scala | 25 +++++++++++-------- .../org/apache/spark/ml/param/params.scala | 1 - .../ml/param/shared/SharedParamCodeGen.scala | 1 + .../spark/ml/param/shared/sharedParams.scala | 17 +++++++++++++ 6 files changed, 42 insertions(+), 19 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 61e6742e880d..4d960df357fe 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -22,6 +22,8 @@ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{StringType, StructType} @@ -34,8 +36,8 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap - checkInputColumn(schema, map(inputCol), StringType) + val map = extractParamMap(paramMap) + SchemaUtils.checkColumnType(schema, map(inputCol), StringType) val inputFields = schema.fields val outputColName = map(outputCol) require(inputFields.forall(_.name != outputColName), @@ -64,7 +66,7 @@ class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase // TODO: handle unseen labels override def fit(dataset: DataFrame, paramMap: ParamMap): StringIndexerModel = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val counts = dataset.select(map(inputCol)).map(_.getString(0)).countByValue() val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray val model = new StringIndexerModel(this, map, labels) @@ -105,7 +107,7 @@ class StringIndexerModel private[ml] ( def setOutputCol(value: String): this.type = set(outputCol, value) override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val indexer = udf { label: String => if (labelToIndex.contains(label)) { labelToIndex(label) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index d1b8f7e6e929..e567e069e7c0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -22,7 +22,8 @@ import scala.collection.mutable.ArrayBuilder import org.apache.spark.SparkException import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.Transformer -import org.apache.spark.ml.param.{HasInputCols, HasOutputCol, ParamMap} +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} import org.apache.spark.sql.{Column, DataFrame, Row} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute @@ -44,7 +45,7 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol { def setOutputCol(value: String): this.type = set(outputCol, value) override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val assembleFunc = udf { r: Row => VectorAssembler.assemble(r.toSeq: _*) } @@ -61,7 +62,7 @@ class VectorAssembler extends Transformer with HasInputCols with HasOutputCol { } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val inputColNames = map(inputCols) val outputColName = map(outputCol) val inputDataTypes = inputColNames.map(name => schema(name).dataType) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 8760960e1927..452faa06e202 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -18,10 +18,12 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute.{BinaryAttribute, NumericAttribute, NominalAttribute, Attribute, AttributeGroup} -import org.apache.spark.ml.param.{HasInputCol, HasOutputCol, IntParam, ParamMap, Params} +import org.apache.spark.ml.param.{IntParam, ParamMap, Params} +import org.apache.spark.ml.param.shared._ import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, VectorUDT} import org.apache.spark.sql.{Row, DataFrame} import org.apache.spark.sql.functions.callUDF @@ -40,11 +42,12 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu */ val maxCategories = new IntParam(this, "maxCategories", "Threshold for the number of values a categorical feature can take." + - " If a feature is found to have > maxCategories values, then it is declared continuous.", - Some(20)) + " If a feature is found to have > maxCategories values, then it is declared continuous.") /** @group getParam */ - def getMaxCategories: Int = get(maxCategories) + def getMaxCategories: Int = getOrDefault(maxCategories) + + setDefault(maxCategories -> 20) } /** @@ -101,7 +104,7 @@ class VectorIndexer extends Estimator[VectorIndexerModel] with VectorIndexerPara override def fit(dataset: DataFrame, paramMap: ParamMap): VectorIndexerModel = { transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val firstRow = dataset.select(map(inputCol)).take(1) require(firstRow.length == 1, s"VectorIndexer cannot be fit on an empty dataset.") val numFeatures = firstRow(0).getAs[Vector](0).size @@ -120,12 +123,12 @@ class VectorIndexer extends Estimator[VectorIndexerModel] with VectorIndexerPara override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { // We do not transfer feature metadata since we do not know what types of features we will // produce in transform(). - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val dataType = new VectorUDT require(map.contains(inputCol), s"VectorIndexer requires input column parameter: $inputCol") require(map.contains(outputCol), s"VectorIndexer requires output column parameter: $outputCol") - checkInputColumn(schema, map(inputCol), dataType) - addOutputColumn(schema, map(outputCol), dataType) + SchemaUtils.checkColumnType(schema, map(inputCol), dataType) + SchemaUtils.appendColumn(schema, map(outputCol), dataType) } } @@ -320,7 +323,7 @@ class VectorIndexerModel private[ml] ( override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val newField = prepOutputField(dataset.schema, map) val newCol = callUDF(transformFunc, new VectorUDT, dataset(map(inputCol))) // For now, just check the first row of inputCol for vector length. @@ -334,13 +337,13 @@ class VectorIndexerModel private[ml] ( } override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - val map = this.paramMap ++ paramMap + val map = extractParamMap(paramMap) val dataType = new VectorUDT require(map.contains(inputCol), s"VectorIndexerModel requires input column parameter: $inputCol") require(map.contains(outputCol), s"VectorIndexerModel requires output column parameter: $outputCol") - checkInputColumn(schema, map(inputCol), dataType) + SchemaUtils.checkColumnType(schema, map(inputCol), dataType) val origAttrGroup = AttributeGroup.fromStructField(schema(map(inputCol))) val origNumFeatures: Option[Int] = if (origAttrGroup.attributes.nonEmpty) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 1c1d7219f375..fe09cadf52d7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -201,7 +201,6 @@ trait Params extends Identifiable with Serializable { } /** -<<<<<<< HEAD * Gets the value of a param in the embedded param map or its default value. Throws an exception * if neither is set. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala index 3f3519c35ceb..f99be87a2451 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala @@ -42,6 +42,7 @@ private[shared] object SharedParamCodeGen { "column name for predicted class conditional probabilities", Some("\"probability\"")), ParamDesc[Double]("threshold", "threshold in prediction"), ParamDesc[String]("inputCol", "input column name"), + ParamDesc[Array[String]]("labelCols", "label column names"), ParamDesc[String]("outputCol", "output column name"), ParamDesc[Int]("checkpointInterval", "checkpoint interval"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true"))) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 42af80058f64..cc05af2d7ebc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -187,6 +187,23 @@ trait HasInputCol extends Params { final def getInputCol: String = getOrDefault(inputCol) } +/** + * :: DeveloperApi :: + * Trait for shared param inputCol. + */ +@DeveloperApi +trait HasInputCols extends Params { + + /** + * Param for input column names. + * @group param + */ + final val inputCols: Param[Array[String]] = new Param(this, "inputCols", "input column names") + + /** @group getParam */ + final def getInputCols: Array[String] = getOrDefault(inputCols) +} + /** * :: DeveloperApi :: * Trait for shared param outputCol. From 2737c2d6afa6221a71975e159ea8a5a16b5f846a Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sun, 12 Apr 2015 23:03:58 -0700 Subject: [PATCH 15/21] rename SharedParamCodeGen to SharedParamsCodeGen --- ...SharedParamCodeGen.scala => SharedParamsCodeGen.scala} | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) rename mllib/src/main/scala/org/apache/spark/ml/param/shared/{SharedParamCodeGen.scala => SharedParamsCodeGen.scala} (97%) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala similarity index 97% rename from mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala rename to mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index f99be87a2451..960990ffdd4e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -24,10 +24,10 @@ import scala.reflect.ClassTag /** * Code generator for shared params (sharedParams.scala). Run under the Spark folder with * {{{ - * build/sbt "mllib/runMain org.apache.spark.ml.param.shared.SharedParamCodeGen" + * build/sbt "mllib/runMain org.apache.spark.ml.param.shared.SharedParamsCodeGen" * }}} */ -private[shared] object SharedParamCodeGen { +private[shared] object SharedParamsCodeGen { def main(args: Array[String]): Unit = { val params = Seq( @@ -42,7 +42,7 @@ private[shared] object SharedParamCodeGen { "column name for predicted class conditional probabilities", Some("\"probability\"")), ParamDesc[Double]("threshold", "threshold in prediction"), ParamDesc[String]("inputCol", "input column name"), - ParamDesc[Array[String]]("labelCols", "label column names"), + ParamDesc[Array[String]]("inputCols", "input column names"), ParamDesc[String]("outputCol", "output column name"), ParamDesc[Int]("checkpointInterval", "checkpoint interval"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true"))) @@ -155,7 +155,7 @@ private[shared] object SharedParamCodeGen { |import org.apache.spark.annotation.DeveloperApi |import org.apache.spark.ml.param._ | - |// DO NOT MODIFY THIS FILE! It was generated by SharedParamCodeGen. + |// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen. | |// scalastyle:off |""".stripMargin From 4fee9e78d6d75bd919fc0aaa647798ddd3e1c105 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sun, 12 Apr 2015 23:07:54 -0700 Subject: [PATCH 16/21] re-gen shared params --- .../org/apache/spark/ml/param/shared/sharedParams.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index cc05af2d7ebc..ca77a14d433e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.param.shared import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param._ -// DO NOT MODIFY THIS FILE! It was generated by SharedParamCodeGen. +// DO NOT MODIFY THIS FILE! It was generated by SharedParamsCodeGen. // scalastyle:off @@ -189,7 +189,7 @@ trait HasInputCol extends Params { /** * :: DeveloperApi :: - * Trait for shared param inputCol. + * Trait for shared param inputCols. */ @DeveloperApi trait HasInputCols extends Params { @@ -198,7 +198,7 @@ trait HasInputCols extends Params { * Param for input column names. * @group param */ - final val inputCols: Param[Array[String]] = new Param(this, "inputCols", "input column names") + final val inputCols: Param[Array[String]] = new Param[Array[String]](this, "inputCols", "input column names") /** @group getParam */ final def getInputCols: Array[String] = getOrDefault(inputCols) From eec2264ef8e49a0be5f2b7539c8bba6e4279c01a Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sun, 12 Apr 2015 23:15:46 -0700 Subject: [PATCH 17/21] make get* public in Params --- .../scala/org/apache/spark/ml/param/params.scala | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index fe09cadf52d7..c3a36ecbdfba 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -149,13 +149,13 @@ trait Params extends Identifiable with Serializable { def explainParams(): String = params.map(explain).mkString("\n") /** Checks whether a param is explicitly set. */ - def isSet(param: Param[_]): Boolean = { + final def isSet(param: Param[_]): Boolean = { shouldOwn(param) paramMap.contains(param) } /** Checks whether a param is explicitly set or has a default value. */ - def isDefined(param: Param[_]): Boolean = { + final def isDefined(param: Param[_]): Boolean = { shouldOwn(param) defaultParamMap.contains(param) || paramMap.contains(param) } @@ -186,7 +186,7 @@ trait Params extends Identifiable with Serializable { /** * Optionally returns the user-supplied value of a param. */ - protected final def get[T](param: Param[T]): Option[T] = { + final def get[T](param: Param[T]): Option[T] = { shouldOwn(param) paramMap.get(param) } @@ -194,7 +194,7 @@ trait Params extends Identifiable with Serializable { /** * Clears the user-supplied value for the input param. */ - protected final def clear(param: Param[_]): this.type = { + final def clear(param: Param[_]): this.type = { shouldOwn(param) paramMap.remove(param) this @@ -204,7 +204,7 @@ trait Params extends Identifiable with Serializable { * Gets the value of a param in the embedded param map or its default value. Throws an exception * if neither is set. */ - protected final def getOrDefault[T](param: Param[T]): T = { + final def getOrDefault[T](param: Param[T]): T = { shouldOwn(param) get(param).orElse(getDefault(param)).get } @@ -231,7 +231,7 @@ trait Params extends Identifiable with Serializable { /** * Gets the default value of a parameter. */ - protected final def getDefault[T](param: Param[T]): Option[T] = { + final def getDefault[T](param: Param[T]): Option[T] = { shouldOwn(param) defaultParamMap.get(param) } @@ -239,7 +239,7 @@ trait Params extends Identifiable with Serializable { /** * Tests whether the input param has a default value set. */ - protected final def hasDefault[T](param: Param[T]): Boolean = { + final def hasDefault[T](param: Param[T]): Boolean = { shouldOwn(param) defaultParamMap.contains(param) } From 409e2d58b159c43f772288e4c0ffa99c91c19071 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 13 Apr 2015 16:59:55 -0700 Subject: [PATCH 18/21] address comments --- .../apache/spark/ml/classification/Classifier.scala | 1 + .../scala/org/apache/spark/ml/param/params.scala | 12 +++++++++--- .../spark/ml/param/shared/SharedParamsCodeGen.scala | 2 +- .../scala/org/apache/spark/ml/util/SchemaUtils.scala | 4 ++-- .../org/apache/spark/ml/param/ParamsSuite.scala | 3 +++ 5 files changed, 16 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 7a3f769253fc..29339c98f51c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} + /** * :: DeveloperApi :: * Params for classification. diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index c3a36ecbdfba..1b3c3b5d87eb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -210,7 +210,10 @@ trait Params extends Identifiable with Serializable { } /** - * Sets a default value. Make sure that the input param is initialized before this gets called. + * Sets a default value for a param. + * @param param param to set the default value. Make sure that this param is initialized before + * this method gets called. + * @param value the default value */ protected final def setDefault[T](param: Param[T], value: T): this.type = { shouldOwn(param) @@ -219,7 +222,10 @@ trait Params extends Identifiable with Serializable { } /** - * Sets default values. Make sure that the input params are initialized before this gets called. + * Sets default values for a list of params. + * @param paramPairs a list of param pairs that specify params and their default values to set + * respectively. Make sure that the params are initialized before this method + * gets called. */ protected final def setDefault(paramPairs: ParamPair[_]*): this.type = { paramPairs.foreach { p => @@ -247,7 +253,7 @@ trait Params extends Identifiable with Serializable { /** * Extracts the embedded default param values and user-supplied values, and then merges them with * extra values from input into a flat param map, where the latter value is used if there exist - * conflicts. + * conflicts, i.e., with ordering: default param values < user-supplied values < extraParamMap. */ protected final def extractParamMap(extraParamMap: ParamMap): ParamMap = { defaultParamMap ++ paramMap ++ extraParamMap diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 960990ffdd4e..95d7e64790c7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -40,7 +40,7 @@ private[shared] object SharedParamsCodeGen { Some("\"rawPrediction\"")), ParamDesc[String]("probabilityCol", "column name for predicted class conditional probabilities", Some("\"probability\"")), - ParamDesc[Double]("threshold", "threshold in prediction"), + ParamDesc[Double]("threshold", "threshold in binary classification prediction"), ParamDesc[String]("inputCol", "input column name"), ParamDesc[Array[String]]("inputCols", "input column names"), ParamDesc[String]("outputCol", "output column name"), diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index 90bce5977815..0383bf0b382b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -30,7 +30,7 @@ object SchemaUtils { // TODO: Move the utility methods to SQL. /** - * Check whether the given schema contains an column of the required data type. + * Check whether the given schema contains a column of the required data type. * @param colName column name * @param dataType required column data type */ @@ -43,7 +43,7 @@ object SchemaUtils { /** * Appends a new column to the input schema. This fails if the given output column already exists. * @param schema input schema - * @param colName new column name. If this column name is en empty string "", this method returns + * @param colName new column name. If this column name is an empty string "", this method returns * the input schema unchanged. This allows users to disable output columns. * @param dataType new column data type * @return new schema with the input column appended diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index da37162eccb9..d77eea559056 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -126,5 +126,8 @@ class ParamsSuite extends FunSuite { intercept[IllegalArgumentException] { solver.validate() } + + solver.clear(maxIter) + assert(!solver.isDefined(maxIter)) } } From 38b78c71cf88a7c3563cd08b710b1fc09316db00 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 13 Apr 2015 17:08:24 -0700 Subject: [PATCH 19/21] update Param.toString and remove Params.explain() --- .../org/apache/spark/ml/param/params.scala | 32 +++++++++---------- .../apache/spark/ml/param/ParamsSuite.scala | 14 ++++---- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 1b3c3b5d87eb..7d807652a702 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -49,7 +49,20 @@ class Param[T] (val parent: Params, val name: String, val doc: String) extends S */ def ->(value: T): ParamPair[T] = ParamPair(this, value) - override def toString: String = s"$name: $doc" + /** + * Converts this param's name, doc, and optionally its default value and the user-supplied + * value in its parent to string. + */ + override def toString: String = { + val valueStr = if (parent.isDefined(this)) { + val defaultValueStr = parent.getDefault(this).map("default: " + _) + val currentValueStr = parent.get(this).map("current: " + _) + (defaultValueStr ++ currentValueStr).mkString("(", ", ", ")") + } else { + "(undefined)" + } + s"$name: $doc $valueStr" + } } // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... @@ -128,25 +141,10 @@ trait Params extends Identifiable with Serializable { */ def validate(): Unit = validate(ParamMap.empty) - /** - * Explain a param and optionally its default value and the user-supplied value. - */ - def explain(param: Param[_]): String = { - shouldOwn(param) - val valueStr = if (isDefined(param)) { - val defaultValueStr = getDefault(param).map("default: " + _) - val currentValueStr = get(param).map("current: " + _) - (defaultValueStr ++ currentValueStr).mkString("(", ", ", ")") - } else { - "(undefined)" - } - s"${param.name}: ${param.doc} $valueStr" - } - /** * Returns the documentation of all params. */ - def explainParams(): String = params.map(explain).mkString("\n") + def explainParams(): String = params.mkString("\n") /** Checks whether a param is explicitly set. */ final def isSet(param: Param[_]): Boolean = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index d77eea559056..9bba214f8ad7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -23,12 +23,17 @@ class ParamsSuite extends FunSuite { test("param") { val solver = new TestParams() - import solver.maxIter + import solver.{maxIter, inputCol} assert(maxIter.name === "maxIter") assert(maxIter.doc === "max number of iterations") assert(maxIter.parent.eq(solver)) - assert(maxIter.toString === "maxIter: max number of iterations") + assert(maxIter.toString === "maxIter: max number of iterations (default: 10)") + + solver.setMaxIter(5) + assert(maxIter.toString === "maxIter: max number of iterations (default: 10, current: 5)") + + assert(inputCol.toString === "inputCol: input column name (undefined)") } test("param pair") { @@ -99,10 +104,7 @@ class ParamsSuite extends FunSuite { assert(!solver.isDefined(inputCol)) intercept[NoSuchElementException](solver.getInputCol) - assert( - solver.explain(maxIter) === "maxIter: max number of iterations (default: 10, current: 100)") - assert(solver.explain(inputCol) === "inputCol: input column name (undefined)") - assert(solver.explainParams() === Seq(inputCol, maxIter).map(solver.explain).mkString("\n")) + assert(solver.explainParams() === Seq(inputCol, maxIter).mkString("\n")) assert(solver.getParam("inputCol").eq(inputCol)) assert(solver.getParam("maxIter").eq(maxIter)) From 26ae2d7898fa598e6fef8a91721992f1d7c793e5 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 13 Apr 2015 17:12:31 -0700 Subject: [PATCH 20/21] re-gen code and mark clear protected --- mllib/src/main/scala/org/apache/spark/ml/param/params.scala | 2 +- .../scala/org/apache/spark/ml/param/shared/sharedParams.scala | 4 ++-- .../test/scala/org/apache/spark/ml/param/ParamsSuite.scala | 2 +- .../src/test/scala/org/apache/spark/ml/param/TestParams.scala | 2 ++ 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 7d807652a702..849c60433c77 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -192,7 +192,7 @@ trait Params extends Identifiable with Serializable { /** * Clears the user-supplied value for the input param. */ - final def clear(param: Param[_]): this.type = { + protected final def clear(param: Param[_]): this.type = { shouldOwn(param) paramMap.remove(param) this diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index ca77a14d433e..72b08bf27648 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -161,10 +161,10 @@ trait HasProbabilityCol extends Params { trait HasThreshold extends Params { /** - * Param for threshold in prediction. + * Param for threshold in binary classification prediction. * @group param */ - final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in prediction") + final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction") /** @group getParam */ final def getThreshold: Double = getOrDefault(threshold) diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 9bba214f8ad7..4a0f121c6763 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -129,7 +129,7 @@ class ParamsSuite extends FunSuite { solver.validate() } - solver.clear(maxIter) + solver.clearMaxIter() assert(!solver.isDefined(maxIter)) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala index 6109ed98323e..8f9ab687c05c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala @@ -35,4 +35,6 @@ class TestParams extends Params { require(m(maxIter) >= 0) require(m.contains(inputCol)) } + + def clearMaxIter(): this.type = clear(maxIter) } From d19236dd12d57ced30374e2b1ac22b720b5c42d8 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 13 Apr 2015 17:57:07 -0700 Subject: [PATCH 21/21] fix test --- .../src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 4a0f121c6763..88ea679eeaad 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -130,6 +130,6 @@ class ParamsSuite extends FunSuite { } solver.clearMaxIter() - assert(!solver.isDefined(maxIter)) + assert(!solver.isSet(maxIter)) } }