Skip to content

Commit dbc9fb2

Browse files
committed
merged with master. enforcing Params.validate
1 parent c9d530e commit dbc9fb2

File tree

11 files changed

+84
-25
lines changed

11 files changed

+84
-25
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,16 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
3434
with HasRegParam with HasMaxIter with HasFitIntercept with HasThreshold {
3535

3636
setDefault(regParam -> 0.1, maxIter -> 100, threshold -> 0.5)
37+
38+
override def validate(paramMap: ParamMap): Unit = {
39+
require(getOrDefault(regParam) >= 0,
40+
s"LogisticRegression regParam must be >= 0, but was ${getOrDefault(regParam)}")
41+
require(getOrDefault(maxIter) >= 0,
42+
s"LogisticRegression maxIter must be >= 0, but was ${getOrDefault(maxIter)}")
43+
val threshold_ = getOrDefault(threshold)
44+
require(threshold_ >= 0 && threshold_ <= 1,
45+
s"LogisticRegression threshold must be in range [0,1], but was $threshold_")
46+
}
3747
}
3848

3949
/**

mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ import org.apache.spark.sql.types.DoubleType
3636
class BinaryClassificationEvaluator extends Evaluator with Params
3737
with HasRawPredictionCol with HasLabelCol {
3838

39+
override def validate(paramMap: ParamMap): Unit = { }
40+
3941
/**
4042
* param for metric name in evaluation
4143
* @group param

mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,20 +31,26 @@ import org.apache.spark.sql.types.DataType
3131
@AlphaComponent
3232
class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] {
3333

34+
override def validate(paramMap: ParamMap): Unit = {
35+
require(getOrDefault(numFeatures) > 0,
36+
s"HashingTF numFeatures must be > 0, but was ${getOrDefault(numFeatures)}")
37+
}
38+
3439
/**
35-
* number of features
40+
* Number of features. Should be > 0.
41+
* (default = 2^18^)
3642
* @group param
3743
*/
38-
val numFeatures = new IntParam(this, "numFeatures", "number of features")
44+
val numFeatures = new IntParam(this, "numFeatures", "number of features (> 0)")
45+
46+
setDefault(numFeatures -> (1 << 18))
3947

4048
/** @group getParam */
4149
def getNumFeatures: Int = getOrDefault(numFeatures)
4250

4351
/** @group setParam */
4452
def setNumFeatures(value: Int): this.type = set(numFeatures, value)
4553

