Skip to content

Commit 5294500

Browse files
committed
update default param handling in python
1 parent 971b95b commit 5294500

File tree

8 files changed

+212
-109
lines changed

8 files changed

+212
-109
lines changed

python/pyspark/ml/classification.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
5959
maxIter=100, regParam=0.1)
6060
"""
6161
super(LogisticRegression, self).__init__()
62+
self._setDefault(maxIter=100, regParam=0.1)
6263
kwargs = self.__init__._input_kwargs
6364
self.setParams(**kwargs)
6465

@@ -71,7 +72,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
7172
Sets params for logistic regression.
7273
"""
7374
kwargs = self.setParams._input_kwargs
74-
return self._set_params(**kwargs)
75+
return self._set(**kwargs)
7576

7677
def _create_model(self, java_model):
7778
return LogisticRegressionModel(java_model)

python/pyspark/ml/feature.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,22 +52,22 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol):
5252
_java_class = "org.apache.spark.ml.feature.Tokenizer"
5353

5454
@keyword_only
55-
def __init__(self, inputCol="input", outputCol="output"):
55+
def __init__(self, inputCol=None, outputCol=None):
5656
"""
57-
__init__(self, inputCol="input", outputCol="output")
57+
__init__(self, inputCol=None, outputCol=None)
5858
"""
5959
super(Tokenizer, self).__init__()
6060
kwargs = self.__init__._input_kwargs
6161
self.setParams(**kwargs)
6262

6363
@keyword_only
64-
def setParams(self, inputCol="input", outputCol="output"):
64+
def setParams(self, inputCol=None, outputCol=None):
6565
"""
6666
setParams(self, inputCol="input", outputCol="output")
6767
Sets params for this Tokenizer.
6868
"""
6969
kwargs = self.setParams._input_kwargs
70-
return self._set_params(**kwargs)
70+
return self._set(**kwargs)
7171

7272

7373
@inherit_doc
@@ -91,22 +91,23 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures):
9191
_java_class = "org.apache.spark.ml.feature.HashingTF"
9292

9393
@keyword_only
94-
def __init__(self, numFeatures=1 << 18, inputCol="input", outputCol="output"):
94+
def __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None):
9595
"""
96-
__init__(self, numFeatures=1 << 18, inputCol="input", outputCol="output")
96+
__init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None)
9797
"""
9898
super(HashingTF, self).__init__()
99+
self._setDefault(numFeatures=1 << 18)
99100
kwargs = self.__init__._input_kwargs
100101
self.setParams(**kwargs)
101102

102103
@keyword_only
103-
def setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output"):
104+
def setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None):
104105
"""
105-
setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output")
106+
setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None)
106107
Sets params for this HashingTF.
107108
"""
108109
kwargs = self.setParams._input_kwargs
109-
return self._set_params(**kwargs)
110+
return self._set(**kwargs)
110111

111112

112113
if __name__ == "__main__":

python/pyspark/ml/param/__init__.py

Lines changed: 125 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,23 +25,21 @@
2525

2626
class Param(object):
2727
"""
28-
A param with self-contained documentation and optionally default value.
28+
A param with self-contained documentation.
2929
"""
3030

31-
def __init__(self, parent, name, doc, defaultValue=None):
32-
if not isinstance(parent, Identifiable):
33-
raise ValueError("Parent must be identifiable but got type %s." % type(parent).__name__)
31+
def __init__(self, parent, name, doc):
32+
if not isinstance(parent, Params):
33+
raise ValueError("Parent must be a Params but got type %s." % type(parent).__name__)
3434
self.parent = parent
3535
self.name = str(name)
3636
self.doc = str(doc)
37-
self.defaultValue = defaultValue
3837

3938
def __str__(self):
40-
return str(self.parent) + "-" + self.name
39+
return str(self.parent) + "__" + self.name
4140

4241
def __repr__(self):
43-
return "Param(parent=%r, name=%r, doc=%r, defaultValue=%r)" % \
44-
(self.parent, self.name, self.doc, self.defaultValue)
42+
return "Param(parent=%r, name=%r, doc=%r)" % (self.parent, self.name, self.doc)
4543

4644

4745
class Params(Identifiable):
@@ -52,10 +50,11 @@ class Params(Identifiable):
5250

5351
__metaclass__ = ABCMeta
5452

