Skip to content

Commit 57cd1e8

Browse files
committed
[SPARK-6893][ML] default pipeline parameter handling in python
Same as apache#5431 but for Python. jkbradley Author: Xiangrui Meng <[email protected]> Closes apache#5534 from mengxr/SPARK-6893 and squashes the following commits: d3b519b [Xiangrui Meng] address comments ebaccc6 [Xiangrui Meng] style update fce244e [Xiangrui Meng] update explainParams with test 4d6b07a [Xiangrui Meng] add tests 5294500 [Xiangrui Meng] update default param handling in python
1 parent 52c3439 commit 57cd1e8

File tree

11 files changed

+270
-121
lines changed

11 files changed

+270
-121
lines changed

mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import java.util.UUID
2525
private[ml] trait Identifiable extends Serializable {
2626

2727
/**
28-
* A unique id for the object. The default implementation concatenates the class name, "-", and 8
28+
* A unique id for the object. The default implementation concatenates the class name, "_", and 8
2929
* random hex chars.
3030
*/
3131
private[ml] val uid: String =

mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,13 @@
1717

1818
package org.apache.spark.ml.param
1919

20+
import org.apache.spark.ml.param.shared.{HasInputCol, HasMaxIter}
21+
2022
/** A subclass of Params for testing. */
21-
class TestParams extends Params {
23+
class TestParams extends Params with HasMaxIter with HasInputCol {
2224

23-
val maxIter = new IntParam(this, "maxIter", "max number of iterations")
2425
def setMaxIter(value: Int): this.type = { set(maxIter, value); this }
25-
def getMaxIter: Int = getOrDefault(maxIter)
26-
27-
val inputCol = new Param[String](this, "inputCol", "input column name")
2826
def setInputCol(value: String): this.type = { set(inputCol, value); this }
29-
def getInputCol: String = getOrDefault(inputCol)
3027

3128
setDefault(maxIter -> 10)
3229

python/pyspark/ml/classification.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
5959
maxIter=100, regParam=0.1)
6060
"""
6161
super(LogisticRegression, self).__init__()
62+
self._setDefault(maxIter=100, regParam=0.1)
6263
kwargs = self.__init__._input_kwargs
6364
self.setParams(**kwargs)
6465

@@ -71,7 +72,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
7172
Sets params for logistic regression.
7273
"""
7374
kwargs = self.setParams._input_kwargs
74-
return self._set_params(**kwargs)
75+
return self._set(**kwargs)
7576

7677
def _create_model(self, java_model):
7778
return LogisticRegressionModel(java_model)

python/pyspark/ml/feature.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,22 +52,22 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol):
5252
_java_class = "org.apache.spark.ml.feature.Tokenizer"
5353

5454
@keyword_only
55-
def __init__(self, inputCol="input", outputCol="output"):
55+
def __init__(self, inputCol=None, outputCol=None):
5656
"""
57-
__init__(self, inputCol="input", outputCol="output")
57+
__init__(self, inputCol=None, outputCol=None)
5858
"""
5959
super(Tokenizer, self).__init__()
6060
kwargs = self.__init__._input_kwargs
6161
self.setParams(**kwargs)
6262

6363
@keyword_only
64-
def setParams(self, inputCol="input", outputCol="output"):
64+
def setParams(self, inputCol=None, outputCol=None):
6565
"""
6666
setParams(self, inputCol="input", outputCol="output")
6767
Sets params for this Tokenizer.
6868
"""
6969
kwargs = self.setParams._input_kwargs
70-
return self._set_params(**kwargs)
70+
return self._set(**kwargs)
7171

7272

7373
@inherit_doc
@@ -91,22 +91,23 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures):
9191
_java_class = "org.apache.spark.ml.feature.HashingTF"
9292

9393
@keyword_only
94-
def __init__(self, numFeatures=1 << 18, inputCol="input", outputCol="output"):
94+
def __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None):
9595
"""
96-
__init__(self, numFeatures=1 << 18, inputCol="input", outputCol="output")
96+
__init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None)
9797
"""
9898
super(HashingTF, self).__init__()
99+
self._setDefault(numFeatures=1 << 18)
99100
kwargs = self.__init__._input_kwargs
100101
self.setParams(**kwargs)
101102

102103
@keyword_only
103-
def setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output"):
104+
def setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None):
104105
"""
105-
setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output")
106+
setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None)
106107
Sets params for this HashingTF.
107108
"""
108109
kwargs = self.setParams._input_kwargs
109-
return self._set_params(**kwargs)
110+
return self._set(**kwargs)
110111

111112

112113
if __name__ == "__main__":

python/pyspark/ml/param/__init__.py

Lines changed: 127 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,23 +25,21 @@
2525

2626
class Param(object):
2727
"""
28-
A param with self-contained documentation and optionally default value.
28+
A param with self-contained documentation.
2929
"""
3030

