Skip to content

Commit 8652fc0

Browse files
Takahashi Hiroshimengxr
authored andcommitted
[SPARK-10259][ML] Add @SInCE annotation to ml.classification
Add since annotation to ml.classification Author: Takahashi Hiroshi <[email protected]> Closes #8534 from taishi-oss/issue10259. (cherry picked from commit 7d05a62) Signed-off-by: Xiangrui Meng <[email protected]>
1 parent 3c683ed commit 8652fc0

File tree

7 files changed

+185
-44
lines changed

7 files changed

+185
-44
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.ml.classification
1919

20-
import org.apache.spark.annotation.Experimental
20+
import org.apache.spark.annotation.{Experimental, Since}
2121
import org.apache.spark.ml.param.ParamMap
2222
import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams}
2323
import org.apache.spark.ml.tree.impl.RandomForest
@@ -36,32 +36,44 @@ import org.apache.spark.sql.DataFrame
3636
* It supports both binary and multiclass labels, as well as both continuous and categorical
3737
* features.
3838
*/
39+
@Since("1.4.0")
3940
@Experimental
40-
final class DecisionTreeClassifier(override val uid: String)
41+
final class DecisionTreeClassifier @Since("1.4.0") (
42+
@Since("1.4.0") override val uid: String)
4143
extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
4244
with DecisionTreeParams with TreeClassifierParams {
4345

46+
@Since("1.4.0")
4447
def this() = this(Identifiable.randomUID("dtc"))
4548

4649
// Override parameter setters from parent trait for Java API compatibility.
4750

51+
@Since("1.4.0")
4852
override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
4953

54+
@Since("1.4.0")
5055
override def setMaxBins(value: Int): this.type = super.setMaxBins(value)
5156

57+
@Since("1.4.0")
5258
override def setMinInstancesPerNode(value: Int): this.type =
5359
super.setMinInstancesPerNode(value)
5460

61+
@Since("1.4.0")
5562
override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value)
5663

64+
@Since("1.4.0")
5765
override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)
5866

67+
@Since("1.4.0")
5968
override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value)
6069

70+
@Since("1.4.0")
6171
override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value)
6272

73+
@Since("1.4.0")
6374
override def setImpurity(value: String): this.type = super.setImpurity(value)
6475

76+
@Since("1.6.0")
6577
override def setSeed(value: Long): this.type = super.setSeed(value)
6678

6779
override protected def train(dataset: DataFrame): DecisionTreeClassificationModel = {
@@ -89,12 +101,15 @@ final class DecisionTreeClassifier(override val uid: String)
89101
subsamplingRate = 1.0)
90102
}
91103

104+
@Since("1.4.1")
92105
override def copy(extra: ParamMap): DecisionTreeClassifier = defaultCopy(extra)
93106
}
94107

108+
@Since("1.4.0")
95109
@Experimental
96110
object DecisionTreeClassifier {
97111
/** Accessor for supported impurities: entropy, gini */
112+
@Since("1.4.0")
98113
final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
99114
}
100115

@@ -104,12 +119,13 @@ object DecisionTreeClassifier {
104119
* It supports both binary and multiclass labels, as well as both continuous and categorical
105120
* features.
106121
*/
122+
@Since("1.4.0")
107123
@Experimental
108124
final class DecisionTreeClassificationModel private[ml] (
109-
override val uid: String,
110-
override val rootNode: Node,
111-
override val numFeatures: Int,
112-
override val numClasses: Int)
125+
@Since("1.4.0")override val uid: String,
126+
@Since("1.4.0")override val rootNode: Node,
127+
@Since("1.6.0")override val numFeatures: Int,
128+
@Since("1.5.0")override val numClasses: Int)
113129
extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel]
114130
with DecisionTreeModel with Serializable {
115131

@@ -142,11 +158,13 @@ final class DecisionTreeClassificationModel private[ml] (
142158
}
143159
}
144160

161+
@Since("1.4.0")
145162
override def copy(extra: ParamMap): DecisionTreeClassificationModel = {
146163
copyValues(new DecisionTreeClassificationModel(uid, rootNode, numFeatures, numClasses), extra)
147164
.setParent(parent)
148165
}
149166

167+
@Since("1.4.0")
150168
override def toString: String = {
151169
s"DecisionTreeClassificationModel (uid=$uid) of depth $depth with $numNodes nodes"
152170
}

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.ml.classification
2020
import com.github.fommil.netlib.BLAS.{getInstance => blas}
2121