55-
def __init__(self):
56-
super(Params, self).__init__()
57-
#: embedded param map
58-
self.paramMap = {}
53+
#: internal param map for user-supplied values param map
54+
paramMap = {}
55+
56+
#: internal param map for default values
57+
defaultParamMap = {}
5958

6059
@property
6160
def params(self):
@@ -67,11 +66,112 @@ def params(self):
6766
return filter(lambda attr: isinstance(attr, Param),
6867
[getattr(self, x) for x in dir(self) if x != "params"])
6968

70-
def _merge_params(self, params):
71-
paramMap = self.paramMap.copy()
72-
paramMap.update(params)
69+
def _explain(self, param):
70+
"""
71+
Explains a single param and returns its name, doc, and optional
72+
default value and user-supplied value in a string.
73+
"""
74+
param = self._resolveParam(param)
75+
values = []
76+
if self.isDefined(param):
77+
if self.defaultParamMap.has_key(param):
78+
values += "default: %s" % self.defaultParamMap[param]
79+
if self.paramMap.has_key(param):
80+
values += "current: %s" % self.paramMap[param]
81+
else:
82+
values += "undefined"
83+
valueStr = "(" + ",".join(values) + ")"
84+
return "%s: %s %s" % (param.name, param.doc, valueStr)
85+
86+
def explainParams(self):
87+
"""
88+
Returns the documentation of all params with their optionally
89+
default values and user-supplied values.
90+
"""
91+
return "\n".join([self._explain(param) for param in self.params])
92+
93+
def getParam(self, paramName):
94+
"""
95+
Gets a param by its name.
96+
"""
97+
param = getattr(self, paramName)
98+
if isinstance(param, Param):
99+
return param
100+
else:
101+
raise ValueError("Cannot find param with name %s." % paramName)
102+
103+
def isSet(self, param):
104+
"""
105+
Checks whether a param is explicitly set by user.
106+
"""
107+
param = self._resolveParam(param)
108+
return self.paramMap.has_key(param)
109+
110+
def hasDefault(self, param):
111+
"""
112+
Checks whether a param has a default value.
113+
"""
114+
param = self._resolveParam(param)
115+
return self.defaultParamMap.has_key(param)
116+
117+
def isDefined(self, param):
118+
"""
119+
Checks whether a param is explicitly set by user or has a default value.
120+
"""
121+
return self.isSet(param) or self.hasDefault(param)
122+
123+
def getOrDefault(self, param):
124+
"""
125+
Gets the value of a param in the user-supplied param map or its
126+
default value. Raises an error if either is set.
127+
"""
128+
if isinstance(param, Param):
129+
if self.paramMap.has_key(param):
130+
return self.paramMap[param]
131+
else:
132+
return self.defaultParamMap[param]
133+
elif isinstance(param, str):
134+
return self.getOrDefault(self.getParam(param))
135+
else:
136+
raise KeyError("Cannot recognize %r as a param." % param)
137+
138+
def extractParamMap(self, extraParamMap={}):
139+
"""
140+
Extracts the embedded default param values and user-supplied
141+
values, and then merges them with extra values from input into
142+
a flat param map, where the latter values is used if there
143+
exist conflicts, i.e., with ordering: default param values <
144+
user-supplied values < extraParamMap.
145+
:param extraParamMap: extra param values
146+
:return: merged param map
147+
"""
148+
paramMap = self.defaultParamMap.copy()
149+
paramMap.update(self.paramMap)
150+
paramMap.update(extraParamMap)
73151
return paramMap
74152

