Skip to content

Commit 971b95b

Browse files
committed
[SPARK-5957][ML] better handling of parameters
The design doc was posted on the JIRA page. Python changes will be in a follow-up PR. jkbradley 1. Use codegen for shared params. 1. Move shared params to package `ml.param.shared`. 1. Set default values in `Params` instead of in `Param`. 1. Add a few methods to `Params` and `ParamMap`. 1. Move schema handling to `SchemaUtils` from `Params`. - [x] check visibility of the methods added Author: Xiangrui Meng <[email protected]> Closes apache#5431 from mengxr/SPARK-5957 and squashes the following commits: d19236d [Xiangrui Meng] fix test 26ae2d7 [Xiangrui Meng] re-gen code and mark clear protected 38b78c7 [Xiangrui Meng] update Param.toString and remove Params.explain() 409e2d5 [Xiangrui Meng] address comments 2d637bd [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5957 eec2264 [Xiangrui Meng] make get* public in Params 4090d95 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5957 4fee9e7 [Xiangrui Meng] re-gen shared params 2737c2d [Xiangrui Meng] rename SharedParamCodeGen to SharedParamsCodeGen e938f81 [Xiangrui Meng] update code to set default parameter values 28ed322 [Xiangrui Meng] merge master 55be1f3 [Xiangrui Meng] merge master d63b5cc [Xiangrui Meng] fix examples 29b004c [Xiangrui Meng] update ParamsSuite 94fd98e [Xiangrui Meng] fix explain params 48d0e84 [Xiangrui Meng] add remove and update explainParams 4ac6348 [Xiangrui Meng] move schema utils to SchemaUtils add a few methods to Params 0d9594e [Xiangrui Meng] add getOrElse to ParamMap eeeffe8 [Xiangrui Meng] map ++ paramMap => extractValues 0d3fc5b [Xiangrui Meng] setDefault after param a9dbf59 [Xiangrui Meng] minor updates d9302b8 [Xiangrui Meng] generate default values 1c72579 [Xiangrui Meng] pass test compile abb7a3b [Xiangrui Meng] update default values handling dcab97a [Xiangrui Meng] add codegen for shared params
1 parent 0ba3fdd commit 971b95b

File tree

27 files changed

+820
-396
lines changed

27 files changed

+820
-396
lines changed

examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ class MyJavaLogisticRegression
116116
*/
117117
IntParam maxIter = new IntParam(this, "maxIter", "max number of iterations");
118118

119-
int getMaxIter() { return (Integer) get(maxIter); }
119+
int getMaxIter() { return (Integer) getOrDefault(maxIter); }
120120

121121
public MyJavaLogisticRegression() {
122122
setMaxIter(100);
@@ -211,7 +211,7 @@ public Vector predictRaw(Vector features) {
211211
public MyJavaLogisticRegressionModel copy() {
212212
MyJavaLogisticRegressionModel m =
213213
new MyJavaLogisticRegressionModel(parent_, fittingParamMap_, weights_);
214-
Params$.MODULE$.inheritValues(this.paramMap(), this, m);
214+
Params$.MODULE$.inheritValues(this.extractParamMap(), this, m);
215215
return m;
216216
}
217217
}

examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ private trait MyLogisticRegressionParams extends ClassifierParams {
9999
* class since the maxIter parameter is only used during training (not in the Model).
100100
*/
101101
val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations")
102-
def getMaxIter: Int = get(maxIter)
102+
def getMaxIter: Int = getOrDefault(maxIter)
103103
}
104104

105105
/**
@@ -174,11 +174,11 @@ private class MyLogisticRegressionModel(
174174
* Create a copy of the model.
175175
* The copy is shallow, except for the embedded paramMap, which gets a deep copy.
176176
*
177-
* This is used for the defaul implementation of [[transform()]].
177+
* This is used for the default implementation of [[transform()]].
178178
*/
179179
override protected def copy(): MyLogisticRegressionModel = {
180180
val m = new MyLogisticRegressionModel(parent, fittingParamMap, weights)
181-
Params.inheritValues(this.paramMap, this, m)
181+
Params.inheritValues(extractParamMap(), this, m)
182182
m
183183
}
184184
}

