Skip to content

Commit ecddf15

Browse files
committed
Remove deprecated methods for ML.
1 parent 4ac9759 commit ecddf15

File tree

8 files changed

+40
-53
lines changed

8 files changed

+40
-53
lines changed

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ import org.apache.spark.mllib.util.MLUtils
4040
import org.apache.spark.rdd.RDD
4141
import org.apache.spark.sql.{DataFrame, Dataset, Row}
4242
import org.apache.spark.sql.functions.{col, lit}
43-
import org.apache.spark.sql.types.DoubleType
43+
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
4444
import org.apache.spark.storage.StorageLevel
4545
import org.apache.spark.util.VersionUtils
4646

@@ -176,8 +176,12 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
176176
}
177177
}
178178

179-
override def validateParams(): Unit = {
179+
override protected def validateAndTransformSchema(
180+
schema: StructType,
181+
fitting: Boolean,
182+
featuresDataType: DataType): StructType = {
180183
checkThresholdConsistency()
184+
super.validateAndTransformSchema(schema, fitting, featuresDataType)
181185
}
182186
}
183187

mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -221,15 +221,6 @@ class RandomForestClassificationModel private[ml] (
221221
}
222222
}
223223

224-
/**
225-
* Number of trees in ensemble
226-
*
227-
* @deprecated Use [[getNumTrees]] instead. This method will be removed in 2.1.0
228-
*/
229-
// TODO: Once this is removed, then this class can inherit from RandomForestClassifierParams
230-
@deprecated("Use getNumTrees instead. This method will be removed in 2.1.0.", "2.0.0")
231-
val numTrees: Int = trees.length
232-
233224
@Since("1.4.0")
234225
override def copy(extra: ParamMap): RandomForestClassificationModel = {
235226
copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra)

mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -216,13 +216,6 @@ final class ChiSqSelectorModel private[ml] (
216216
@Since("1.6.0")
217217
def setOutputCol(value: String): this.type = set(outputCol, value)
218218

219-
/**
220-
* @group setParam
221-
*/
222-
@Since("1.6.0")
223-
@deprecated("labelCol is not used by ChiSqSelectorModel.", "2.0.0")
224-
def setLabelCol(value: String): this.type = set(labelCol, value)
225-
226219
@Since("2.0.0")
227220
override def transform(dataset: Dataset[_]): DataFrame = {
228221
val transformedSchema = transformSchema(dataset.schema, logging = true)

mllib/src/main/scala/org/apache/spark/ml/param/params.scala

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -546,21 +546,6 @@ trait Params extends Identifiable with Serializable {
546546
.map(m => m.invoke(this).asInstanceOf[Param[_]])
547547
}
548548

549-
/**
550-
* Validates parameter values stored internally.
551-
* Raise an exception if any parameter value is invalid.
552-
*
553-
* This only needs to check for interactions between parameters.
554-
* Parameter value checks which do not depend on other parameters are handled by
555-
* `Param.validate()`. This method does not handle input/output column parameters;
556-
* those are checked during schema validation.
557-
* @deprecated Will be removed in 2.1.0. All the checks should be merged into transformSchema
558-
*/
559-
@deprecated("Will be removed in 2.1.0. Checks should be merged into transformSchema.", "2.0.0")
560-
def validateParams(): Unit = {
561-
// Do nothing by default. Override to handle Param interactions.
562-
}
563-
564549
/**
565550
* Explains a param.
566551
* @param param input param, must belong to this instance.

mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -605,9 +605,6 @@ class LinearRegressionSummary private[regression] (
605605
private val privateModel: LinearRegressionModel,
606606
private val diagInvAtWA: Array[Double]) extends Serializable {
607607

608-
@deprecated("The model field is deprecated and will be removed in 2.1.0.", "2.0.0")
609-
val model: LinearRegressionModel = privateModel
610-
611608
@transient private val metrics = new RegressionMetrics(
612609
predictions
613610
.select(col(predictionCol), col(labelCol).cast(DoubleType))

mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -181,14 +181,6 @@ class RandomForestRegressionModel private[ml] (
181181
_trees.map(_.rootNode.predictImpl(features).prediction).sum / getNumTrees
182182
}
183183

184-
/**
185-
* Number of trees in ensemble
186-
* @deprecated Use [[getNumTrees]] instead. This method will be removed in 2.1.0
187-
*/
188-
// TODO: Once this is removed, then this class can inherit from RandomForestRegressorParams
189-
@deprecated("Use getNumTrees instead. This method will be removed in 2.1.0.", "2.0.0")
190-
val numTrees: Int = trees.length
191-
192184
@Since("1.4.0")
193185
override def copy(extra: ParamMap): RandomForestRegressionModel = {
194186
copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra).setParent(parent)

mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,12 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS
435435
setDefault(maxIter -> 20, stepSize -> 0.1)
436436

437437
/** @group setParam */
438-
def setMaxIter(value: Int): this.type = set(maxIter, value)
438+
def setMaxIter(value: Int): this.type = {
439+
require(ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)(
440+
value), "GBT parameter stepSize should be in interval (0, 1], " +
441+
s"but it given invalid value $value.")
442+
set(maxIter, value)
443+
}
439444

440445
/**
441446
* Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each
@@ -445,12 +450,6 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS
445450
*/
446451
def setStepSize(value: Double): this.type = set(stepSize, value)
447452

448-
override def validateParams(): Unit = {
449-
require(ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)(
450-
getStepSize), "GBT parameter stepSize should be in interval (0, 1], " +
451-
s"but it given invalid value $getStepSize.")
452-
}
453-
454453
/** (private[ml]) Create a BoostingStrategy instance to use with the old API. */
455454
private[ml] def getOldBoostingStrategy(
456455
categoricalFeatures: Map[Int, Int],

python/pyspark/ml/util.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ def context(self, sqlContext):
8181
"""Sets the SQL context to use for saving."""
8282
raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self))
8383