153+
def _shouldOwn(self, param):
154+
"""
155+
Validates that the input param belongs to this Params instance.
156+
"""
157+
if param.parent is not self:
158+
raise ValueError("Param %r does not belong to %r." % (param, self))
159+
160+
def _resolveParam(self, param):
161+
"""
162+
Resolves a param and validates the ownership.
163+
:param param: param name or the param instance, which must
164+
belongs to this Params instance
165+
:return: resolved param instance
166+
"""
167+
if isinstance(param, Param):
168+
self._shouldOwn(param)
169+
return param
170+
elif isinstance(param, str):
171+
return self.getParam(param)
172+
else:
173+
raise ValueError("Cannot resolve %r as a param." % param)
174+
75175
@staticmethod
76176
def _dummy():
77177
"""
@@ -81,10 +181,18 @@ def _dummy():
81181
dummy.uid = "undefined"
82182
return dummy
83183

84-
def _set_params(self, **kwargs):
184+
def _set(self, **kwargs):
85185
"""
86-
Sets params.
186+
Sets user-supplied params.
87187
"""
88188
for param, value in kwargs.iteritems():
89189
self.paramMap[getattr(self, param)] = value
90190
return self
191+
192+
def _setDefault(self, **kwargs):
193+
"""
194+
Sets default params.
195+
"""
196+
for param, value in kwargs.iteritems():
197+
self.defaultParamMap[getattr(self, param)] = value
198+
return self

python/pyspark/ml/param/_gen_shared_params.py renamed to python/pyspark/ml/param/_shared_params_code_gen.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -32,29 +32,33 @@
3232
# limitations under the License.
3333
#"""
3434

35+
# Code generator for shared params (shared.py). Run under this folder with:
36+
# python _shared_params_code_gen.py > shared.py
3537

36-
def _gen_param_code(name, doc, defaultValue):
38+
def _gen_param_code(name, doc, defaultValueStr):
3739
"""
3840
Generates Python code for a shared param class.
3941
4042
:param name: param name
4143
:param doc: param doc
42-
:param defaultValue: string representation of the param
44+
:param defaultValueStr: string representation of the default value
4345
:return: code string
4446
"""
4547
# TODO: How to correctly inherit instance attributes?
4648
template = '''class Has$Name(Params):
4749
"""
48-
Params with $name.
50+
Mixin for param $name: $doc.
4951
"""
5052
5153
# a placeholder to make it appear in the generated doc
52-
$name = Param(Params._dummy(), "$name", "$doc", $defaultValue)
54+
$name = Param(Params._dummy(), "$name", "$doc")
5355
5456
def __init__(self):
5557
super(Has$Name, self).__init__()
5658
#: param for $doc
57-
self.$name = Param(self, "$name", "$doc", $defaultValue)
59+
self.$name = Param(self, "$name", "$doc")
60+
if $defaultValueStr is not None:
61+
self._setDefault($name=$defaultValueStr)
5862
5963
def set$Name(self, value):
6064
"""
@@ -67,32 +71,29 @@ def get$Name(self):
6771
"""
6872
Gets the value of $name or its default value.
6973
"""
70-
if self.$name in self.paramMap:
71-
return self.paramMap[self.$name]
72-
else:
73-
return self.$name.defaultValue'''
74+
return self.getOrDefault(self.$name)'''
7475

75-
upperCamelName = name[0].upper() + name[1:]
76+
Name = name[0].upper() + name[1:]
7677
return template \
7778
.replace("$name", name) \
78-
.replace("$Name", upperCamelName) \
79+
.replace("$Name", Name) \
7980
.replace("$doc", doc) \
80-
.replace("$defaultValue", defaultValue)
81+
.replace("$defaultValueStr", str(defaultValueStr))
8182

8283
if __name__ == "__main__":
8384
print header
84-
print "\n# DO NOT MODIFY. The code is generated by _gen_shared_params.py.\n"
85+
print "\n# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py.\n"
8586
print "from pyspark.ml.param import Param, Params\n\n"
8687
shared = [
87-
("maxIter", "max number of iterations", "100"),
88-
("regParam", "regularization constant", "0.1"),
88+
("maxIter", "max number of iterations", None),
89+
("regParam", "regularization constant", None),
8990
("featuresCol", "features column name", "'features'"),
9091
("labelCol", "label column name", "'label'"),
9192
("predictionCol", "prediction column name", "'prediction'"),
92-
("inputCol", "input column name", "'input'"),
93-
("outputCol", "output column name", "'output'"),
94-
("numFeatures", "number of features", "1 << 18")]
93+
("inputCol", "input column name", None),
94+
("outputCol", "output column name", None),
95+
("numFeatures", "number of features", None)]
9596
code = []
96-
for name, doc, defaultValue in shared:
97-
code.append(_gen_param_code(name, doc, defaultValue))
97+
for name, doc, defaultValueStr in shared:
98+
code.append(_gen_param_code(name, doc, defaultValueStr))
9899
print "\n\n\n".join(code)

0 commit comments

Comments
 (0)