mllib/src/main/scala/org/apache/spark/ml/Estimator.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
4040
*/
4141
@varargs
4242
def fit(dataset: DataFrame, paramPairs: ParamPair[_]*): M = {
43-
val map = new ParamMap().put(paramPairs: _*)
43+
val map = ParamMap(paramPairs: _*)
4444
fit(dataset, map)
4545
}
4646

mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class Pipeline extends Estimator[PipelineModel] {
8484
/** param for pipeline stages */
8585
val stages: Param[Array[PipelineStage]] = new Param(this, "stages", "stages of the pipeline")
8686
def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this }
87-
def getStages: Array[PipelineStage] = get(stages)
87+
def getStages: Array[PipelineStage] = getOrDefault(stages)
8888

8989
/**
9090
* Fits the pipeline to the input dataset with additional parameters. If a stage is an
@@ -101,7 +101,7 @@ class Pipeline extends Estimator[PipelineModel] {
101101
*/
102102
override def fit(dataset: DataFrame, paramMap: ParamMap): PipelineModel = {
103103
transformSchema(dataset.schema, paramMap, logging = true)
104-
val map = this.paramMap ++ paramMap
104+
val map = extractParamMap(paramMap)
105105
val theStages = map(stages)
106106
// Search for the last estimator.
107107
var indexOfLastEstimator = -1
@@ -138,7 +138,7 @@ class Pipeline extends Estimator[PipelineModel] {
138138
}
139139

140140
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
141-
val map = this.paramMap ++ paramMap
141+
val map = extractParamMap(paramMap)
142142
val theStages = map(stages)
143143
require(theStages.toSet.size == theStages.size,
144144
"Cannot have duplicate components in a pipeline.")
@@ -177,14 +177,14 @@ class PipelineModel private[ml] (
177177

178178
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
179179
// Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
180-
val map = (fittingParamMap ++ this.paramMap) ++ paramMap
180+
val map = fittingParamMap ++ extractParamMap(paramMap)
181181
transformSchema(dataset.schema, map, logging = true)
182182
stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, map))
183183
}
184184

185185
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
186186
// Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
187-
val map = (fittingParamMap ++ this.paramMap) ++ paramMap
187+
val map = fittingParamMap ++ extractParamMap(paramMap)
188188
stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, map))
189189
}
190190
}

mllib/src/main/scala/org/apache/spark/ml/Transformer.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import scala.annotation.varargs
2222
import org.apache.spark.Logging
2323
import org.apache.spark.annotation.AlphaComponent
2424
import org.apache.spark.ml.param._
25+
import org.apache.spark.ml.param.shared._
2526
import org.apache.spark.sql.DataFrame
2627
import org.apache.spark.sql.functions._
2728
import org.apache.spark.sql.types._
@@ -86,7 +87,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O
8687
protected def validateInputType(inputType: DataType): Unit = {}
8788

8889
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
89-
val map = this.paramMap ++ paramMap
90+
val map = extractParamMap(paramMap)
9091
val inputType = schema(map(inputCol)).dataType
9192
validateInputType(inputType)
9293
if (schema.fieldNames.contains(map(outputCol))) {
@@ -99,7 +100,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O
99100

100101
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
101102
transformSchema(dataset.schema, paramMap, logging = true)
102-
val map = this.paramMap ++ paramMap
103+
val map = extractParamMap(paramMap)
103104
dataset.withColumn(map(outputCol),
104105
callUDF(this.createTransformFunc(map), outputDataType, dataset(map(inputCol))))
105106
}

mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@
1717

1818
package org.apache.spark.ml.classification
1919

20-
import org.apache.spark.annotation.{DeveloperApi, AlphaComponent}
20+
import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
2121
import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams}
22-
import org.apache.spark.ml.param.{Params, ParamMap, HasRawPredictionCol}
22+
import org.apache.spark.ml.param.{ParamMap, Params}
23+
import org.apache.spark.ml.param.shared.HasRawPredictionCol
24+
import org.apache.spark.ml.util.SchemaUtils
2325
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
24-
import org.apache.spark.sql.functions._
2526
import org.apache.spark.sql.DataFrame
27+
import org.apache.spark.sql.functions._
2628
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
2729

2830

