Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,18 @@ package org.apache.spark.ml.tuning

import com.github.fommil.netlib.F2jBLAS
import org.apache.hadoop.fs.Path
import org.json4s.{JObject, DefaultFormats}
import org.json4s.jackson.JsonMethods._
import org.json4s.{DefaultFormats, JObject}

import org.apache.spark.ml.classification.OneVsRestParams
import org.apache.spark.ml.feature.RFormulaModel
import org.apache.spark.{SparkContext, Logging}
import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml._
import org.apache.spark.ml.classification.OneVsRestParams
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.feature.RFormulaModel
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.ml.util._
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -58,26 +58,34 @@ private[ml] trait CrossValidatorParams extends ValidatorParams {
* :: Experimental ::
* K-fold cross validation.
*/
@Since("1.2.0")
@Experimental
class CrossValidator(override val uid: String) extends Estimator[CrossValidatorModel]
class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
extends Estimator[CrossValidatorModel]
with CrossValidatorParams with MLWritable with Logging {

@Since("1.2.0")
def this() = this(Identifiable.randomUID("cv"))

private val f2jBLAS = new F2jBLAS

/** @group setParam */
@Since("1.2.0")
def setEstimator(value: Estimator[_]): this.type = set(estimator, value)

/** @group setParam */
@Since("1.2.0")
def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value)

/** @group setParam */
@Since("1.2.0")
def setEvaluator(value: Evaluator): this.type = set(evaluator, value)

/** @group setParam */
@Since("1.2.0")
def setNumFolds(value: Int): this.type = set(numFolds, value)

@Since("1.4.0")
override def fit(dataset: DataFrame): CrossValidatorModel = {
val schema = dataset.schema
transformSchema(schema, logging = true)
Expand Down Expand Up @@ -116,10 +124,12 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this))
}

@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
$(estimator).transformSchema(schema)
}

@Since("1.4.0")
override def validateParams(): Unit = {
super.validateParams()
val est = $(estimator)
Expand All @@ -128,6 +138,7 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
}
}