84+
def session(self, sparkSession):
85+
"""Sets the Spark Session to use for saving."""
86+
raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self))
87+
8488

8589
@inherit_doc
8690
class JavaMLWriter(MLWriter):
@@ -105,10 +109,19 @@ def overwrite(self):
105109
return self
106110

107111
def context(self, sqlContext):
108-
"""Sets the SQL context to use for saving."""
112+
"""
113+
Sets the SQL context to use for saving.
114+
.. note:: Deprecated in 2.1, use session instead.
115+
"""
116+
warnings.warn("Deprecated in 2.1, use session instead.")
109117
self._jwrite.context(sqlContext._ssql_ctx)
110118
return self
111119

120+
def session(self, sparkSession):
121+
"""Sets the Spark Session to use for saving."""
122+
self._jwrite.session(sparkSession._jsparkSession)
123+
return self
124+
112125

113126
@inherit_doc
114127
class MLWritable(object):
@@ -158,6 +171,10 @@ def context(self, sqlContext):
158171
"""Sets the SQL context to use for loading."""
159172
raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self))
160173

174+
def session(self, sparkSession):
175+
"""Sets the Spark Session to use for loading."""
176+
raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self))
177+
161178

162179
@inherit_doc
163180
class JavaMLReader(MLReader):
@@ -180,10 +197,19 @@ def load(self, path):
180197
return self._clazz._from_java(java_obj)
181198

182199
def context(self, sqlContext):
183-
"""Sets the SQL context to use for loading."""
200+
"""
201+
Sets the SQL context to use for loading.
202+
.. note:: Deprecated in 2.1, use session instead.
203+
"""
204+
warnings.warn("Deprecated in 2.1, use session instead.")
184205
self._jread.context(sqlContext._ssql_ctx)
185206
return self
186207

208+
def session(self, sparkSession):
209+
"""Sets the Spark Session to use for loading."""
210+
self._jread.session(sparkSession._jsparkSession)
211+
return self
212+
187213
@classmethod
188214
def _java_loader_class(cls, clazz):
189215
"""

0 commit comments

Comments
 (0)