Skip to content

Commit 830ee13

Browse files
committed
[SPARK-18481][ML] ML 2.1 QA: Remove deprecated methods for ML
## What changes were proposed in this pull request? Remove deprecated methods for ML. ## How was this patch tested? Existing tests. Author: Yanbo Liang <[email protected]> Closes #15913 from yanboliang/spark-18481. (cherry picked from commit c4a7eef) Signed-off-by: Yanbo Liang <[email protected]>
1 parent da66b97 commit 830ee13

File tree

16 files changed

+144
-107
lines changed

16 files changed

+144
-107
lines changed

mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ abstract class PipelineStage extends Params with Logging {
4646
*
4747
* Check transform validity and derive the output schema from the input schema.
4848
*
49+
* We check validity for interactions between parameters during `transformSchema` and
50+
* raise an exception if any parameter value is invalid. Parameter value checks which
51+
* do not depend on other parameters are handled by `Param.validate()`.
52+
*
4953
* Typical implementation should first conduct verification on schema change and parameter
5054
* validity, including complex parameter interaction checks.
5155
*/

mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,12 @@ class GBTClassificationModel private[ml](
203203
@Since("1.4.0")
204204
override def trees: Array[DecisionTreeRegressionModel] = _trees
205205

206+
/**
207+
* Number of trees in ensemble
208+
*/
209+
@Since("2.0.0")
210+
val getNumTrees: Int = trees.length
211+
206212
@Since("1.4.0")
207213
override def treeWeights: Array[Double] = _treeWeights
208214

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ import org.apache.spark.mllib.util.MLUtils
4040
import org.apache.spark.rdd.RDD
4141
import org.apache.spark.sql.{DataFrame, Dataset, Row}
4242
import org.apache.spark.sql.functions.{col, lit}
43-
import org.apache.spark.sql.types.DoubleType
43+
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
4444
import org.apache.spark.storage.StorageLevel
4545
import org.apache.spark.util.VersionUtils
4646

@@ -176,8 +176,12 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
176176
}
177177
}
178178

179-
override def validateParams(): Unit = {
179+
override protected def validateAndTransformSchema(
180+
schema: StructType,
181+
fitting: Boolean,
182+
featuresDataType: DataType): StructType = {
180183
checkThresholdConsistency()
184+
super.validateAndTransformSchema(schema, fitting, featuresDataType)
181185
}
182186
}
183187

mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ class RandomForestClassificationModel private[ml] (
158158
@Since("1.6.0") override val numFeatures: Int,
159159
@Since("1.5.0") override val numClasses: Int)
160160
extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
161-
with RandomForestClassificationModelParams with TreeEnsembleModel[DecisionTreeClassificationModel]
161+
with RandomForestClassifierParams with TreeEnsembleModel[DecisionTreeClassificationModel]
162162
with MLWritable with Serializable {
163163

164164
require(_trees.nonEmpty, "RandomForestClassificationModel requires at least 1 tree.")
@@ -221,15 +221,6 @@ class RandomForestClassificationModel private[ml] (
221221
}
222222
}
223223

224-
/**
225-
* Number of trees in ensemble
226-
*
227-
* @deprecated Use [[getNumTrees]] instead. This method will be removed in 2.1.0
228-
*/
229-
// TODO: Once this is removed, then this class can inherit from RandomForestClassifierParams
230-
@deprecated("Use getNumTrees instead. This method will be removed in 2.1.0.", "2.0.0")
231-
val numTrees: Int = trees.length
232-
233224
@Since("1.4.0")
234225
override def copy(extra: ParamMap): RandomForestClassificationModel = {
235226
copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra)

mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -216,13 +216,6 @@ final class ChiSqSelectorModel private[ml] (
216216
@Since("1.6.0")
217217
def setOutputCol(value: String): this.type = set(outputCol, value)
218218

219-
/**
220-
* @group setParam
221-
*/
222-
@Since("1.6.0")
223-
@deprecated("labelCol is not used by ChiSqSelectorModel.", "2.0.0")
224-
def setLabelCol(value: String): this.type = set(labelCol, value)
225-
226219
@Since("2.0.0")
227220
override def transform(dataset: Dataset[_]): DataFrame = {
228221
val transformedSchema = transformSchema(dataset.schema, logging = true)

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -546,21 +546,6 @@ trait Params extends Identifiable with Serializable {
546546
.map(m => m.invoke(this).asInstanceOf[Param[_]])
547547
}
548548

549-
/**
550-
* Validates parameter values stored internally.
551-
* Raise an exception if any parameter value is invalid.
552-
*
553-
* This only needs to check for interactions between parameters.
554-
* Parameter value checks which do not depend on other parameters are handled by
555-
* `Param.validate()`. This method does not handle input/output column parameters;
556-
* those are checked during schema validation.
557-
* @deprecated Will be removed in 2.1.0. All the checks should be merged into transformSchema
558-
*/
559-
@deprecated("Will be removed in 2.1.0. Checks should be merged into transformSchema.", "2.0.0")
560-
def validateParams(): Unit = {
561-
// Do nothing by default. Override to handle Param interactions.
562-
}
563-
564549
/**
565550
* Explains a param.
566551
* @param param input param, must belong to this instance.

mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,12 @@ class GBTRegressionModel private[ml](
183183
@Since("1.4.0")
184184
override def trees: Array[DecisionTreeRegressionModel] = _trees
185185

186+
/**
187+
* Number of trees in ensemble
188+
*/
189+
@Since("2.0.0")
190+
val getNumTrees: Int = trees.length
191+
186192
@Since("1.4.0")
187193
override def treeWeights: Array[Double] = _treeWeights
188194

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -611,9 +611,6 @@ class LinearRegressionSummary private[regression] (
611611
private val privateModel: LinearRegressionModel,
612612
private val diagInvAtWA: Array[Double]) extends Serializable {
613613

614-
@deprecated("The model field is deprecated and will be removed in 2.1.0.", "2.0.0")
615-
val model: LinearRegressionModel = privateModel
616-
617614
@transient private val metrics = new RegressionMetrics(
618615
predictions
619616
.select(col(predictionCol), col(labelCol).cast(DoubleType))

mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ class RandomForestRegressionModel private[ml] (
145145
private val _trees: Array[DecisionTreeRegressionModel],
146146
override val numFeatures: Int)
147147
extends PredictionModel[Vector, RandomForestRegressionModel]
148-
with RandomForestRegressionModelParams with TreeEnsembleModel[DecisionTreeRegressionModel]
148+
with RandomForestRegressorParams with TreeEnsembleModel[DecisionTreeRegressionModel]
149149
with MLWritable with Serializable {
150150

151151
require(_trees.nonEmpty, "RandomForestRegressionModel requires at least 1 tree.")
@@ -182,14 +182,6 @@ class RandomForestRegressionModel private[ml] (
182182
_trees.map(_.rootNode.predictImpl(features).prediction).sum / getNumTrees
183183
}
184184

185-
/**
186-
* Number of trees in ensemble
187-
* @deprecated Use [[getNumTrees]] instead. This method will be removed in 2.1.0
188-
*/
189-
// TODO: Once this is removed, then this class can inherit from RandomForestRegressorParams
190-
@deprecated("Use getNumTrees instead. This method will be removed in 2.1.0.", "2.0.0")
191-
val numTrees: Int = trees.length
192-
193185
@Since("1.4.0")
194186
override def copy(extra: ParamMap): RandomForestRegressionModel = {
195187
copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra).setParent(parent)

mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,6 @@ private[ml] trait TreeEnsembleModel[M <: DecisionTreeModel] {
9595
/** Trees in this ensemble. Warning: These have null parent Estimators. */
9696
def trees: Array[M]
9797

98-
/**
99-
* Number of trees in ensemble
100-
*/
101-
val getNumTrees: Int = trees.length
102-
10398
/** Weights for each tree, zippable with [[trees]] */
10499
def treeWeights: Array[Double]
105100

0 commit comments

Comments
 (0)