Skip to content

Commit fce244e

Browse files
committed
update explainParams with test
1 parent 4d6b07a commit fce244e

File tree

3 files changed

+11
-10
lines changed

3 files changed

+11
-10
lines changed

mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,13 @@
1717

1818
package org.apache.spark.ml.param
1919

20+
import org.apache.spark.ml.param.shared.{HasInputCol, HasMaxIter}
21+
2022
/** A subclass of Params for testing. */
21-
class TestParams extends Params {
23+
class TestParams extends Params with HasMaxIter with HasInputCol {
2224

23-
val maxIter = new IntParam(this, "maxIter", "max number of iterations")
2425
def setMaxIter(value: Int): this.type = { set(maxIter, value); this }
25-
def getMaxIter: Int = getOrDefault(maxIter)
26-
27-
val inputCol = new Param[String](this, "inputCol", "input column name")
2826
def setInputCol(value: String): this.type = { set(inputCol, value); this }
29-
def getInputCol: String = getOrDefault(inputCol)
3027

3128
setDefault(maxIter -> 10)
3229

python/pyspark/ml/param/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,12 @@ def _explain(self, param):
7575
values = []
7676
if self.isDefined(param):
7777
if self.defaultParamMap.has_key(param):
78-
values += "default: %s" % self.defaultParamMap[param]
78+
values.append("default: %s" % self.defaultParamMap[param])
7979
if self.paramMap.has_key(param):
80-
values += "current: %s" % self.paramMap[param]
80+
values.append("current: %s" % self.paramMap[param])
8181
else:
82-
values += "undefined"
83-
valueStr = "(" + ",".join(values) + ")"
82+
values.append("undefined")
83+
valueStr = "(" + ", ".join(values) + ")"
8484
return "%s: %s %s" % (param.name, param.doc, valueStr)
8585

8686
def explainParams(self):

python/pyspark/ml/tests.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,10 @@ def test_params(self):
153153
with self.assertRaises(KeyError):
154154
testParams.getInputCol()
155155

156+
self.assertEquals(testParams.explainParams(),
157+
"\n".join(["inputCol: input column name (undefined)",
158+
"maxIter: max number of iterations (default: 10, current: 100)"]))
159+
156160

157161
if __name__ == "__main__":
158162
unittest.main()

0 commit comments

Comments
 (0)