Skip to content

Commit 114bad6

Browse files
jkbradleymengxr
authored andcommitted
[SPARK-7176] [ML] Add validation functionality to Param
Main change: Added isValid field to Param. Modified all usages to use isValid when relevant. Added helper methods in ParamValidate. Also overrode Params.validate() in: * CrossValidator + model * Pipeline + model I made a few updates for the elastic net patch: * I changed "tol" to "convergenceTol" * I added some documentation This PR is Scala + Java only. Python will be in a follow-up PR. CC: mengxr Author: Joseph K. Bradley <[email protected]> Closes #5740 from jkbradley/enforce-validate and squashes the following commits: ad9c6c1 [Joseph K. Bradley] re-generated sharedParams after merging with current master 76415e8 [Joseph K. Bradley] reverted convergenceTol to tol af62f4b [Joseph K. Bradley] Removed changes to SparkBuild, python linalg. Fixed test failures. Renamed ParamValidate to ParamValidators. Removed explicit type from ParamValidators calls where possible. bb2665a [Joseph K. Bradley] merged with elastic net pr ecda302 [Joseph K. Bradley] fix rat tests, plus add a little doc 6895dfc [Joseph K. Bradley] small cleanups 069ac6d [Joseph K. Bradley] many cleanups 928fb84 [Joseph K. Bradley] Maybe done a910ac7 [Joseph K. Bradley] still workin 6d60e2e [Joseph K. Bradley] Still workin b987319 [Joseph K. Bradley] Partly done with adding checks, but blocking on adding checking functionality to Param dbc9fb2 [Joseph K. Bradley] merged with master. enforcing Params.validate
1 parent 1fdfdb4 commit 114bad6

File tree

22 files changed

+593
-274
lines changed

22 files changed

+593
-274
lines changed

examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import org.apache.spark.ml.classification.ClassificationModel;
2929
import org.apache.spark.ml.param.IntParam;
3030
import org.apache.spark.ml.param.ParamMap;
31-
import org.apache.spark.ml.param.Params;
3231
import org.apache.spark.ml.param.Params$;
3332
import org.apache.spark.mllib.linalg.BLAS;
3433
import org.apache.spark.mllib.linalg.Vector;
@@ -100,11 +99,12 @@ public static void main(String[] args) throws Exception {
10099
/**
101100
* Example of defining a type of {@link Classifier}.
102101
*
103-
* NOTE: This is private since it is an example. In practice, you may not want it to be private.
102+
* Note: Some IDEs (e.g., IntelliJ) will complain that this will not compile due to
103+
* {@link org.apache.spark.ml.param.Params#set} using incompatible return types.
104+
* However, this should still compile and run successfully.
104105
*/
105106
class MyJavaLogisticRegression
106-
extends Classifier<Vector, MyJavaLogisticRegression, MyJavaLogisticRegressionModel>
107-
implements Params {
107+
extends Classifier<Vector, MyJavaLogisticRegression, MyJavaLogisticRegressionModel> {
108108

109109
/**
110110
* Param for max number of iterations
@@ -145,10 +145,12 @@ public MyJavaLogisticRegressionModel train(DataFrame dataset, ParamMap paramMap)
145145
/**
146146
* Example of defining a type of {@link ClassificationModel}.
147147
*
148-
* NOTE: This is private since it is an example. In practice, you may not want it to be private.
148+
* Note: Some IDEs (e.g., IntelliJ) will complain that this will not compile due to
149+
* {@link org.apache.spark.ml.param.Params#set} using incompatible return types.
150+
* However, this should still compile and run successfully.
149151
*/
150152
class MyJavaLogisticRegressionModel
151-
extends ClassificationModel<Vector, MyJavaLogisticRegressionModel> implements Params {
153+
extends ClassificationModel<Vector, MyJavaLogisticRegressionModel> {
152154

153155
private MyJavaLogisticRegression parent_;
154156
public MyJavaLogisticRegression parent() { return parent_; }

mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import scala.collection.mutable.ListBuffer
2121

2222
import org.apache.spark.Logging
2323
import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
24-
import org.apache.spark.ml.param.{Param, ParamMap}
24+
import org.apache.spark.ml.param.{Params, Param, ParamMap}
2525
import org.apache.spark.sql.DataFrame
2626
import org.apache.spark.sql.types.StructType
2727

@@ -86,6 +86,14 @@ class Pipeline extends Estimator[PipelineModel] {
8686
def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this }
8787
def getStages: Array[PipelineStage] = getOrDefault(stages)
8888

89+
override def validate(paramMap: ParamMap): Unit = {
90+
val map = extractParamMap(paramMap)
91+
getStages.foreach {
92+
case pStage: Params => pStage.validate(map)
93+
case _ =>
94+
}
95+
}
96+
8997
/**
9098
* Fits the pipeline to the input dataset with additional parameters. If a stage is an
9199
* [[Estimator]], its [[Estimator#fit]] method will be called on the input dataset to fit a model.
@@ -140,7 +148,7 @@ class Pipeline extends Estimator[PipelineModel] {
140148
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
141149
val map = extractParamMap(paramMap)
142150
val theStages = map(stages)
143-
require(theStages.toSet.size == theStages.size,
151+
require(theStages.toSet.size == theStages.length,
144152
"Cannot have duplicate components in a pipeline.")
145153
theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur, paramMap))
146154
}
@@ -157,6 +165,11 @@ class PipelineModel private[ml] (
157165
private[ml] val stages: Array[Transformer])
158166
extends Model[PipelineModel] with Logging {
159167

168+
override def validate(paramMap: ParamMap): Unit = {
169+
val map = fittingParamMap ++ extractParamMap(paramMap)
170+
stages.foreach(_.validate(map))
171+
}
172+
160173
/**
161174
* Gets the model produced by the input estimator. Throws an NoSuchElementException is the input
162175
* estimator does not exist in the pipeline.
@@ -168,7 +181,7 @@ class PipelineModel private[ml] (
168181
}
169182
if (matched.isEmpty) {
170183
throw new NoSuchElementException(s"Cannot find stage $stage from the pipeline.")
171-
} else if (matched.size > 1) {
184+
} else if (matched.length > 1) {
172185
throw new IllegalStateException(s"Cannot have duplicate estimators in the sample pipeline.")
173186
} else {
174187
matched.head.asInstanceOf[M]

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -103,21 +103,16 @@ final class GBTClassifier
103103
*/
104104
val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
105105
" tries to minimize (case-insensitive). Supported options:" +
106-
s" ${GBTClassifier.supportedLossTypes.mkString(", ")}")
106+
s" ${GBTClassifier.supportedLossTypes.mkString(", ")}",
107+
(value: String) => GBTClassifier.supportedLossTypes.contains(value.toLowerCase))
107108

108109
setDefault(lossType -> "logistic")
109110

110111
/** @group setParam */
111-
def setLossType(value: String): this.type = {
112-
val lossStr = value.toLowerCase
113-
require(GBTClassifier.supportedLossTypes.contains(lossStr), "GBTClassifier was given bad loss" +
114-
s" type: $value. Supported options: ${GBTClassifier.supportedLossTypes.mkString(", ")}")
115-
set(lossType, lossStr)
116-
this
117-
}
112+
def setLossType(value: String): this.type = set(lossType, value)
118113

119114
/** @group getParam */
120-
def getLossType: String = getOrDefault(lossType)
115+
def getLossType: String = getOrDefault(lossType).toLowerCase
121116

122117
/** (private[ml]) Convert new loss to old loss. */
123118
override private[ml] def getOldLossType: OldLoss = {

mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.ml.feature
1919

2020
import org.apache.spark.annotation.AlphaComponent
2121
import org.apache.spark.ml.UnaryTransformer
22-
import org.apache.spark.ml.param.{IntParam, ParamMap}
22+
import org.apache.spark.ml.param.{ParamValidators, IntParam, ParamMap}
2323
import org.apache.spark.mllib.feature
2424
import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
2525
import org.apache.spark.sql.types.DataType
@@ -32,19 +32,21 @@ import org.apache.spark.sql.types.DataType
3232
class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] {
3333

3434
/**
35-
* number of features
35+
* Number of features. Should be > 0.
36+
* (default = 2^18^)
3637
* @group param
3738
*/
38-
val numFeatures = new IntParam(this, "numFeatures", "number of features")
39+
val numFeatures = new IntParam(this, "numFeatures", "number of features (> 0)",
40+
ParamValidators.gt(0))
41+
42+
setDefault(numFeatures -> (1 << 18))
3943

4044
/** @group getParam */
4145
def getNumFeatures: Int = getOrDefault(numFeatures)
4246

4347
/** @group setParam */
4448
def setNumFeatures(value: Int): this.type = set(numFeatures, value)
4549

46-
setDefault(numFeatures -> (1 << 18))
47-
4850
override protected def createTransformFunc(paramMap: ParamMap): Iterable[_] => Vector = {
4951
val hashingTF = new feature.HashingTF(paramMap(numFeatures))
5052
hashingTF.transform

mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.ml.feature
1919

2020
import org.apache.spark.annotation.AlphaComponent
2121
import org.apache.spark.ml.UnaryTransformer
22-
import org.apache.spark.ml.param.{DoubleParam, ParamMap}
22+
import org.apache.spark.ml.param.{ParamValidators, DoubleParam, ParamMap}
2323
import org.apache.spark.mllib.feature
2424
import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
2525
import org.apache.spark.sql.types.DataType
@@ -32,19 +32,20 @@ import org.apache.spark.sql.types.DataType
3232
class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] {
3333

3434
/**
35-
* Normalization in L^p^ space, p = 2 by default.
35+
* Normalization in L^p^ space. Must be >= 1.
36+
* (default: p = 2)
3637
* @group param
3738
*/
38-
val p = new DoubleParam(this, "p", "the p norm value")
39+
val p = new DoubleParam(this, "p", "the p norm value", ParamValidators.gtEq(1))
40+
41+
setDefault(p -> 2.0)
3942

4043
/** @group getParam */
4144
def getP: Double = getOrDefault(p)
4245

4346
/** @group setParam */
4447
def setP(value: Double): this.type = set(p, value)
4548

46-
setDefault(p -> 2.0)
47-
4849
override protected def createTransformFunc(paramMap: ParamMap): Vector => Vector = {
4950
val normalizer = new feature.Normalizer(paramMap(p))
5051
normalizer.transform

mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import scala.collection.mutable
2121

2222
import org.apache.spark.annotation.AlphaComponent
2323
import org.apache.spark.ml.UnaryTransformer
24-
import org.apache.spark.ml.param.{IntParam, ParamMap}
24+
import org.apache.spark.ml.param.{ParamValidators, IntParam, ParamMap}
2525
import org.apache.spark.mllib.linalg._
2626
import org.apache.spark.sql.types.DataType
2727

@@ -37,10 +37,13 @@ import org.apache.spark.sql.types.DataType
3737
class PolynomialExpansion extends UnaryTransformer[Vector, Vector, PolynomialExpansion] {
3838

3939
/**
40-
* The polynomial degree to expand, which should be larger than 1.
40+
* The polynomial degree to expand, which should be >= 1. A value of 1 means no expansion.
41+
* Default: 2
4142
* @group param
4243
*/
43-
val degree = new IntParam(this, "degree", "the polynomial degree to expand")
44+
val degree = new IntParam(this, "degree", "the polynomial degree to expand (>= 1)",
45+
ParamValidators.gt(1))
46+
4447
setDefault(degree -> 2)
4548

4649
/** @group getParam */

mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,19 @@ import org.apache.spark.sql.types.{StructField, StructType}
3131
* Params for [[StandardScaler]] and [[StandardScalerModel]].
3232
*/
3333
private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol {
34-
34+
3535
/**
36-
* False by default. Centers the data with mean before scaling.
36+
* Centers the data with mean before scaling.
3737
* It will build a dense output, so this does not work on sparse input
3838
* and will raise an exception.
39+
* Default: false
3940
* @group param
4041
*/
4142
val withMean: BooleanParam = new BooleanParam(this, "withMean", "Center data with mean")
4243

4344
/**
44-
* True by default. Scales the data to unit standard deviation.
45+
* Scales the data to unit standard deviation.
46+
* Default: true
4547
* @group param
4648
*/
4749
val withStd: BooleanParam = new BooleanParam(this, "withStd", "Scale to unit standard deviation")
@@ -56,7 +58,7 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with
5658
class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerParams {
5759

5860
setDefault(withMean -> false, withStd -> true)
59-
61+
6062
/** @group setParam */
6163
def setInputCol(value: String): this.type = set(inputCol, value)
6264

mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.ml.feature
1919

2020
import org.apache.spark.annotation.AlphaComponent
2121
import org.apache.spark.ml.UnaryTransformer
22-
import org.apache.spark.ml.param.{ParamMap, IntParam, BooleanParam, Param}
22+
import org.apache.spark.ml.param._
2323
import org.apache.spark.sql.types.{DataType, StringType, ArrayType}
2424

2525
/**
@@ -43,20 +43,20 @@ class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] {
4343
/**
4444
* :: AlphaComponent ::
4545
* A regex based tokenizer that extracts tokens either by repeatedly matching the regex(default)
46-
* or using it to split the text (set matching to false). Optional parameters also allow to fold
47-
* the text to lowercase prior to it being tokenized and to filer tokens using a minimal length.
46+
* or using it to split the text (set matching to false). Optional parameters also allow filtering
47+
* tokens using a minimal length.
4848
* It returns an array of strings that can be empty.
49-
* The default parameters are regex = "\\p{L}+|[^\\p{L}\\s]+", matching = true,
50-
* lowercase = false, minTokenLength = 1
5149
*/
5250
@AlphaComponent
5351
class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenizer] {
5452

5553
/**
56-
* param for minimum token length, default is one to avoid returning empty strings
54+
* Minimum token length, >= 0.
55+
* Default: 1, to avoid returning empty strings
5756
* @group param
5857
*/
59-
val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length")
58+
val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length (>= 0)",
59+
ParamValidators.gtEq(0))
6060

6161
/** @group setParam */
6262
def setMinTokenLength(value: Int): this.type = set(minTokenLength, value)
@@ -65,7 +65,8 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize
6565
def getMinTokenLength: Int = getOrDefault(minTokenLength)
6666

6767
/**
68-
* param sets regex as splitting on gaps (true) or matching tokens (false)
68+
* Indicates whether regex splits on gaps (true) or matching tokens (false).
69+
* Default: false
6970
* @group param
7071
*/
7172
val gaps: BooleanParam = new BooleanParam(this, "gaps", "Set regex to match gaps or tokens")
@@ -77,7 +78,8 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize
7778
def getGaps: Boolean = getOrDefault(gaps)
7879

7980
/**
80-
* param sets regex pattern used by tokenizer
81+
* Regex pattern used by tokenizer.
82+
* Default: `"\\p{L}+|[^\\p{L}\\s]+"`
8183
* @group param
8284
*/
8385
val pattern: Param[String] = new Param(this, "pattern", "regex pattern used for tokenizing")

mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import org.apache.spark.ml.util.SchemaUtils
2222
import org.apache.spark.ml.{Estimator, Model}
2323
import org.apache.spark.ml.attribute.{BinaryAttribute, NumericAttribute, NominalAttribute,
2424
Attribute, AttributeGroup}
25-
import org.apache.spark.ml.param.{IntParam, ParamMap, Params}
25+
import org.apache.spark.ml.param.{ParamValidators, IntParam, ParamMap, Params}
2626
import org.apache.spark.ml.param.shared._
2727
import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, VectorUDT}
2828
import org.apache.spark.sql.{Row, DataFrame}
@@ -37,17 +37,19 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu
3737
/**
3838
* Threshold for the number of values a categorical feature can take.
3939
* If a feature is found to have > maxCategories values, then it is declared continuous.
40+
* Must be >= 2.
4041
*
4142
* (default = 20)
4243
*/
4344
val maxCategories = new IntParam(this, "maxCategories",
44-
"Threshold for the number of values a categorical feature can take." +
45-
" If a feature is found to have > maxCategories values, then it is declared continuous.")
45+
"Threshold for the number of values a categorical feature can take (>= 2)." +
46+
" If a feature is found to have > maxCategories values, then it is declared continuous.",
47+
ParamValidators.gtEq(2))
48+
49+
setDefault(maxCategories -> 20)
4650

4751
/** @group getParam */
4852
def getMaxCategories: Int = getOrDefault(maxCategories)
49-
50-
setDefault(maxCategories -> 20)
5153
}
5254

5355
/**
@@ -90,11 +92,7 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu
9092
class VectorIndexer extends Estimator[VectorIndexerModel] with VectorIndexerParams {
9193

9294
/** @group setParam */
93-
def setMaxCategories(value: Int): this.type = {
94-
require(value > 1,
95-
s"DatasetIndexer given maxCategories = value, but requires maxCategories > 1.")
96-
set(maxCategories, value)
97-
}
95+
def setMaxCategories(value: Int): this.type = set(maxCategories, value)
9896

9997
/** @group setParam */
10098
def setInputCol(value: String): this.type = set(inputCol, value)

0 commit comments

Comments
 (0)