Skip to content

Commit d9302b8

Browse files
committed
generate default values
1 parent 1c72579 commit d9302b8

File tree

2 files changed

+50
-13
lines changed

2 files changed

+50
-13
lines changed

mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamCodeGen.scala

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,13 @@ private[shared] object SharedParamCodeGen {
3333
val params = Seq(
3434
ParamDesc[Double]("regParam", "regularization parameter"),
3535
ParamDesc[Int]("maxIter", "max number of iterations"),
36-
ParamDesc[String]("featuresCol", "features column name"),
37-
ParamDesc[String]("labelCol", "label column name"),
38-
ParamDesc[String]("predictionCol", "prediction column name"),
39-
ParamDesc[String]("rawPredictionCol", "raw prediction (a.k.a. confidence) column name"),
40-
ParamDesc[String](
41-
"probabilityCol", "column name for predicted class conditional probabilities"),
36+
ParamDesc[String]("featuresCol", "features column name", Some("\"features\"")),
37+
ParamDesc[String]("labelCol", "label column name", Some("\"label\"")),
38+
ParamDesc[String]("predictionCol", "prediction column name", Some("\"prediction\"")),
39+
ParamDesc[String]("rawPredictionCol", "raw prediction (a.k.a. confidence) column name",
40+
Some("\"rawPrediction\"")),
41+
ParamDesc[String]("probabilityCol",
42+
"column name for predicted class conditional probabilities", Some("\"probability\"")),
4243
ParamDesc[Double]("threshold", "threshold in prediction"),
4344
ParamDesc[String]("inputCol", "input column name"),
4445
ParamDesc[String]("outputCol", "output column name"),
@@ -52,7 +53,11 @@ private[shared] object SharedParamCodeGen {
5253
}
5354

5455
/** Description of a param. */
55-
private case class ParamDesc[T: ClassTag](name: String, doc: String) {
56+
private case class ParamDesc[T: ClassTag](
57+
name: String,
58+
doc: String,
59+
defaultValueStr: Option[String] = None) {
60+
5661
require(name.matches("[a-z][a-zA-Z0-9]*"), s"Param name $name is invalid.")
5762
require(doc.nonEmpty) // TODO: more rigorous on doc
5863

@@ -93,14 +98,24 @@ private[shared] object SharedParamCodeGen {
9398
val Param = param.paramTypeName
9499
val T = param.valueTypeName
95100
val doc = param.doc
101+
val defaultValue = param.defaultValueStr
102+
val defaultValueDoc = defaultValue.map { v =>
103+
s" (default: $v)"
104+
}.getOrElse("")
105+
val setDefault = defaultValue.map { v =>
106+
s"""
107+
| setDefault($name, $v)
108+
""".stripMargin
109+
}.getOrElse("")
96110

97111
s"""
98112
|/**
99113
| * :: DeveloperApi ::
100-
| * Trait for shared param $name.
114+
| * Trait for shared param $name$defaultValueDoc.
101115
| */
102116
|@DeveloperApi
103117
|trait Has$Name extends Params {
118+
|$setDefault
104119
| /**
105120
| * Param for $doc.
106121
| * @group param

mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import org.apache.spark.ml.param._
3131
*/
3232
@DeveloperApi
3333
trait HasRegParam extends Params {
34+
3435
/**
3536
* Param for regularization parameter.
3637
* @group param
@@ -47,6 +48,7 @@ trait HasRegParam extends Params {
4748
*/
4849
@DeveloperApi
4950
trait HasMaxIter extends Params {
51+
5052
/**
5153
* Param for max number of iterations.
5254
* @group param
@@ -59,10 +61,13 @@ trait HasMaxIter extends Params {
5961

6062
/**
6163
* :: DeveloperApi ::
62-
* Trait for shared param featuresCol.
64+
* Trait for shared param featuresCol (default: "features").
6365
*/
6466
@DeveloperApi
6567
trait HasFeaturesCol extends Params {
68+
69+
setDefault(featuresCol, "features")
70+
6671
/**
6772
* Param for features column name.
6873
* @group param
@@ -75,10 +80,13 @@ trait HasFeaturesCol extends Params {
7580

7681
/**
7782
* :: DeveloperApi ::
78-
* Trait for shared param labelCol.
83+
* Trait for shared param labelCol (default: "label").
7984
*/
8085
@DeveloperApi
8186
trait HasLabelCol extends Params {
87+
88+
setDefault(labelCol, "label")
89+
8290
/**
8391
* Param for label column name.
8492
* @group param
@@ -91,10 +99,13 @@ trait HasLabelCol extends Params {
9199

92100
/**
93101
* :: DeveloperApi ::
94-
* Trait for shared param predictionCol.
102+
* Trait for shared param predictionCol (default: "prediction").
95103
*/
96104
@DeveloperApi
97105
trait HasPredictionCol extends Params {
106+
107+
setDefault(predictionCol, "prediction")
108+
98109
/**
99110
* Param for prediction column name.
100111
* @group param
@@ -107,10 +118,13 @@ trait HasPredictionCol extends Params {
107118

108119
/**
109120
* :: DeveloperApi ::
110-
* Trait for shared param rawPredictionCol.
121+
* Trait for shared param rawPredictionCol (default: "rawPrediction").
111122
*/
112123
@DeveloperApi
113124
trait HasRawPredictionCol extends Params {
125+
126+
setDefault(rawPredictionCol, "rawPrediction")
127+
114128
/**
115129
* Param for raw prediction (a.k.a. confidence) column name.
116130
* @group param
@@ -123,10 +137,13 @@ trait HasRawPredictionCol extends Params {
123137

124138
/**
125139
* :: DeveloperApi ::
126-
* Trait for shared param probabilityCol.
140+
* Trait for shared param probabilityCol (default: "probability").
127141
*/
128142
@DeveloperApi
129143
trait HasProbabilityCol extends Params {
144+
145+
setDefault(probabilityCol, "probability")
146+
130147
/**
131148
* Param for column name for predicted class conditional probabilities.
132149
* @group param
@@ -143,6 +160,7 @@ trait HasProbabilityCol extends Params {
143160
*/
144161
@DeveloperApi
145162
trait HasThreshold extends Params {
163+
146164
/**
147165
* Param for threshold in prediction.
148166
* @group param
@@ -159,6 +177,7 @@ trait HasThreshold extends Params {
159177
*/
160178
@DeveloperApi
161179
trait HasInputCol extends Params {
180+
162181
/**
163182
* Param for input column name.
164183
* @group param
@@ -175,6 +194,7 @@ trait HasInputCol extends Params {
175194
*/
176195
@DeveloperApi
177196
trait HasOutputCol extends Params {
197+
178198
/**
179199
* Param for output column name.
180200
* @group param
@@ -191,6 +211,7 @@ trait HasOutputCol extends Params {
191211
*/
192212
@DeveloperApi
193213
trait HasCheckpointInterval extends Params {
214+
194215
/**
195216
* Param for checkpoint interval.
196217
* @group param
@@ -202,3 +223,4 @@ trait HasCheckpointInterval extends Params {
202223
}
203224

204225
// scalastyle:on
226+

0 commit comments

Comments
 (0)