Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.ml.classification

import org.apache.spark.annotation.Experimental
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams}
import org.apache.spark.ml.tree.impl.RandomForest
Expand All @@ -36,32 +36,44 @@ import org.apache.spark.sql.DataFrame
* It supports both binary and multiclass labels, as well as both continuous and categorical
* features.
*/
@Since("1.4.0")
@Experimental
final class DecisionTreeClassifier(override val uid: String)
final class DecisionTreeClassifier @Since("1.4.0") (
@Since("1.4.0") override val uid: String)
extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
with DecisionTreeParams with TreeClassifierParams {

@Since("1.4.0")
def this() = this(Identifiable.randomUID("dtc"))

// Override parameter setters from parent trait for Java API compatibility.

@Since("1.4.0")
override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)

@Since("1.4.0")
override def setMaxBins(value: Int): this.type = super.setMaxBins(value)

@Since("1.4.0")
override def setMinInstancesPerNode(value: Int): this.type =
super.setMinInstancesPerNode(value)

@Since("1.4.0")
override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value)

@Since("1.4.0")
override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)

@Since("1.4.0")
override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value)

@Since("1.4.0")
override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value)

@Since("1.4.0")
override def setImpurity(value: String): this.type = super.setImpurity(value)

@Since("1.6.0")
override def setSeed(value: Long): this.type = super.setSeed(value)

override protected def train(dataset: DataFrame): DecisionTreeClassificationModel = {
Expand Down Expand Up @@ -89,12 +101,15 @@ final class DecisionTreeClassifier(override val uid: String)
subsamplingRate = 1.0)
}

@Since("1.4.1")
override def copy(extra: ParamMap): DecisionTreeClassifier = defaultCopy(extra)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Annotate all public methods

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}

@Since("1.4.0")
@Experimental
object DecisionTreeClassifier {
/** Accessor for supported impurities: entropy, gini */
@Since("1.4.0")
final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
}

Expand All @@ -104,12 +119,13 @@ object DecisionTreeClassifier {
* It supports both binary and multiclass labels, as well as both continuous and categorical
* features.
*/
@Since("1.4.0")
@Experimental
final class DecisionTreeClassificationModel private[ml] (
override val uid: String,
override val rootNode: Node,
override val numFeatures: Int,
override val numClasses: Int)
@Since("1.4.0")override val uid: String,
@Since("1.4.0")override val rootNode: Node,
@Since("1.6.0")override val numFeatures: Int,
@Since("1.5.0")override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel]
with DecisionTreeModel with Serializable {

Expand Down Expand Up @@ -142,11 +158,13 @@ final class DecisionTreeClassificationModel private[ml] (
}
}

@Since("1.4.0")
override def copy(extra: ParamMap): DecisionTreeClassificationModel = {
copyValues(new DecisionTreeClassificationModel(uid, rootNode, numFeatures, numClasses), extra)
.setParent(parent)
}

@Since("1.4.0")
override def toString: String = {
s"DecisionTreeClassificationModel (uid=$uid) of depth $depth with $numNodes nodes"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.ml.classification
import com.github.fommil.netlib.BLAS.{getInstance => blas}

import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
Expand All @@ -44,54 +44,69 @@ import org.apache.spark.sql.types.DoubleType
* It supports binary labels, as well as both continuous and categorical features.
* Note: Multiclass labels are not currently supported.
*/
@Since("1.4.0")
@Experimental
final class GBTClassifier(override val uid: String)
final class GBTClassifier @Since("1.4.0") (
@Since("1.4.0") override val uid: String)
extends Predictor[Vector, GBTClassifier, GBTClassificationModel]
with GBTParams with TreeClassifierParams with Logging {

@Since("1.4.0")
def this() = this(Identifiable.randomUID("gbtc"))

// Override parameter setters from parent trait for Java API compatibility.

// Parameters from TreeClassifierParams:

@Since("1.4.0")
override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)

@Since("1.4.0")
override def setMaxBins(value: Int): this.type = super.setMaxBins(value)

@Since("1.4.0")
override def setMinInstancesPerNode(value: Int): this.type =
super.setMinInstancesPerNode(value)

@Since("1.4.0")
override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value)

@Since("1.4.0")
override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)

@Since("1.4.0")
override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value)

@Since("1.4.0")
override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value)

/**
* The impurity setting is ignored for GBT models.
* Individual trees are built using impurity "Variance."
*/
@Since("1.4.0")
override def setImpurity(value: String): this.type = {
logWarning("GBTClassifier.setImpurity should NOT be used")
this
}

// Parameters from TreeEnsembleParams:

@Since("1.4.0")
override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value)

@Since("1.4.0")
override def setSeed(value: Long): this.type = {
logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.")
super.setSeed(value)
}

// Parameters from GBTParams:

@Since("1.4.0")
override def setMaxIter(value: Int): this.type = super.setMaxIter(value)

@Since("1.4.0")
override def setStepSize(value: Double): this.type = super.setStepSize(value)

// Parameters for GBTClassifier:
Expand All @@ -102,6 +117,7 @@ final class GBTClassifier(override val uid: String)
* (default = logistic)
* @group param
*/
@Since("1.4.0")
val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
" tries to minimize (case-insensitive). Supported options:" +
s" ${GBTClassifier.supportedLossTypes.mkString(", ")}",
Expand All @@ -110,9 +126,11 @@ final class GBTClassifier(override val uid: String)
setDefault(lossType -> "logistic")

/** @group setParam */
@Since("1.4.0")
def setLossType(value: String): this.type = set(lossType, value)

/** @group getParam */
@Since("1.4.0")
def getLossType: String = $(lossType).toLowerCase

/** (private[ml]) Convert new loss to old loss. */
Expand Down Expand Up @@ -145,13 +163,16 @@ final class GBTClassifier(override val uid: String)
GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures, numFeatures)
}

@Since("1.4.1")
override def copy(extra: ParamMap): GBTClassifier = defaultCopy(extra)
}

@Since("1.4.0")
@Experimental
object GBTClassifier {
// The losses below should be lowercase.
/** Accessor for supported loss settings: logistic */
@Since("1.4.0")
final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase)
}

Expand All @@ -164,12 +185,13 @@ object GBTClassifier {
* @param _trees Decision trees in the ensemble.
* @param _treeWeights Weights for the decision trees in the ensemble.
*/
@Since("1.6.0")
@Experimental
final class GBTClassificationModel private[ml](
override val uid: String,
@Since("1.6.0") override val uid: String,
private val _trees: Array[DecisionTreeRegressionModel],
private val _treeWeights: Array[Double],
override val numFeatures: Int)
@Since("1.6.0") override val numFeatures: Int)
extends PredictionModel[Vector, GBTClassificationModel]
with TreeEnsembleModel with Serializable {

Expand All @@ -182,11 +204,14 @@ final class GBTClassificationModel private[ml](
* @param _trees Decision trees in the ensemble.
* @param _treeWeights Weights for the decision trees in the ensemble.
*/
@Since("1.6.0")
def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) =
this(uid, _trees, _treeWeights, -1)

@Since("1.4.0")
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]

@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights

override protected def transformImpl(dataset: DataFrame): DataFrame = {
Expand All @@ -205,11 +230,13 @@ final class GBTClassificationModel private[ml](
if (prediction > 0.0) 1.0 else 0.0
}

@Since("1.4.0")
override def copy(extra: ParamMap): GBTClassificationModel = {
copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures),
extra).setParent(parent)
}

@Since("1.4.0")
override def toString: String = {
s"GBTClassificationModel (uid=$uid) with $numTrees trees"
}
Expand Down
Loading