31-
def __init__(self, parent, name, doc, defaultValue=None):
32-
if not isinstance(parent, Identifiable):
33-
raise ValueError("Parent must be identifiable but got type %s." % type(parent).__name__)
31+
def __init__(self, parent, name, doc):
32+
if not isinstance(parent, Params):
33+
raise ValueError("Parent must be a Params but got type %s." % type(parent).__name__)
3434
self.parent = parent
3535
self.name = str(name)
3636
self.doc = str(doc)
37-
self.defaultValue = defaultValue
3837

3938
def __str__(self):
40-
return str(self.parent) + "-" + self.name
39+
return str(self.parent) + "__" + self.name
4140

4241
def __repr__(self):
43-
return "Param(parent=%r, name=%r, doc=%r, defaultValue=%r)" % \
44-
(self.parent, self.name, self.doc, self.defaultValue)
42+
return "Param(parent=%r, name=%r, doc=%r)" % (self.parent, self.name, self.doc)
4543

4644

4745
class Params(Identifiable):
@@ -52,26 +50,128 @@ class Params(Identifiable):
5250

5351
__metaclass__ = ABCMeta
5452

55-
def __init__(self):
56-
super(Params, self).__init__()
57-
#: embedded param map
58-
self.paramMap = {}
53+
#: internal param map for user-supplied values param map
54+
paramMap = {}
55+
56+
#: internal param map for default values
57+
defaultParamMap = {}
5958

6059
@property
6160
def params(self):
6261
"""
63-
Returns all params. The default implementation uses
64-
:py:func:`dir` to get all attributes of type
62+
Returns all params ordered by name. The default implementation
63+
uses :py:func:`dir` to get all attributes of type
6564
:py:class:`Param`.
6665
"""
6766
return filter(lambda attr: isinstance(attr, Param),
6867
[getattr(self, x) for x in dir(self) if x != "params"])
6968

70-
def _merge_params(self, params):
71-
paramMap = self.paramMap.copy()
72-
paramMap.update(params)
69+
def _explain(self, param):
70+
"""
71+
Explains a single param and returns its name, doc, and optional
72+
default value and user-supplied value in a string.
73+
"""
74+
param = self._resolveParam(param)
75+
values = []
76+
if self.isDefined(param):
77+
if param in self.defaultParamMap:
78+
values.append("default: %s" % self.defaultParamMap[param])
79+
if param in self.paramMap:
80+
values.append("current: %s" % self.paramMap[param])
81+
else:
82+
values.append("undefined")
83+
valueStr = "(" + ", ".join(values) + ")"
84+
return "%s: %s %s" % (param.name, param.doc, valueStr)
85+
86+
def explainParams(self):
87+
"""
88+
Returns the documentation of all params with their optionally
89+
default values and user-supplied values.
90+
"""
91+
return "\n".join([self._explain(param) for param in self.params])
92+
93+
def getParam(self, paramName):
94+
"""
95+
Gets a param by its name.
96+
"""
97+
param = getattr(self, paramName)
98+
if isinstance(param, Param):
99+
return param
100+
else:
101+
raise ValueError("Cannot find param with name %s." % paramName)
102+
103+
def isSet(self, param):
104+
"""
105+
Checks whether a param is explicitly set by user.
106+
"""
107+
param = self._resolveParam(param)
108+
return param in self.paramMap
109+
110+
def hasDefault(self, param):
111+
"""
112+
Checks whether a param has a default value.
113+
"""
114+
param = self._resolveParam(param)
115+
return param in self.defaultParamMap
116+
117+
def isDefined(self, param):
118+
"""
119+
Checks whether a param is explicitly set by user or has a default value.
120+
"""
121+
return self.isSet(param) or self.hasDefault(param)
122+
123+
def getOrDefault(self, param):
124+
"""
125+
Gets the value of a param in the user-supplied param map or its
126+
default value. Raises an error if either is set.
127+
"""
128+
if isinstance(param, Param):
129+
if param in self.paramMap:
130+
return self.paramMap[param]
131+
else:
132+
return self.defaultParamMap[param]
133+
elif isinstance(param, str):
134+
return self.getOrDefault(self.getParam(param))
135+
else:
136+
raise KeyError("Cannot recognize %r as a param." % param)
137+
138+
def extractParamMap(self, extraParamMap={}):
139+
"""
140+
Extracts the embedded default param values and user-supplied
141+
values, and then merges them with extra values from input into
142+
a flat param map, where the latter value is used if there exist
143+
conflicts, i.e., with ordering: default param values <
144+
user-supplied values < extraParamMap.
145+
:param extraParamMap: extra param values
146+
:return: merged param map
147+
"""
148+
paramMap = self.defaultParamMap.copy()
149+
paramMap.update(self.paramMap)
150+
paramMap.update(extraParamMap)
73151
return paramMap
74152