@@ -42,8 +44,8 @@ private[spark] trait ClassifierParams extends PredictorParams
4244
fitting: Boolean,
4345
featuresDataType: DataType): StructType = {
4446
val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType)
45-
val map = this.paramMap ++ paramMap
46-
addOutputColumn(parentSchema, map(rawPredictionCol), new VectorUDT)
47+
val map = extractParamMap(paramMap)
48+
SchemaUtils.appendColumn(parentSchema, map(rawPredictionCol), new VectorUDT)
4749
}
4850
}
4951

@@ -67,8 +69,7 @@ private[spark] abstract class Classifier[
6769
with ClassifierParams {
6870

6971
/** @group setParam */
70-
def setRawPredictionCol(value: String): E =
71-
set(rawPredictionCol, value).asInstanceOf[E]
72+
def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E]
7273

7374
// TODO: defaultEvaluator (follow-up PR)
7475
}
@@ -109,7 +110,7 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
109110

110111
// Check schema
111112
transformSchema(dataset.schema, paramMap, logging = true)
112-
val map = this.paramMap ++ paramMap
113+
val map = extractParamMap(paramMap)
113114

114115
// Prepare model
115116
val tmpModel = if (paramMap.size != 0) {

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,22 @@ package org.apache.spark.ml.classification
1919

2020
import org.apache.spark.annotation.AlphaComponent
2121
import org.apache.spark.ml.param._
22+
import org.apache.spark.ml.param.shared._
2223
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
2324
import org.apache.spark.mllib.linalg.{VectorUDT, BLAS, Vector, Vectors}
2425
import org.apache.spark.sql.DataFrame
2526
import org.apache.spark.sql.functions._
26-
import org.apache.spark.sql.types.DoubleType
2727
import org.apache.spark.storage.StorageLevel
2828

2929

3030
/**
3131
* Params for logistic regression.
3232
*/
3333
private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams
34-
with HasRegParam with HasMaxIter with HasFitIntercept with HasThreshold
34+
with HasRegParam with HasMaxIter with HasFitIntercept with HasThreshold {
3535

36+
setDefault(regParam -> 0.1, maxIter -> 100, threshold -> 0.5)
37+
}
3638

3739
/**
3840
* :: AlphaComponent ::
@@ -45,10 +47,6 @@ class LogisticRegression
4547
extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel]
4648
with LogisticRegressionParams {
4749

48-
setRegParam(0.1)
49-
setMaxIter(100)
50-
setThreshold(0.5)
51-
5250
/** @group setParam */
5351
def setRegParam(value: Double): this.type = set(regParam, value)
5452

@@ -100,8 +98,6 @@ class LogisticRegressionModel private[ml] (
10098
extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel]
10199
with LogisticRegressionParams {
102100

103-
setThreshold(0.5)
104-
105101
/** @group setParam */
106102
def setThreshold(value: Double): this.type = set(threshold, value)
107103

@@ -123,7 +119,7 @@ class LogisticRegressionModel private[ml] (
123119
// Check schema
124120
transformSchema(dataset.schema, paramMap, logging = true)
125121

126-
val map = this.paramMap ++ paramMap
122+
val map = extractParamMap(paramMap)
127123

128124
// Output selected columns only.
129125
// This is a bit complicated since it tries to avoid repeated computation.
@@ -184,7 +180,7 @@ class LogisticRegressionModel private[ml] (
184180
* The behavior of this can be adjusted using [[threshold]].
185181
*/
186182
override protected def predict(features: Vector): Double = {
187-
if (score(features) > paramMap(threshold)) 1 else 0
183+
if (score(features) > getThreshold) 1 else 0
188184
}
189185

190186
override protected def predictProbabilities(features: Vector): Vector = {
@@ -199,7 +195,7 @@ class LogisticRegressionModel private[ml] (
199195

200196
override protected def copy(): LogisticRegressionModel = {
201197
val m = new LogisticRegressionModel(parent, fittingParamMap, weights, intercept)
202-
Params.inheritValues(this.paramMap, this, m)
198+
Params.inheritValues(this.extractParamMap(), this, m)
203199
m
204200
}
205201
}

mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@
1818
package org.apache.spark.ml.classification
1919

2020
import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
21-
import org.apache.spark.ml.param.{HasProbabilityCol, ParamMap, Params}
21+
import org.apache.spark.ml.param.{ParamMap, Params}
22+
import org.apache.spark.ml.param.shared._
23+
import org.apache.spark.ml.util.SchemaUtils
2224
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
2325
import org.apache.spark.sql.DataFrame
2426
import org.apache.spark.sql.functions._
2527
import org.apache.spark.sql.types.{DataType, StructType}
2628

27-
2829
/**
2930
* Params for probabilistic classification.
3031
*/
@@ -37,8 +38,8 @@ private[classification] trait ProbabilisticClassifierParams
3738
fitting: Boolean,
3839
featuresDataType: DataType): StructType = {
3940
val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType)
40-
val map = this.paramMap ++ paramMap
41-
addOutputColumn(parentSchema, map(probabilityCol), new VectorUDT)
41+
val map = extractParamMap(paramMap)
42+
SchemaUtils.appendColumn(parentSchema, map(probabilityCol), new VectorUDT)
4243
}
4344
}
4445

@@ -102,7 +103,7 @@ private[spark] abstract class ProbabilisticClassificationModel[
102103

103104
// Check schema
104105
transformSchema(dataset.schema, paramMap, logging = true)
105-
val map = this.paramMap ++ paramMap
106+
val map = extractParamMap(paramMap)
106107

107108
// Prepare model
108109
val tmpModel = if (paramMap.size != 0) {

mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,13 @@ package org.apache.spark.ml.evaluation
2020
import org.apache.spark.annotation.AlphaComponent
2121
import org.apache.spark.ml.Evaluator
2222
import org.apache.spark.ml.param._
23+
import org.apache.spark.ml.param.shared._
24+
import org.apache.spark.ml.util.SchemaUtils
2325
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
2426
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
2527
import org.apache.spark.sql.{DataFrame, Row}
2628
import org.apache.spark.sql.types.DoubleType
2729

28-
2930
/**
3031
* :: AlphaComponent ::
3132
*
@@ -40,10 +41,10 @@ class BinaryClassificationEvaluator extends Evaluator with Params
4041
* @group param
4142
*/
4243
val metricName: Param[String] = new Param(this, "metricName",
43-
"metric name in evaluation (areaUnderROC|areaUnderPR)", Some("areaUnderROC"))
44+
"metric name in evaluation (areaUnderROC|areaUnderPR)")
4445

4546
/** @group getParam */
46-
def getMetricName: String = get(metricName)
47+
def getMetricName: String = getOrDefault(metricName)
4748

4849
/** @group setParam */
4950
def setMetricName(value: String): this.type = set(metricName, value)
@@ -54,12 +55,14 @@ class BinaryClassificationEvaluator extends Evaluator with Params
5455
/** @group setParam */
5556
def setLabelCol(value: String): this.type = set(labelCol, value)
5657

58+
setDefault(metricName -> "areaUnderROC")
59+
5760
override def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = {
58-
val map = this.paramMap ++ paramMap
61+
val map = extractParamMap(paramMap)
5962

6063
val schema = dataset.schema
61-
checkInputColumn(schema, map(rawPredictionCol), new VectorUDT)
62-
checkInputColumn(schema, map(labelCol), DoubleType)
64+
SchemaUtils.checkColumnType(schema, map(rawPredictionCol), new VectorUDT)
65+
SchemaUtils.checkColumnType(schema, map(labelCol), DoubleType)
6366

6467
// TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2.
6568
val scoreAndLabels = dataset.select(map(rawPredictionCol), map(labelCol))

mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,16 @@ class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] {
3535
* number of features
3636
* @group param
3737
*/
38-
val numFeatures = new IntParam(this, "numFeatures", "number of features", Some(1 << 18))
38+
val numFeatures = new IntParam(this, "numFeatures", "number of features")
3939

4040
/** @group getParam */
41-
def getNumFeatures: Int = get(numFeatures)
41+
def getNumFeatures: Int = getOrDefault(numFeatures)
4242

4343
/** @group setParam */
4444
def setNumFeatures(value: Int): this.type = set(numFeatures, value)
4545

46+
setDefault(numFeatures -> (1 << 18))
47+
4648
override protected def createTransformFunc(paramMap: ParamMap): Iterable[_] => Vector = {
4749
val hashingTF = new feature.HashingTF(paramMap(numFeatures))
4850
hashingTF.transform

0 commit comments

Comments
 (0)