Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/pyspark/ml/param/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,13 +425,13 @@ def _set(self, **kwargs):
Sets user-supplied params.
"""
for param, value in kwargs.items():
p = getattr(self, param)
if value is not None:
p = getattr(self, param)
try:
value = p.typeConverter(value)
except TypeError as e:
raise TypeError('Invalid param value given for param "%s". %s' % (p.name, e))
self._paramMap[p] = value
self._paramMap[p] = value
return self

def _clear(self, param):
Expand Down
14 changes: 12 additions & 2 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,14 +247,14 @@ class TestParams(HasMaxIter, HasInputCol, HasSeed):
A subclass of Params mixed with HasMaxIter, HasInputCol and HasSeed.
"""
@keyword_only
def __init__(self, seed=None):
def __init__(self, maxIter=None, inputCol=None, seed=None):
super(TestParams, self).__init__()
self._setDefault(maxIter=10)
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)

@keyword_only
def setParams(self, seed=None):
def setParams(self, maxIter=None, inputCol=None, seed=None):
"""
setParams(self, seed=None)
Sets params for this test.
Expand Down Expand Up @@ -389,6 +389,16 @@ def test_word2vec_param(self):
# Check windowSize is set properly
self.assertEqual(model.getWindowSize(), 6)

def test_param_value_None(self):
tp = TestParams()
self.assertFalse(tp.isSet(tp.inputCol), "inputCol is not set initially")
tp.setParams(inputCol=None)
self.assertFalse(tp.isSet(tp.inputCol), "Value of None should not change param")
tp.setParams(inputCol="input")
self.assertTrue(tp.isSet(tp.inputCol), "inputCol should now be set")
tp.setParams(inputCol=None)
self.assertTrue(tp.isSet(tp.inputCol), "inputCol should still be set")


class FeatureTests(SparkSessionTestCase):

Expand Down