@Since("1.4.0")
override def copy(extra: ParamMap): CrossValidator = {
val copied = defaultCopy(extra).asInstanceOf[CrossValidator]
if (copied.isDefined(estimator)) {
Expand Down Expand Up @@ -308,26 +319,31 @@ object CrossValidator extends MLReadable[CrossValidator] {
* @param avgMetrics Average cross-validation metrics for each paramMap in
* [[CrossValidator.estimatorParamMaps]], in the corresponding order.
*/
@Since("1.2.0")
@Experimental
class CrossValidatorModel private[ml] (
override val uid: String,
val bestModel: Model[_],
val avgMetrics: Array[Double])
@Since("1.4.0") override val uid: String,
@Since("1.2.0") val bestModel: Model[_],
@Since("1.5.0") val avgMetrics: Array[Double])
extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable {

@Since("1.4.0")
override def validateParams(): Unit = {
bestModel.validateParams()
}

@Since("1.4.0")
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
bestModel.transform(dataset)
}

@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
bestModel.transformSchema(schema)
}

@Since("1.4.0")
override def copy(extra: ParamMap): CrossValidatorModel = {
val copied = new CrossValidatorModel(
uid,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,23 @@ package org.apache.spark.ml.tuning
import scala.annotation.varargs
import scala.collection.mutable

import org.apache.spark.annotation.Experimental
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param._

/**
* :: Experimental ::
* Builder for a param grid used in grid search-based model selection.
*/
@Since("1.2.0")
@Experimental
class ParamGridBuilder {
class ParamGridBuilder @Since("1.2.0") {

private val paramGrid = mutable.Map.empty[Param[_], Iterable[_]]

/**
* Sets the given parameters in this grid to fixed values.
*/
@Since("1.2.0")
def baseOn(paramMap: ParamMap): this.type = {
baseOn(paramMap.toSeq: _*)
this
Expand All @@ -43,6 +45,7 @@ class ParamGridBuilder {
/**
* Sets the given parameters in this grid to fixed values.
*/
@Since("1.2.0")
@varargs
def baseOn(paramPairs: ParamPair[_]*): this.type = {
paramPairs.foreach { p =>
Expand All @@ -54,6 +57,7 @@ class ParamGridBuilder {
/**
* Adds a param with multiple values (overwrites if the input param exists).
*/
@Since("1.2.0")
def addGrid[T](param: Param[T], values: Iterable[T]): this.type = {
paramGrid.put(param, values)
this
Expand All @@ -64,41 +68,47 @@ class ParamGridBuilder {
/**
* Adds a double param with multiple values.
*/
@Since("1.2.0")
def addGrid(param: DoubleParam, values: Array[Double]): this.type = {
addGrid[Double](param, values)
}

/**
* Adds a int param with multiple values.
*/
@Since("1.2.0")
def addGrid(param: IntParam, values: Array[Int]): this.type = {
addGrid[Int](param, values)
}

/**
* Adds a float param with multiple values.
*/
@Since("1.2.0")
def addGrid(param: FloatParam, values: Array[Float]): this.type = {
addGrid[Float](param, values)
}

/**
* Adds a long param with multiple values.
*/
@Since("1.2.0")
def addGrid(param: LongParam, values: Array[Long]): this.type = {
addGrid[Long](param, values)
}

/**
* Adds a boolean param with true and false.
*/
@Since("1.2.0")
def addGrid(param: BooleanParam): this.type = {
addGrid[Boolean](param, Array(true, false))
}

/**
* Builds and returns all combinations of parameters specified by the param grid.
*/
@Since("1.2.0")
def build(): Array[ParamMap] = {
var paramMaps = Array(new ParamMap)
paramGrid.foreach { case (param, values) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.ml.tuning

import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators}
Expand Down Expand Up @@ -51,24 +51,32 @@ private[ml] trait TrainValidationSplitParams extends ValidatorParams {
* and uses evaluation metric on the validation set to select the best model.
* Similar to [[CrossValidator]], but only splits the set once.
*/
@Since("1.5.0")
@Experimental
class TrainValidationSplit(override val uid: String) extends Estimator[TrainValidationSplitModel]
class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: String)
extends Estimator[TrainValidationSplitModel]
with TrainValidationSplitParams with Logging {

@Since("1.5.0")
def this() = this(Identifiable.randomUID("tvs"))

/** @group setParam */
@Since("1.5.0")
def setEstimator(value: Estimator[_]): this.type = set(estimator, value)

/** @group setParam */
@Since("1.5.0")
def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value)

/** @group setParam */
@Since("1.5.0")
def setEvaluator(value: Evaluator): this.type = set(evaluator, value)

/** @group setParam */
@Since("1.5.0")
def setTrainRatio(value: Double): this.type = set(trainRatio, value)

@Since("1.5.0")
override def fit(dataset: DataFrame): TrainValidationSplitModel = {
val schema = dataset.schema
transformSchema(schema, logging = true)
Expand Down Expand Up @@ -108,10 +116,12 @@ class TrainValidationSplit(override val uid: String) extends Estimator[TrainVali
copyValues(new TrainValidationSplitModel(uid, bestModel, metrics).setParent(this))
}

@Since("1.5.0")
override def transformSchema(schema: StructType): StructType = {
$(estimator).transformSchema(schema)
}

@Since("1.5.0")
override def validateParams(): Unit = {
super.validateParams()
val est = $(estimator)
Expand All @@ -120,6 +130,7 @@ class TrainValidationSplit(override val uid: String) extends Estimator[TrainVali
}
}

@Since("1.5.0")
override def copy(extra: ParamMap): TrainValidationSplit = {
val copied = defaultCopy(extra).asInstanceOf[TrainValidationSplit]
if (copied.isDefined(estimator)) {
Expand All @@ -140,26 +151,31 @@ class TrainValidationSplit(override val uid: String) extends Estimator[TrainVali
* @param bestModel Estimator determined best model.
* @param validationMetrics Evaluated validation metrics.
*/
@Since("1.5.0")
@Experimental
class TrainValidationSplitModel private[ml] (
override val uid: String,
val bestModel: Model[_],
val validationMetrics: Array[Double])
@Since("1.5.0") override val uid: String,
@Since("1.5.0") val bestModel: Model[_],
@Since("1.5.0") val validationMetrics: Array[Double])
extends Model[TrainValidationSplitModel] with TrainValidationSplitParams {

@Since("1.5.0")
override def validateParams(): Unit = {
bestModel.validateParams()
}

@Since("1.5.0")
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
bestModel.transform(dataset)
}

@Since("1.5.0")
override def transformSchema(schema: StructType): StructType = {
bestModel.transformSchema(schema)
}

@Since("1.5.0")
override def copy(extra: ParamMap): TrainValidationSplitModel = {
val copied = new TrainValidationSplitModel (
uid,
Expand Down