Skip to content

Commit af62f4b

Browse files
committed
Removed changes to SparkBuild, python linalg. Fixed test failures. Renamed ParamValidate to ParamValidators. Removed explicit type from ParamValidators calls where possible.
1 parent bb2665a commit af62f4b

File tree

17 files changed

+83
-82
lines changed

17 files changed

+83
-82
lines changed

mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ class Pipeline extends Estimator[PipelineModel] {
9090
val map = extractParamMap(paramMap)
9191
getStages.foreach {
9292
case pStage: Params => pStage.validate(map)
93+
case _ =>
9394
}
9495
}
9596

mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.ml.feature
1919

2020
import org.apache.spark.annotation.AlphaComponent
2121
import org.apache.spark.ml.UnaryTransformer
22-
import org.apache.spark.ml.param.{ParamValidate, IntParam, ParamMap}
22+
import org.apache.spark.ml.param.{ParamValidators, IntParam, ParamMap}
2323
import org.apache.spark.mllib.feature
2424
import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
2525
import org.apache.spark.sql.types.DataType
@@ -37,7 +37,7 @@ class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] {
3737
* @group param
3838
*/
3939
val numFeatures = new IntParam(this, "numFeatures", "number of features (> 0)",
40-
ParamValidate.gt[Int](0))
40+
ParamValidators.gt(0))
4141

4242
setDefault(numFeatures -> (1 << 18))
4343

mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.ml.feature
1919

2020
import org.apache.spark.annotation.AlphaComponent
2121
import org.apache.spark.ml.UnaryTransformer
22-
import org.apache.spark.ml.param.{ParamValidate, DoubleParam, ParamMap}
22+
import org.apache.spark.ml.param.{ParamValidators, DoubleParam, ParamMap}
2323
import org.apache.spark.mllib.feature
2424
import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
2525
import org.apache.spark.sql.types.DataType
@@ -36,7 +36,7 @@ class Normalizer extends UnaryTransformer[Vector, Vector, Normalizer] {
3636
* (default: p = 2)
3737
* @group param
3838
*/
39-
val p = new DoubleParam(this, "p", "the p norm value", ParamValidate.gtEq[Double](1))
39+
val p = new DoubleParam(this, "p", "the p norm value", ParamValidators.gtEq(1))
4040

4141
setDefault(p -> 2.0)
4242

mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import scala.collection.mutable
2121

2222
import org.apache.spark.annotation.AlphaComponent
2323
import org.apache.spark.ml.UnaryTransformer
24-
import org.apache.spark.ml.param.{ParamValidate, IntParam, ParamMap}
24+
import org.apache.spark.ml.param.{ParamValidators, IntParam, ParamMap}
2525
import org.apache.spark.mllib.linalg._
2626
import org.apache.spark.sql.types.DataType
2727

@@ -37,12 +37,12 @@ import org.apache.spark.sql.types.DataType
3737
class PolynomialExpansion extends UnaryTransformer[Vector, Vector, PolynomialExpansion] {
3838

3939
/**
40-
* The polynomial degree to expand, which should be larger than 1.
40+
* The polynomial degree to expand, which should be >= 1. A value of 1 means no expansion.
4141
* Default: 2
4242
* @group param
4343
*/
44-
val degree = new IntParam(this, "degree", "the polynomial degree to expand",
45-
ParamValidate.gt[Int](2))
44+
val degree = new IntParam(this, "degree", "the polynomial degree to expand (>= 1)",
45+
ParamValidators.gt(1))
4646

4747
setDefault(degree -> 2)
4848

mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class RegexTokenizer extends UnaryTransformer[String, Seq[String], RegexTokenize
5656
* @group param
5757
*/
5858
val minTokenLength: IntParam = new IntParam(this, "minLength", "minimum token length (>= 0)",
59-
ParamValidate.gtEq[Int](0))
59+
ParamValidators.gtEq(0))
6060

6161
/** @group setParam */
6262
def setMinTokenLength(value: Int): this.type = set(minTokenLength, value)

mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import org.apache.spark.ml.util.SchemaUtils
2222
import org.apache.spark.ml.{Estimator, Model}
2323
import org.apache.spark.ml.attribute.{BinaryAttribute, NumericAttribute, NominalAttribute,
2424
Attribute, AttributeGroup}
25-
import org.apache.spark.ml.param.{ParamValidate, IntParam, ParamMap, Params}
25+
import org.apache.spark.ml.param.{ParamValidators, IntParam, ParamMap, Params}
2626
import org.apache.spark.ml.param.shared._
2727
import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, VectorUDT}
2828
import org.apache.spark.sql.{Row, DataFrame}
@@ -44,7 +44,7 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu
4444
val maxCategories = new IntParam(this, "maxCategories",
4545
"Threshold for the number of values a categorical feature can take (>= 2)." +
4646
" If a feature is found to have > maxCategories values, then it is declared continuous.",
47-
ParamValidate.gtEq[Int](2))
47+
ParamValidators.gtEq(2))
4848

