@@ -20,7 +20,7 @@ package org.apache.spark.ml.classification
2020import com .github .fommil .netlib .BLAS .{getInstance => blas }
2121
2222import org .apache .spark .Logging
23- import org .apache .spark .annotation .Experimental
23+ import org .apache .spark .annotation .{ Experimental , Since }
2424import org .apache .spark .ml .{PredictionModel , Predictor }
2525import org .apache .spark .ml .param .{Param , ParamMap }
2626import org .apache .spark .ml .regression .DecisionTreeRegressionModel
@@ -44,54 +44,69 @@ import org.apache.spark.sql.types.DoubleType
4444 * It supports binary labels, as well as both continuous and categorical features.
4545 * Note: Multiclass labels are not currently supported.
4646 */
47+ @ Since (" 1.4.0" )
4748@ Experimental
48- final class GBTClassifier (override val uid : String )
49+ final class GBTClassifier @ Since (" 1.4.0" ) (
50+ @ Since (" 1.4.0" ) override val uid : String )
4951 extends Predictor [Vector , GBTClassifier , GBTClassificationModel ]
5052 with GBTParams with TreeClassifierParams with Logging {
5153
54+ @ Since (" 1.4.0" )
5255 def this () = this (Identifiable .randomUID(" gbtc" ))
5356
5457 // Override parameter setters from parent trait for Java API compatibility.
5558
5659 // Parameters from TreeClassifierParams:
5760
61+ @ Since (" 1.4.0" )
5862 override def setMaxDepth (value : Int ): this .type = super .setMaxDepth(value)
5963
64+ @ Since (" 1.4.0" )
6065 override def setMaxBins (value : Int ): this .type = super .setMaxBins(value)
6166
67+ @ Since (" 1.4.0" )
6268 override def setMinInstancesPerNode (value : Int ): this .type =
6369 super .setMinInstancesPerNode(value)
6470
71+ @ Since (" 1.4.0" )
6572 override def setMinInfoGain (value : Double ): this .type = super .setMinInfoGain(value)
6673
74+ @ Since (" 1.4.0" )
6775 override def setMaxMemoryInMB (value : Int ): this .type = super .setMaxMemoryInMB(value)
6876
77+ @ Since (" 1.4.0" )
6978 override def setCacheNodeIds (value : Boolean ): this .type = super .setCacheNodeIds(value)
7079
80+ @ Since (" 1.4.0" )
7181 override def setCheckpointInterval (value : Int ): this .type = super .setCheckpointInterval(value)
7282
7383 /**
7484 * The impurity setting is ignored for GBT models.
7585 * Individual trees are built using impurity "Variance."
7686 */
87+ @ Since (" 1.4.0" )
7788 override def setImpurity (value : String ): this .type = {
7889 logWarning(" GBTClassifier.setImpurity should NOT be used" )
7990 this
8091 }
8192
8293 // Parameters from TreeEnsembleParams:
8394
95+ @ Since (" 1.4.0" )
8496 override def setSubsamplingRate (value : Double ): this .type = super .setSubsamplingRate(value)
8597
98+ @ Since (" 1.4.0" )
8699 override def setSeed (value : Long ): this .type = {
87100 logWarning(" The 'seed' parameter is currently ignored by Gradient Boosting." )
88101 super .setSeed(value)
89102 }
90103
91104 // Parameters from GBTParams:
92105
106+ @ Since (" 1.4.0" )
93107 override def setMaxIter (value : Int ): this .type = super .setMaxIter(value)
94108
109+ @ Since (" 1.4.0" )
95110 override def setStepSize (value : Double ): this .type = super .setStepSize(value)
96111
97112 // Parameters for GBTClassifier:
@@ -102,6 +117,7 @@ final class GBTClassifier(override val uid: String)
102117 * (default = logistic)
103118 * @group param
104119 */
120+ @ Since (" 1.4.0" )
105121 val lossType : Param [String ] = new Param [String ](this , " lossType" , " Loss function which GBT" +
106122 " tries to minimize (case-insensitive). Supported options:" +
107123 s " ${GBTClassifier .supportedLossTypes.mkString(" , " )}" ,
@@ -110,9 +126,11 @@ final class GBTClassifier(override val uid: String)
110126 setDefault(lossType -> " logistic" )
111127
112128 /** @group setParam */
129+ @ Since (" 1.4.0" )
113130 def setLossType (value : String ): this .type = set(lossType, value)
114131
115132 /** @group getParam */
133+ @ Since (" 1.4.0" )
116134 def getLossType : String = $(lossType).toLowerCase
117135
118136 /** (private[ml]) Convert new loss to old loss. */
@@ -145,13 +163,16 @@ final class GBTClassifier(override val uid: String)
145163 GBTClassificationModel .fromOld(oldModel, this , categoricalFeatures, numFeatures)
146164 }
147165
166+ @ Since (" 1.4.1" )
148167 override def copy (extra : ParamMap ): GBTClassifier = defaultCopy(extra)
149168}
150169
170+ @ Since (" 1.4.0" )
151171@ Experimental
152172object GBTClassifier {
153173 // The losses below should be lowercase.
154174 /** Accessor for supported loss settings: logistic */
175+ @ Since (" 1.4.0" )
155176 final val supportedLossTypes : Array [String ] = Array (" logistic" ).map(_.toLowerCase)
156177}
157178
@@ -164,12 +185,13 @@ object GBTClassifier {
164185 * @param _trees Decision trees in the ensemble.
165186 * @param _treeWeights Weights for the decision trees in the ensemble.
166187 */
188+ @ Since (" 1.6.0" )
167189@ Experimental
168190final class GBTClassificationModel private [ml](
169- override val uid : String ,
191+ @ Since ( " 1.6.0 " ) override val uid : String ,
170192 private val _trees : Array [DecisionTreeRegressionModel ],
171193 private val _treeWeights : Array [Double ],
172- override val numFeatures : Int )
194+ @ Since ( " 1.6.0 " ) override val numFeatures : Int )
173195 extends PredictionModel [Vector , GBTClassificationModel ]
174196 with TreeEnsembleModel with Serializable {
175197
@@ -182,11 +204,14 @@ final class GBTClassificationModel private[ml](
182204 * @param _trees Decision trees in the ensemble.
183205 * @param _treeWeights Weights for the decision trees in the ensemble.
184206 */
207+ @ Since (" 1.6.0" )
185208 def this (uid : String , _trees : Array [DecisionTreeRegressionModel ], _treeWeights : Array [Double ]) =
186209 this (uid, _trees, _treeWeights, - 1 )
187210
211+ @ Since (" 1.4.0" )
188212 override def trees : Array [DecisionTreeModel ] = _trees.asInstanceOf [Array [DecisionTreeModel ]]
189213
214+ @ Since (" 1.4.0" )
190215 override def treeWeights : Array [Double ] = _treeWeights
191216
192217 override protected def transformImpl (dataset : DataFrame ): DataFrame = {
@@ -205,11 +230,13 @@ final class GBTClassificationModel private[ml](
205230 if (prediction > 0.0 ) 1.0 else 0.0
206231 }
207232
233+ @ Since (" 1.4.0" )
208234 override def copy (extra : ParamMap ): GBTClassificationModel = {
209235 copyValues(new GBTClassificationModel (uid, _trees, _treeWeights, numFeatures),
210236 extra).setParent(parent)
211237 }
212238
239+ @ Since (" 1.4.0" )
213240 override def toString : String = {
214241 s " GBTClassificationModel (uid= $uid) with $numTrees trees "
215242 }
0 commit comments