153+
def _shouldOwn(self, param):
154+
"""
155+
Validates that the input param belongs to this Params instance.
156+
"""
157+
if param.parent is not self:
158+
raise ValueError("Param %r does not belong to %r." % (param, self))
159+
160+
def _resolveParam(self, param):
161+
"""
162+
Resolves a param and validates the ownership.
163+
:param param: param name or the param instance, which must
164+
belong to this Params instance
165+
:return: resolved param instance
166+
"""
167+
if isinstance(param, Param):
168+
self._shouldOwn(param)
169+
return param
170+
elif isinstance(param, str):
171+
return self.getParam(param)
172+
else:
173+
raise ValueError("Cannot resolve %r as a param." % param)
174+
75175
@staticmethod
76176
def _dummy():
77177
"""
@@ -81,10 +181,18 @@ def _dummy():
81181
dummy.uid = "undefined"
82182
return dummy
83183

84-
def _set_params(self, **kwargs):
184+
def _set(self, **kwargs):
85185
"""
86-
Sets params.
186+
Sets user-supplied params.
87187
"""
88188
for param, value in kwargs.iteritems():
89189
self.paramMap[getattr(self, param)] = value
90190
return self
191+
192+
def _setDefault(self, **kwargs):
193+
"""
194+
Sets default params.
195+
"""
196+
for param, value in kwargs.iteritems():
197+
self.defaultParamMap[getattr(self, param)] = value
198+
return self

python/pyspark/ml/param/_gen_shared_params.py renamed to python/pyspark/ml/param/_shared_params_code_gen.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -32,29 +32,34 @@
3232
# limitations under the License.
3333
#"""
3434

35+
# Code generator for shared params (shared.py). Run under this folder with:
36+
# python _shared_params_code_gen.py > shared.py
3537

36-
def _gen_param_code(name, doc, defaultValue):
38+
39+
def _gen_param_code(name, doc, defaultValueStr):
3740
"""
3841
Generates Python code for a shared param class.
3942
4043
:param name: param name
4144
:param doc: param doc
42-
:param defaultValue: string representation of the param
45+
:param defaultValueStr: string representation of the default value
4346
:return: code string
4447
"""
4548
# TODO: How to correctly inherit instance attributes?
4649
template = '''class Has$Name(Params):
4750
"""
48-
Params with $name.
51+
Mixin for param $name: $doc.
4952
"""
5053
5154
# a placeholder to make it appear in the generated doc
52-
$name = Param(Params._dummy(), "$name", "$doc", $defaultValue)
55+
$name = Param(Params._dummy(), "$name", "$doc")
5356
5457
def __init__(self):
5558
super(Has$Name, self).__init__()
5659
#: param for $doc
57-
self.$name = Param(self, "$name", "$doc", $defaultValue)
60+
self.$name = Param(self, "$name", "$doc")
61+
if $defaultValueStr is not None:
62+
self._setDefault($name=$defaultValueStr)
5863
5964
def set$Name(self, value):
6065
"""
@@ -67,32 +72,29 @@ def get$Name(self):
6772
"""
6873
Gets the value of $name or its default value.
6974
"""
70-
if self.$name in self.paramMap:
71-
return self.paramMap[self.$name]
72-
else:
73-
return self.$name.defaultValue'''
75+
return self.getOrDefault(self.$name)'''
7476

75-
upperCamelName = name[0].upper() + name[1:]
77+
Name = name[0].upper() + name[1:]
7678
return template \
7779
.replace("$name", name) \
78-
.replace("$Name", upperCamelName) \
80+
.replace("$Name", Name) \
7981
.replace("$doc", doc) \
80-
.replace("$defaultValue", defaultValue)
82+
.replace("$defaultValueStr", str(defaultValueStr))
8183

8284
if __name__ == "__main__":
8385
print header
84-
print "\n# DO NOT MODIFY. The code is generated by _gen_shared_params.py.\n"
86+
print "\n# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py.\n"
8587
print "from pyspark.ml.param import Param, Params\n\n"
8688
shared = [
87-
("maxIter", "max number of iterations", "100"),
88-
("regParam", "regularization constant", "0.1"),
89+
("maxIter", "max number of iterations", None),
90+
("regParam", "regularization constant", None),
8991
("featuresCol", "features column name", "'features'"),
9092
("labelCol", "label column name", "'label'"),
9193
("predictionCol", "prediction column name", "'prediction'"),
92-
("inputCol", "input column name", "'input'"),
93-
("outputCol", "output column name", "'output'"),
94-
("numFeatures", "number of features", "1 << 18")]
94+
("inputCol", "input column name", None),
95+
("outputCol", "output column name", None),
96+
("numFeatures", "number of features", None)]
9597
code = []
96-
for name, doc, defaultValue in shared:
97-
code.append(_gen_param_code(name, doc, defaultValue))
98+
for name, doc, defaultValueStr in shared:
99+
code.append(_gen_param_code(name, doc, defaultValueStr))
98100
print "\n\n\n".join(code)

0 commit comments

Comments
 (0)