4949
setDefault(maxCategories -> 20)
5050

mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
4646
final val maxDepth: IntParam =
4747
new IntParam(this, "maxDepth", "Maximum depth of the tree. (>= 0)" +
4848
" E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.",
49-
ParamValidate.gtEq[Int](0))
49+
ParamValidators.gtEq(0))
5050

5151
/**
5252
* Maximum number of bins used for discretizing continuous features and for choosing how to split
@@ -57,7 +57,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
5757
*/
5858
final val maxBins: IntParam = new IntParam(this, "maxBins", "Max number of bins for" +
5959
" discretizing continuous features. Must be >=2 and >= number of categories for any" +
60-
" categorical feature.", ParamValidate.gtEq[Int](2))
60+
" categorical feature.", ParamValidators.gtEq(2))
6161

6262
/**
6363
* Minimum number of instances each child must have after split.
@@ -70,7 +70,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
7070
final val minInstancesPerNode: IntParam = new IntParam(this, "minInstancesPerNode", "Minimum" +
7171
" number of instances each child must have after split. If a split causes the left or right" +
7272
" child to have fewer than minInstancesPerNode, the split will be discarded as invalid." +
73-
" Should be >= 1.", ParamValidate.gtEq[Int](1))
73+
" Should be >= 1.", ParamValidators.gtEq(1))
7474

7575
/**
7676
* Minimum information gain for a split to be considered at a tree node.
@@ -87,7 +87,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
8787
*/
8888
final val maxMemoryInMB: IntParam = new IntParam(this, "maxMemoryInMB",
8989
"Maximum memory in MB allocated to histogram aggregation.",
90-
ParamValidate.gtEq[Int](0))
90+
ParamValidators.gtEq(0))
9191

9292
/**
9393
* If false, the algorithm will pass trees to executors to match instances with nodes.
@@ -114,7 +114,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
114114
" how often to checkpoint the cached node IDs. E.g. 10 means that the cache will get" +
115115
" checkpointed every 10 iterations. This is only used if cacheNodeIds is true and if the" +
116116
" checkpoint directory is set in the SparkContext. Must be >= 1.",
117-
ParamValidate.gtEq[Int](1))
117+
ParamValidators.gtEq(1))
118118

119119
setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0,
120120
maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10)
@@ -283,7 +283,7 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed {
283283
*/
284284
final val subsamplingRate: DoubleParam = new DoubleParam(this, "subsamplingRate",
285285
"Fraction of the training data used for learning each decision tree, in range (0, 1].",
286-
ParamValidate.inRange[Double](0, 1, lowerInclusive = false, upperInclusive = true))
286+
ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))
287287

288288
setDefault(subsamplingRate -> 1.0)
289289

@@ -326,7 +326,7 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams {
326326
* @group param
327327
*/
328328
final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)",
329-
ParamValidate.gtEq[Int](1))
329+
ParamValidators.gtEq(1))
330330

331331
/**
332332
* The number of features to consider for splits at each tree node.
@@ -396,7 +396,7 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter {
396396
*/
397397
final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size (a.k.a." +
398398
" learning rate) in interval (0, 1] for shrinking the contribution of each estimator",
399-
ParamValidate.inRange[Double](0, 1, lowerInclusive = false, upperInclusive = true))
399+
ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))
400400