46-
setDefault(numFeatures -> (1 << 18))
47-
4854
override protected def createTransformFunc(paramMap: ParamMap): Iterable[_] => Vector = {
4955
val hashingTF = new feature.HashingTF(paramMap(numFeatures))
5056
hashingTF.transform

mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,20 +31,25 @@ import org.apache.spark.sql.types.DataType
3131
@AlphaComponent
3232
class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] {
3333

34+
override def validate(paramMap: ParamMap): Unit = {
35+
require(getOrDefault(p) >= 0, s"Normalizer p must be >= 0, but was ${getOrDefault(p)}")
36+
}
37+
3438
/**
35-
* Normalization in L^p^ space, p = 2 by default.
39+
* Normalization in L^p^ space. Must be >= 1.
40+
* (default: p = 2)
3641
* @group param
3742
*/
3843
val p = new DoubleParam(this, "p", "the p norm value")
3944

45+
setDefault(p -> 2.0)
46+
4047
/** @group getParam */
4148
def getP: Double = getOrDefault(p)
4249

4350
/** @group setParam */
4451
def setP(value: Double): this.type = set(p, value)
4552

46-
setDefault(p -> 2.0)
47-
4853
override protected def createTransformFunc(paramMap: ParamMap): Vector => Vector = {
4954
val normalizer = new feature.Normalizer(paramMap(p))
5055
normalizer.transform

mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import org.apache.spark.sql.types.{StructField, StructType}
3131
* Params for [[StandardScaler]] and [[StandardScalerModel]].
3232
*/
3333
private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol {
34-
34+
3535
/**
3636
* False by default. Centers the data with mean before scaling.
3737
* It will build a dense output, so this does not work on sparse input
@@ -45,6 +45,8 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with
4545
* @group param
4646
*/
4747
val withStd: BooleanParam = new BooleanParam(this, "withStd", "Scale to unit standard deviation")
48+
49+
override def validate(paramMap: ParamMap): Unit = { }
4850
}
4951

5052
/**
@@ -56,7 +58,7 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with
5658
class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerParams {
5759

5860
setDefault(withMean -> false, withStd -> true)
59-
61+
6062
/** @group setParam */
6163
def setInputCol(value: String): this.type = set(inputCol, value)
6264

mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ import org.apache.spark.util.collection.OpenHashMap
3333
*/
3434
private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol {
3535

36+
override def validate(paramMap: ParamMap): Unit = { }
37+
3638
/** Validates and transforms the input schema. */
3739
protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = {
3840
val map = extractParamMap(paramMap)

mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ import org.apache.spark.sql.types.{DataType, StringType, ArrayType}
2929
@AlphaComponent
3030
class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] {
3131

32+
override def validate(paramMap: ParamMap): Unit = { }
33+
3234
override protected def createTransformFunc(paramMap: ParamMap): String => Seq[String] = {
3335
_.toLowerCase.split("\\s")
3436
}
@@ -43,20 +45,24 @@ class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] {
4345
/**
4446
* :: AlphaComponent ::
4547
* A regex based tokenizer that extracts tokens either by repeatedly matching the regex(default)
46-
* or using it to split the text (set matching to false). Optional parameters also allow to fold
47-
* the text to lowercase prior to it being tokenized and to filer tokens using a minimal length.
48+
* or using it to split the text (set matching to false). Optional parameters also allow filtering
49+
* tokens using a minimal length.
4850
* It returns an array of strings that can be empty.
49-
* The default parameters are regex = "\\p{L}+|[^\\p{L}\\s]+", matching = true,
50-
* lowercase = false, minTokenLength = 1
5151
*/
5252
@AlphaComponent
5353
class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenizer] {
5454

55+
override def validate(paramMap: ParamMap): Unit = {
56+
require(getOrDefault(minTokenLength) >= 0,
57+
s"RegexTokenizer minTokenLength must be >= 0, but was ${getOrDefault(minTokenLength)}")
58+
}
59+
5560
/**
56-
* param for minimum token length, default is one to avoid returning empty strings
61+
* Minimum token length, >= 0.
62+
* Default: 1, to avoid returning empty strings
5763
* @group param
5864
*/
59-
val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length")
65+
val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length (>= 0)")
6066

6167
/** @group setParam */
6268
def setMinTokenLength(value: Int): this.type = set(minTokenLength, value)
@@ -65,7 +71,8 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize
6571
def getMinTokenLength: Int = getOrDefault(minTokenLength)
6672

6773
/**
68-
* param sets regex as splitting on gaps (true) or matching tokens (false)
74+
* Indicates whether regex splits on gaps (true) or matching tokens (false).
75+
* Default: false
6976
* @group param
7077
*/
7178
val gaps: BooleanParam = new BooleanParam(this, "gaps", "Set regex to match gaps or tokens")
@@ -77,7 +84,8 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize
7784
def getGaps: Boolean = getOrDefault(gaps)
7885

7986
/**
80-
* param sets regex pattern used by tokenizer
87+
* Regex pattern used by tokenizer.
88+
* Default: "\\p{L}+|[^\\p{L}\\s]+"
8189
* @group param
8290
*/
8391
val pattern: Param[String] = new Param(this, "pattern", "regex pattern used for tokenizing")

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,12 +132,18 @@ trait Params extends Identifiable with Serializable {
132132
/**
133133
* Validates parameter values stored internally plus the input parameter map.
134134
* Raises an exception if any parameter is invalid.
135+
*
136+
* This generally checks parameters which do not specify input/output columns;
137+
* input/output columns are checked during schema validation.
135138
*/
136-
def validate(paramMap: ParamMap): Unit = {}
139+
def validate(paramMap: ParamMap): Unit
137140

138141
/**
139142
* Validates parameter values stored internally.
140143
* Raise an exception if any parameter value is invalid.
144+
*
145+
* This generally checks parameters which do not specify input/output columns;
146+
* input/output columns are checked during schema validation.
141147
*/
142148
def validate(): Unit = validate(ParamMap.empty)
143149

mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,11 @@ import org.apache.spark.util.Utils
3333
trait HasRegParam extends Params {
3434

3535
/**
36-
* Param for regularization parameter.
36+
* Param for regularization parameter. Should be >= 0.
3737
* @group param
3838
*/
39-
final val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter")
39+
final val regParam: DoubleParam =
40+
new DoubleParam(this, "regParam", "regularization parameter (>= 0)")
4041

4142
/** @group getParam */
4243
final def getRegParam: Double = getOrDefault(regParam)
@@ -50,10 +51,10 @@ trait HasRegParam extends Params {
5051
trait HasMaxIter extends Params {
5152

5253
/**
53-
* Param for max number of iterations.
54+
* Param for max number of iterations. Should be >= 0.
5455
* @group param
5556
*/
56-
final val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations")
57+
final val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations (>= 0)")
5758

5859
/** @group getParam */
5960
final def getMaxIter: Int = getOrDefault(maxIter)
@@ -165,7 +166,7 @@ trait HasThreshold extends Params {
165166
* Param for threshold in binary classification prediction.
166167
* @group param
167168
*/
168-
final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction")
169+
final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction, in range [0, 1]")
169170

170171
/** @group getParam */
171172
final def getThreshold: Double = getOrDefault(threshold)
@@ -233,7 +234,7 @@ trait HasCheckpointInterval extends Params {
233234
* Param for checkpoint interval.
234235
* @group param
235236
*/
236-
final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval")
237+
final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval (>= 1)")
237238

238239
/** @group getParam */
239240
final def getCheckpointInterval: Int = getOrDefault(checkpointInterval)

mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,15 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
138138
implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item",
139139
ratingCol -> "rating", nonnegative -> false)
140140

141+
override def validate(paramMap: ParamMap): Unit = {
142+
require(getOrDefault(regParam) >= 0,
143+
s"ALS regParam must be >= 0, but was ${getOrDefault(regParam)}")
144+
require(getOrDefault(maxIter) >= 0,
145+
s"ALS maxIter must be >= 0, but was ${getOrDefault(maxIter)}")
146+
require(getOrDefault(checkpointInterval) >= 1,
147+
s"ALS checkpointInterval must be >= 1, but was ${getOrDefault(checkpointInterval)}")
148+
}
149+
141150
/**
142151
* Validates and transforms the input schema.
143152
* @param schema input schema

0 commit comments

Comments
 (0)