Skip to content

Commit 901ff92

Browse files
huaxingaoBryanCutler
authored andcommitted
[SPARK-29464][PYTHON][ML] PySpark ML should expose Params.clear() to unset a user supplied Param
### What changes were proposed in this pull request? change PySpark ml ```Params._clear``` to ```Params.clear``` ### Why are the changes needed? PySpark ML currently has a private _clear() method that will unset a param. This should be made public to match the Scala API and give users a way to unset a user supplied param. ### Does this PR introduce any user-facing change? Yes. PySpark ml ```Params._clear``` ---> ```Params.clear``` ### How was this patch tested? Add test. Closes #26130 from huaxingao/spark-29464. Authored-by: Huaxin Gao <[email protected]> Signed-off-by: Bryan Cutler <[email protected]>
1 parent 00347a3 commit 901ff92

File tree

5 files changed

+30
-6
lines changed

5 files changed

+30
-6
lines changed

python/pyspark/ml/classification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ def setThreshold(self, value):
446446
Clears value of :py:attr:`thresholds` if it has been set.
447447
"""
448448
self._set(threshold=value)
449-
self._clear(self.thresholds)
449+
self.clear(self.thresholds)
450450
return self
451451

452452
@since("1.4.0")
@@ -477,7 +477,7 @@ def setThresholds(self, value):
477477
Clears value of :py:attr:`threshold` if it has been set.
478478
"""
479479
self._set(thresholds=value)
480-
self._clear(self.threshold)
480+
self.clear(self.threshold)
481481
return self
482482

483483
@since("1.5.0")

python/pyspark/ml/param/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ def _set(self, **kwargs):
452452
self._paramMap[p] = value
453453
return self
454454

455-
def _clear(self, param):
455+
def clear(self, param):
456456
"""
457457
Clears a param from the param map if it has been explicitly set.
458458
"""

python/pyspark/ml/tests/test_param.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
from pyspark.ml.classification import LogisticRegression
2828
from pyspark.ml.clustering import KMeans
2929
from pyspark.ml.feature import Binarizer, Bucketizer, ElementwiseProduct, IndexToString, \
30-
VectorSlicer, Word2Vec
31-
from pyspark.ml.linalg import DenseVector, SparseVector
30+
MaxAbsScaler, VectorSlicer, Word2Vec
31+
from pyspark.ml.linalg import DenseVector, SparseVector, Vectors
3232
from pyspark.ml.param import Param, Params, TypeConverters
3333
from pyspark.ml.param.shared import HasInputCol, HasMaxIter, HasSeed
3434
from pyspark.ml.wrapper import JavaParams
@@ -224,6 +224,10 @@ def test_params(self):
224224
testParams.setMaxIter(100)
225225
self.assertTrue(testParams.isSet(maxIter))
226226
self.assertEqual(testParams.getMaxIter(), 100)
227+
testParams.clear(maxIter)
228+
self.assertFalse(testParams.isSet(maxIter))
229+
self.assertEqual(testParams.getMaxIter(), 10)
230+
testParams.setMaxIter(100)
227231

228232
self.assertTrue(testParams.hasParam(inputCol.name))
229233
self.assertFalse(testParams.hasDefault(inputCol))
@@ -248,6 +252,18 @@ def test_params(self):
248252
"maxIter: max number of iterations (>= 0). (default: 10, current: 100)",
249253
"seed: random seed. (default: 41, current: 43)"]))
250254

255+
def test_clear_param(self):
256+
df = self.spark.createDataFrame([(Vectors.dense([1.0]),), (Vectors.dense([2.0]),)], ["a"])
257+
maScaler = MaxAbsScaler(inputCol="a", outputCol="scaled")
258+
model = maScaler.fit(df)
259+
self.assertTrue(model.isSet(model.outputCol))
260+
self.assertEqual(model.getOutputCol(), "scaled")
261+
model.clear(model.outputCol)
262+
self.assertFalse(model.isSet(model.outputCol))
263+
self.assertEqual(model.getOutputCol()[:12], 'MaxAbsScaler')
264+
output = model.transform(df)
265+
self.assertEqual(model.getOutputCol(), output.schema.names[1])
266+
251267
def test_kmeans_param(self):
252268
algo = KMeans()
253269
self.assertEqual(algo.getInitMode(), "k-means||")

python/pyspark/ml/wrapper.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,14 @@ def copy(self, extra=None):
280280
that._transfer_params_to_java()
281281
return that
282282

283+
def clear(self, param):
284+
"""
285+
Clears a param from the param map if it has been explicitly set.
286+
"""
287+
super(JavaParams, self).clear(param)
288+
java_param = self._java_obj.getParam(param.name)
289+
self._java_obj.clear(java_param)
290+
283291

284292
@inherit_doc
285293
class JavaEstimator(JavaParams, Estimator):

python/pyspark/testing/mlutils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def check_params(test_self, py_stage, check_params_exist=True):
6262
continue # Random seeds between Spark and PySpark are different
6363
java_default = _java2py(test_self.sc,
6464
java_stage.clear(java_param).getOrDefault(java_param))
65-
py_stage._clear(p)
65+
py_stage.clear(p)
6666
py_default = py_stage.getOrDefault(p)
6767
# equality test for NaN is always False
6868
if isinstance(java_default, float) and np.isnan(java_default):

0 commit comments

Comments
 (0)