diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index ade4864e1d78..5d794d0d3965 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -425,13 +425,13 @@ def _set(self, **kwargs): Sets user-supplied params. """ for param, value in kwargs.items(): - p = getattr(self, param) if value is not None: + p = getattr(self, param) try: value = p.typeConverter(value) except TypeError as e: raise TypeError('Invalid param value given for param "%s". %s' % (p.name, e)) - self._paramMap[p] = value + self._paramMap[p] = value return self def _clear(self, param): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 6886ed321ee8..87a06e945651 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -247,14 +247,14 @@ class TestParams(HasMaxIter, HasInputCol, HasSeed): A subclass of Params mixed with HasMaxIter, HasInputCol and HasSeed. """ @keyword_only - def __init__(self, seed=None): + def __init__(self, maxIter=None, inputCol=None, seed=None): super(TestParams, self).__init__() self._setDefault(maxIter=10) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only - def setParams(self, seed=None): + def setParams(self, maxIter=None, inputCol=None, seed=None): """ setParams(self, seed=None) Sets params for this test. @@ -389,6 +389,16 @@ def test_word2vec_param(self): # Check windowSize is set properly self.assertEqual(model.getWindowSize(), 6) + def test_param_value_None(self): + tp = TestParams() + self.assertFalse(tp.isSet(tp.inputCol), "inputCol is not set initially") + tp.setParams(inputCol=None) + self.assertFalse(tp.isSet(tp.inputCol), "Value of None should not change param") + tp.setParams(inputCol="input") + self.assertTrue(tp.isSet(tp.inputCol), "inputCol should now be set") + tp.setParams(inputCol=None) + self.assertTrue(tp.isSet(tp.inputCol), "inputCol should still be set") + class FeatureTests(SparkSessionTestCase):