2222
import org.apache.spark.Logging
23-
import org.apache.spark.annotation.Experimental
23+
import org.apache.spark.annotation.{Experimental, Since}
2424
import org.apache.spark.ml.{PredictionModel, Predictor}
2525
import org.apache.spark.ml.param.{Param, ParamMap}
2626
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
@@ -44,54 +44,69 @@ import org.apache.spark.sql.types.DoubleType
4444
* It supports binary labels, as well as both continuous and categorical features.
4545
* Note: Multiclass labels are not currently supported.
4646
*/
47+
@Since("1.4.0")
4748
@Experimental
48-
final class GBTClassifier(override val uid: String)
49+
final class GBTClassifier @Since("1.4.0") (
50+
@Since("1.4.0") override val uid: String)
4951
extends Predictor[Vector, GBTClassifier, GBTClassificationModel]
5052
with GBTParams with TreeClassifierParams with Logging {
5153

54+
@Since("1.4.0")
5255
def this() = this(Identifiable.randomUID("gbtc"))
5356

5457
// Override parameter setters from parent trait for Java API compatibility.
5558

5659
// Parameters from TreeClassifierParams:
5760

61+
@Since("1.4.0")
5862
override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
5963

64+
@Since("1.4.0")
6065
override def setMaxBins(value: Int): this.type = super.setMaxBins(value)
6166

67+
@Since("1.4.0")
6268
override def setMinInstancesPerNode(value: Int): this.type =
6369
super.setMinInstancesPerNode(value)
6470

71+
@Since("1.4.0")
6572
override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value)
6673

74+
@Since("1.4.0")
6775
override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)
6876

77+
@Since("1.4.0")
6978
override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value)
7079

80+
@Since("1.4.0")
7181
override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value)
7282

7383
/**
7484
* The impurity setting is ignored for GBT models.
7585
* Individual trees are built using impurity "Variance."
7686
*/
87+
@Since("1.4.0")
7788
override def setImpurity(value: String): this.type = {
7889
logWarning("GBTClassifier.setImpurity should NOT be used")
7990
this
8091
}
8192

8293
// Parameters from TreeEnsembleParams:
8394

95+
@Since("1.4.0")
8496
override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value)
8597

98+
@Since("1.4.0")
8699
override def setSeed(value: Long): this.type = {
87100
logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.")
88101
super.setSeed(value)
89102
}
90103

91104
// Parameters from GBTParams:
92105

106+
@Since("1.4.0")
93107
override def setMaxIter(value: Int): this.type = super.setMaxIter(value)
94108

109+
@Since("1.4.0")
95110
override def setStepSize(value: Double): this.type = super.setStepSize(value)
96111

97112
// Parameters for GBTClassifier:
@@ -102,6 +117,7 @@ final class GBTClassifier(override val uid: String)
102117
* (default = logistic)
103118
* @group param
104119
*/
120+
@Since("1.4.0")
105121
val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
106122
" tries to minimize (case-insensitive). Supported options:" +
107123
s" ${GBTClassifier.supportedLossTypes.mkString(", ")}",
@@ -110,9 +126,11 @@ final class GBTClassifier(override val uid: String)
110126
setDefault(lossType -> "logistic")
111127

112128
/** @group setParam */
129+
@Since("1.4.0")
113130
def setLossType(value: String): this.type = set(lossType, value)
114131

115132
/** @group getParam */
133+
@Since("1.4.0")
116134
def getLossType: String = $(lossType).toLowerCase
117135

118136
/** (private[ml]) Convert new loss to old loss. */
@@ -145,13 +163,16 @@ final class GBTClassifier(override val uid: String)
145163
GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures, numFeatures)
146164
}
147165

166+
@Since("1.4.1")
148167
override def copy(extra: ParamMap): GBTClassifier = defaultCopy(extra)
149168
}
150169

170+
@Since("1.4.0")
151171
@Experimental
152172
object GBTClassifier {
153173
// The losses below should be lowercase.
154174
/** Accessor for supported loss settings: logistic */
175+
@Since("1.4.0")
155176
final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase)
156177
}
157178

@@ -164,12 +185,13 @@ object GBTClassifier {
164185
* @param _trees Decision trees in the ensemble.
165186
* @param _treeWeights Weights for the decision trees in the ensemble.
166187
*/
188+
@Since("1.6.0")
167189
@Experimental
168190
final class GBTClassificationModel private[ml](
169-
override val uid: String,
191+
@Since("1.6.0") override val uid: String,
170192
private val _trees: Array[DecisionTreeRegressionModel],
171193
private val _treeWeights: Array[Double],
172-
override val numFeatures: Int)
194+
@Since("1.6.0") override val numFeatures: Int)
173195
extends PredictionModel[Vector, GBTClassificationModel]
174196
with TreeEnsembleModel with Serializable {
175197

@@ -182,11 +204,14 @@ final class GBTClassificationModel private[ml](
182204
* @param _trees Decision trees in the ensemble.
183205
* @param _treeWeights Weights for the decision trees in the ensemble.
184206
*/
207+
@Since("1.6.0")
185208
def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) =
186209
this(uid, _trees, _treeWeights, -1)
187210

211+
@Since("1.4.0")
188212
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
189213

214+
@Since("1.4.0")
190215
override def treeWeights: Array[Double] = _treeWeights
191216

192217
override protected def transformImpl(dataset: DataFrame): DataFrame = {
@@ -205,11 +230,13 @@ final class GBTClassificationModel private[ml](
205230
if (prediction > 0.0) 1.0 else 0.0
206231
}
207232

233+
@Since("1.4.0")
208234
override def copy(extra: ParamMap): GBTClassificationModel = {
209235
copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures),
210236
extra).setParent(parent)
211237
}
212238

239+
@Since("1.4.0")
213240
override def toString: String = {
214241
s"GBTClassificationModel (uid=$uid) with $numTrees trees"
215242
}

0 commit comments

Comments
 (0)