Skip to content

Commit 8659084

Browse files
committed
Merge branch 'master' of https://github.com/apache/spark into SPARK-6694
2 parents 0301eb9 + 57cd1e8 commit 8659084

File tree

12 files changed

+280
-126
lines changed

12 files changed

+280
-126
lines changed

mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import java.util.UUID
2525
private[ml] trait Identifiable extends Serializable {
2626

2727
/**
28-
* A unique id for the object. The default implementation concatenates the class name, "-", and 8
28+
* A unique id for the object. The default implementation concatenates the class name, "_", and 8
2929
* random hex chars.
3030
*/
3131
private[ml] val uid: String =

mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -227,15 +227,16 @@ object Vectors {
227227
* @param elements vector elements in (index, value) pairs.
228228
*/
229229
def sparse(size: Int, elements: Seq[(Int, Double)]): Vector = {
230-
require(size > 0)
230+
require(size > 0, "The size of the requested sparse vector must be greater than 0.")
231231

232232
val (indices, values) = elements.sortBy(_._1).unzip
233233
var prev = -1
234234
indices.foreach { i =>
235235
require(prev < i, s"Found duplicate indices: $i.")
236236
prev = i
237237
}
238-
require(prev < size)
238+
require(prev < size, s"You may not write an element to index $prev because the declared " +
239+
s"size of your vector is $size")
239240

240241
new SparseVector(size, indices.toArray, values.toArray)
241242
}
@@ -309,7 +310,8 @@ object Vectors {
309310
* @return norm in L^p^ space.
310311
*/
311312
def norm(vector: Vector, p: Double): Double = {
312-
require(p >= 1.0)
313+
require(p >= 1.0, "To compute the p-norm of the vector, we require that you specify a p>=1. " +
314+
s"You specified p=$p.")
313315
val values = vector match {
314316
case DenseVector(vs) => vs
315317
case SparseVector(n, ids, vs) => vs
@@ -360,7 +362,8 @@ object Vectors {
360362
* @return squared distance between two Vectors.
361363
*/
362364
def sqdist(v1: Vector, v2: Vector): Double = {
363-
require(v1.size == v2.size, "vector dimension mismatch")
365+
require(v1.size == v2.size, s"Vector dimensions do not match: Dim(v1)=${v1.size} and Dim(v2)" +
366+
s"=${v2.size}.")
364367
var squaredDistance = 0.0
365368
(v1, v2) match {
366369
case (v1: SparseVector, v2: SparseVector) =>
@@ -518,7 +521,9 @@ class SparseVector(
518521
val indices: Array[Int],
519522
val values: Array[Double]) extends Vector {
520523

521-
require(indices.length == values.length)
524+
require(indices.length == values.length, "Sparse vectors require that the dimension of the" +
525+
s" indices match the dimension of the values. You provided ${indices.size} indices and " +
526+
s" ${values.size} values.")
522527

523528
override def toString: String =
524529
"(%s,%s,%s)".format(size, indices.mkString("[", ",", "]"), values.mkString("[", ",", "]"))

mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,13 @@
1717

1818
package org.apache.spark.ml.param
1919

20+
import org.apache.spark.ml.param.shared.{HasInputCol, HasMaxIter}
21+
2022
/** A subclass of Params for testing. */
21-
class TestParams extends Params {
23+
class TestParams extends Params with HasMaxIter with HasInputCol {
2224

23-
val maxIter = new IntParam(this, "maxIter", "max number of iterations")
2425
def setMaxIter(value: Int): this.type = { set(maxIter, value); this }
25-
def getMaxIter: Int = getOrDefault(maxIter)
26-
27-
val inputCol = new Param[String](this, "inputCol", "input column name")
2826
def setInputCol(value: String): this.type = { set(inputCol, value); this }
29-
def getInputCol: String = getOrDefault(inputCol)
3027

3128
setDefault(maxIter -> 10)
3229

python/pyspark/ml/classification.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
5959
maxIter=100, regParam=0.1)
6060
"""
6161
super(LogisticRegression, self).__init__()
62+
self._setDefault(maxIter=100, regParam=0.1)
6263
kwargs = self.__init__._input_kwargs
6364
self.setParams(**kwargs)
6465

@@ -71,7 +72,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
7172
Sets params for logistic regression.
7273
"""
7374
kwargs = self.setParams._input_kwargs
74-
return self._set_params(**kwargs)
75+
return self._set(**kwargs)
7576

7677
def _create_model(self, java_model):
7778
return LogisticRegressionModel(java_model)

python/pyspark/ml/feature.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,22 +52,22 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol):
5252
_java_class = "org.apache.spark.ml.feature.Tokenizer"
5353

5454
@keyword_only
55-
def __init__(self, inputCol="input", outputCol="output"):
55+
def __init__(self, inputCol=None, outputCol=None):
5656
"""
57-
__init__(self, inputCol="input", outputCol="output")
57+
__init__(self, inputCol=None, outputCol=None)
5858
"""
5959
super(Tokenizer, self).__init__()
6060
kwargs = self.__init__._input_kwargs
6161
self.setParams(**kwargs)
6262

6363
@keyword_only
64-
def setParams(self, inputCol="input", outputCol="output"):
64+
def setParams(self, inputCol=None, outputCol=None):
6565
"""
6666
setParams(self, inputCol="input", outputCol="output")
6767
Sets params for this Tokenizer.
6868
"""
6969
kwargs = self.setParams._input_kwargs
70-
return self._set_params(**kwargs)
70+
return self._set(**kwargs)
7171

7272

7373
@inherit_doc
@@ -91,22 +91,23 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures):
9191
_java_class = "org.apache.spark.ml.feature.HashingTF"
9292

9393
@keyword_only
94-
def __init__(self, numFeatures=1 << 18, inputCol="input", outputCol="output"):
94+
def __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None):
9595
"""
96-
__init__(self, numFeatures=1 << 18, inputCol="input", outputCol="output")
96+
__init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None)
9797
"""
9898
super(HashingTF, self).__init__()
99+
self._setDefault(numFeatures=1 << 18)
99100
kwargs = self.__init__._input_kwargs
100101
self.setParams(**kwargs)
101102

102103
@keyword_only
103-
def setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output"):
104+
def setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None):
104105
"""
105-
setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output")
106+
setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None)
106107
Sets params for this HashingTF.
107108
"""
108109
kwargs = self.setParams._input_kwargs
109-
return self._set_params(**kwargs)
110+
return self._set(**kwargs)
110111

111112

112113
if __name__ == "__main__":

python/pyspark/ml/param/__init__.py

Lines changed: 127 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,23 +25,21 @@
2525

2626
class Param(object):
2727
"""
28-
A param with self-contained documentation and optionally default value.
28+
A param with self-contained documentation.
2929
"""
3030

31-
def __init__(self, parent, name, doc, defaultValue=None):
32-
if not isinstance(parent, Identifiable):
33-
raise ValueError("Parent must be identifiable but got type %s." % type(parent).__name__)
31+
def __init__(self, parent, name, doc):
32+
if not isinstance(parent, Params):
33+
raise ValueError("Parent must be a Params but got type %s." % type(parent).__name__)
3434
self.parent = parent
3535
self.name = str(name)
3636
self.doc = str(doc)
37-
self.defaultValue = defaultValue
3837

3938
def __str__(self):
40-
return str(self.parent) + "-" + self.name
39+
return str(self.parent) + "__" + self.name
4140

4241
def __repr__(self):
43-
return "Param(parent=%r, name=%r, doc=%r, defaultValue=%r)" % \
44-
(self.parent, self.name, self.doc, self.defaultValue)
42+
return "Param(parent=%r, name=%r, doc=%r)" % (self.parent, self.name, self.doc)
4543

4644

4745
class Params(Identifiable):
@@ -52,26 +50,128 @@ class Params(Identifiable):
5250

5351
__metaclass__ = ABCMeta
5452

55-
def __init__(self):
56-
super(Params, self).__init__()
57-
#: embedded param map
58-
self.paramMap = {}
53+
#: internal param map for user-supplied values param map
54+
paramMap = {}
55+
56+
#: internal param map for default values
57+
defaultParamMap = {}
5958

6059
@property
6160
def params(self):
6261
"""
63-
Returns all params. The default implementation uses
64-
:py:func:`dir` to get all attributes of type
62+
Returns all params ordered by name. The default implementation
63+
uses :py:func:`dir` to get all attributes of type
6564
:py:class:`Param`.
6665
"""
6766
return filter(lambda attr: isinstance(attr, Param),
6867
[getattr(self, x) for x in dir(self) if x != "params"])
6968

70-
def _merge_params(self, params):
71-
paramMap = self.paramMap.copy()
72-
paramMap.update(params)
69+
def _explain(self, param):
70+
"""
71+
Explains a single param and returns its name, doc, and optional
72+
default value and user-supplied value in a string.
73+
"""
74+
param = self._resolveParam(param)
75+
values = []
76+
if self.isDefined(param):
77+
if param in self.defaultParamMap:
78+
values.append("default: %s" % self.defaultParamMap[param])
79+
if param in self.paramMap:
80+
values.append("current: %s" % self.paramMap[param])
81+
else:
82+
values.append("undefined")
83+
valueStr = "(" + ", ".join(values) + ")"
84+
return "%s: %s %s" % (param.name, param.doc, valueStr)
85+
86+
def explainParams(self):
87+
"""
88+
Returns the documentation of all params with their optionally
89+
default values and user-supplied values.
90+
"""
91+
return "\n".join([self._explain(param) for param in self.params])
92+
93+
def getParam(self, paramName):
94+
"""
95+
Gets a param by its name.
96+
"""
97+
param = getattr(self, paramName)
98+
if isinstance(param, Param):
99+
return param
100+
else:
101+
raise ValueError("Cannot find param with name %s." % paramName)
102+
103+
def isSet(self, param):
104+
"""
105+
Checks whether a param is explicitly set by user.
106+
"""
107+
param = self._resolveParam(param)
108+
return param in self.paramMap
109+
110+
def hasDefault(self, param):
111+
"""
112+
Checks whether a param has a default value.
113+
"""
114+
param = self._resolveParam(param)
115+
return param in self.defaultParamMap
116+
117+
def isDefined(self, param):
118+
"""
119+
Checks whether a param is explicitly set by user or has a default value.
120+
"""
121+
return self.isSet(param) or self.hasDefault(param)
122+
123+
def getOrDefault(self, param):
124+
"""
125+
Gets the value of a param in the user-supplied param map or its
126+
default value. Raises an error if either is set.
127+
"""
128+
if isinstance(param, Param):
129+
if param in self.paramMap:
130+
return self.paramMap[param]
131+
else:
132+
return self.defaultParamMap[param]
133+
elif isinstance(param, str):
134+
return self.getOrDefault(self.getParam(param))
135+
else:
136+
raise KeyError("Cannot recognize %r as a param." % param)
137+
138+
def extractParamMap(self, extraParamMap={}):
139+
"""
140+
Extracts the embedded default param values and user-supplied
141+
values, and then merges them with extra values from input into
142+
a flat param map, where the latter value is used if there exist
143+
conflicts, i.e., with ordering: default param values <
144+
user-supplied values < extraParamMap.
145+
:param extraParamMap: extra param values
146+
:return: merged param map
147+
"""
148+
paramMap = self.defaultParamMap.copy()
149+
paramMap.update(self.paramMap)
150+
paramMap.update(extraParamMap)
73151
return paramMap
74152

153+
def _shouldOwn(self, param):
154+
"""
155+
Validates that the input param belongs to this Params instance.
156+
"""
157+
if param.parent is not self:
158+
raise ValueError("Param %r does not belong to %r." % (param, self))
159+
160+
def _resolveParam(self, param):
161+
"""
162+
Resolves a param and validates the ownership.
163+
:param param: param name or the param instance, which must
164+
belong to this Params instance
165+
:return: resolved param instance
166+
"""
167+
if isinstance(param, Param):
168+
self._shouldOwn(param)
169+
return param
170+
elif isinstance(param, str):
171+
return self.getParam(param)
172+
else:
173+
raise ValueError("Cannot resolve %r as a param." % param)
174+
75175
@staticmethod
76176
def _dummy():
77177
"""
@@ -81,10 +181,18 @@ def _dummy():
81181
dummy.uid = "undefined"
82182
return dummy
83183

84-
def _set_params(self, **kwargs):
184+
def _set(self, **kwargs):
85185
"""
86-
Sets params.
186+
Sets user-supplied params.
87187
"""
88188
for param, value in kwargs.iteritems():
89189
self.paramMap[getattr(self, param)] = value
90190
return self
191+
192+
def _setDefault(self, **kwargs):
193+
"""
194+
Sets default params.
195+
"""
196+
for param, value in kwargs.iteritems():
197+
self.defaultParamMap[getattr(self, param)] = value
198+
return self

0 commit comments

Comments
 (0)