401401
/* TODO: Add this doc when we add this param. SPARK-7132
402402
* Threshold for stopping early when runWithValidation is used.

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,15 @@ import org.apache.spark.ml.util.Identifiable
3535
* @param name param name
3636
* @param doc documentation
3737
* @param isValid optional validation method which indicates if a value is valid.
38-
* See [[ParamValidate]] for factory methods for common validation functions.
38+
* See [[ParamValidators]] for factory methods for common validation functions.
3939
* @tparam T param value type
4040
*/
4141
@AlphaComponent
4242
class Param[T] (val parent: Params, val name: String, val doc: String, val isValid: T => Boolean)
4343
extends Serializable {
4444

4545
def this(parent: Params, name: String, doc: String) =
46-
this(parent, name, doc, ParamValidate.alwaysTrue[T])
46+
this(parent, name, doc, ParamValidators.alwaysTrue[T])
4747

4848
/**
4949
* Assert that the given value is valid for this parameter.
@@ -94,7 +94,7 @@ class Param[T] (val parent: Params, val name: String, val doc: String, val isVal
9494
* Factory methods for common validation functions for [[Param.isValid]].
9595
* The numerical methods only support Int, Long, Float, and Double.
9696
*/
97-
object ParamValidate {
97+
object ParamValidators {
9898

9999
/** (private[param]) Default validation always return true */
100100
private[param] def alwaysTrue[T]: T => Boolean = (_: T) => true
@@ -176,37 +176,37 @@ class DoubleParam(parent: Params, name: String, doc: String, isValid: Double =>
176176
extends Param[Double](parent, name, doc, isValid) {
177177

178178
def this(parent: Params, name: String, doc: String) =
179-
this(parent, name, doc, ParamValidate.alwaysTrue[Double])
179+
this(parent, name, doc, ParamValidators.alwaysTrue)
180180

181181
override def w(value: Double): ParamPair[Double] = super.w(value)
182182
}
183183

184184
/** Specialized version of [[Param[Int]]] for Java. */
185185
class IntParam(parent: Params, name: String, doc: String, isValid: Int => Boolean)
186-
extends Param[Int](parent, name, doc) {
186+
extends Param[Int](parent, name, doc, isValid) {
187187

188188
def this(parent: Params, name: String, doc: String) =
189-
this(parent, name, doc, ParamValidate.alwaysTrue[Int])
189+
this(parent, name, doc, ParamValidators.alwaysTrue)
190190

191191
override def w(value: Int): ParamPair[Int] = super.w(value)
192192
}
193193

194194
/** Specialized version of [[Param[Float]]] for Java. */
195195
class FloatParam(parent: Params, name: String, doc: String, isValid: Float => Boolean)
196-
extends Param[Float](parent, name, doc) {
196+
extends Param[Float](parent, name, doc, isValid) {
197197

198198
def this(parent: Params, name: String, doc: String) =
199-
this(parent, name, doc, ParamValidate.alwaysTrue[Float])
199+
this(parent, name, doc, ParamValidators.alwaysTrue)
200200

201201
override def w(value: Float): ParamPair[Float] = super.w(value)
202202
}
203203

204204
/** Specialized version of [[Param[Long]]] for Java. */
205205
class LongParam(parent: Params, name: String, doc: String, isValid: Long => Boolean)
206-
extends Param[Long](parent, name, doc) {
206+
extends Param[Long](parent, name, doc, isValid) {
207207

208208
def this(parent: Params, name: String, doc: String) =
209-
this(parent, name, doc, ParamValidate.alwaysTrue[Long])
209+
this(parent, name, doc, ParamValidators.alwaysTrue)
210210

211211
override def w(value: Long): ParamPair[Long] = super.w(value)
212212
}

mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import java.io.PrintWriter
2121

2222
import scala.reflect.ClassTag
2323

24-
import org.apache.spark.ml.param.ParamValidate
24+
import org.apache.spark.ml.param.ParamValidators
2525

2626
/**
2727
* Code generator for shared params (sharedParams.scala). Run under the Spark folder with
@@ -34,9 +34,9 @@ private[shared] object SharedParamsCodeGen {
3434
def main(args: Array[String]): Unit = {
3535
val params = Seq(
3636
ParamDesc[Double]("regParam", "regularization parameter (>= 0)",
37-
isValid = "ParamValidate.gtEq[Double](0)"),
37+
isValid = "ParamValidators.gtEq(0)"),
3838
ParamDesc[Int]("maxIter", "max number of iterations (>= 0)",
39-
isValid = "ParamValidate.gtEq[Int](0)"),
39+
isValid = "ParamValidators.gtEq(0)"),
4040
ParamDesc[String]("featuresCol", "features column name", Some("\"features\"")),
4141
ParamDesc[String]("labelCol", "label column name", Some("\"label\"")),
4242
ParamDesc[String]("predictionCol", "prediction column name", Some("\"prediction\"")),
@@ -46,17 +46,17 @@ private[shared] object SharedParamsCodeGen {
4646
"column name for predicted class conditional probabilities", Some("\"probability\"")),
4747
ParamDesc[Double]("threshold",
4848
"threshold in binary classification prediction, in range [0, 1]",
49-
isValid = "ParamValidate.inRange[Double](0, 1)"),
49+
isValid = "ParamValidators.inRange(0, 1)"),
5050
ParamDesc[String]("inputCol", "input column name"),
5151
ParamDesc[Array[String]]("inputCols", "input column names"),
5252
ParamDesc[String]("outputCol", "output column name"),
5353
ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1)",
54-
isValid = "ParamValidate.gtEq[Int](1)"),
54+
isValid = "ParamValidators.gtEq(1)"),
5555
ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")),
5656
ParamDesc[Long]("seed", "random seed", Some("Utils.random.nextLong()")),
5757
ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]." +
5858
" For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.",
59-
isValid = "ParamValidate.inRange[Double](0, 1)"),
59+
isValid = "ParamValidators.inRange(0, 1)"),
6060
ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms"),
6161
ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization."))
6262

mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ private[ml] trait HasRegParam extends Params {
3434
* Param for regularization parameter (>= 0).
3535
* @group param
3636
*/
37-
final val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter (>= 0)", ParamValidate.gtEq[Double](0))
37+
final val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter (>= 0)", ParamValidators.gtEq(0))
3838

3939
/** @group getParam */
4040
final def getRegParam: Double = getOrDefault(regParam)
@@ -49,7 +49,7 @@ private[ml] trait HasMaxIter extends Params {
4949
* Param for max number of iterations (>= 0).
5050
* @group param
5151
*/
52-
final val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations (>= 0)", ParamValidate.gtEq[Int](0))
52+
final val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations (>= 0)", ParamValidators.gtEq(0))
5353

5454
/** @group getParam */
5555
final def getMaxIter: Int = getOrDefault(maxIter)
@@ -149,7 +149,7 @@ private[ml] trait HasThreshold extends Params {
149149
* Param for threshold in binary classification prediction, in range [0, 1].
150150
* @group param
151151
*/
152-
final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction, in range [0, 1]", ParamValidate.inRange[Double](0, 1))
152+
final val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in binary classification prediction, in range [0, 1]", ParamValidators.inRange(0, 1))
153153

154154
/** @group getParam */
155155
final def getThreshold: Double = getOrDefault(threshold)
@@ -209,7 +209,7 @@ private[ml] trait HasCheckpointInterval extends Params {
209209
* Param for checkpoint interval (>= 1).
210210
* @group param
211211
*/
212-
final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval (>= 1)", ParamValidate.gtEq[Int](1))
212+
final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval (>= 1)", ParamValidators.gtEq(1))
213213

214214
/** @group getParam */
215215
final def getCheckpointInterval: Int = getOrDefault(checkpointInterval)
@@ -258,7 +258,7 @@ private[ml] trait HasElasticNetParam extends Params {
258258
* Param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty..
259259
* @group param
260260
*/
261-
final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", ParamValidate.inRange[Double](0, 1))
261+
final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", ParamValidators.inRange(0, 1))
262262

263263
/** @group getParam */
264264
final def getElasticNetParam: Double = getOrDefault(elasticNetParam)

0 commit comments

Comments
 (0)