Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ final class Bucketizer private[ml] (override val parent: Estimator[Bucketizer])
* otherwise, values outside the splits specified will be treated as errors.
* @group param
*/
val splits: Param[Array[Double]] = new Param[Array[Double]](this, "splits",
val splits: DoubleArrayParam = new DoubleArrayParam(this, "splits",
"Split points for mapping continuous features into buckets. With n+1 splits, there are n " +
"buckets. A bucket defined by splits x,y holds values in the range [x,y) except the last " +
"bucket, which also includes y. The splits should be strictly increasing. " +
Expand Down
15 changes: 14 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ class BooleanParam(parent: Params, name: String, doc: String) // No need for isV
override def w(value: Boolean): ParamPair[Boolean] = super.w(value)
}

/** Specialized version of [[Param[Array[T]]]] for Java. */
/** Specialized version of [[Param[Array[String]]]] for Java. */
class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array[String] => Boolean)
extends Param[Array[String]](parent, name, doc, isValid) {

Expand All @@ -232,6 +232,19 @@ class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array
def w(value: java.util.List[String]): ParamPair[Array[String]] = w(value.asScala.toArray)
}

/** Specialized version of [[Param[Array[Double]]]] for Java. */
class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array[Double] => Boolean)
extends Param[Array[Double]](parent, name, doc, isValid) {

def this(parent: Params, name: String, doc: String) =
this(parent, name, doc, ParamValidators.alwaysTrue)

override def w(value: Array[Double]): ParamPair[Array[Double]] = super.w(value)

/** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
def w(value: java.util.List[Double]): ParamPair[Array[Double]] = w(value.asScala.toArray)
}

/**
* A param amd its value.
*/
Expand Down
77 changes: 77 additions & 0 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,83 @@ def getThreshold(self):
return self.getOrDefault(self.threshold)


@inherit_doc
class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol):
"""
Maps a column of continuous features to a column of feature buckets.

>>> df = sqlContext.createDataFrame([(0.1,), (0.4,), (1.2,), (1.5,)], ["values"])
>>> bucketizer = Bucketizer(splits=[-float("inf"), 0.5, 1.4, float("inf")],
... inputCol="values", outputCol="buckets")
>>> bucketed = bucketizer.transform(df).collect()
>>> bucketed[0].buckets
0.0
>>> bucketed[1].buckets
0.0
>>> bucketed[2].buckets
1.0
>>> bucketed[3].buckets
2.0
>>> bucketizer.setParams(outputCol="b").transform(df).head().b
0.0
"""

_java_class = "org.apache.spark.ml.feature.Bucketizer"
# a placeholder to make it appear in the generated doc
splits = \
Param(Params._dummy(), "splits",
"Split points for mapping continuous features into buckets. With n+1 splits, " +
"there are n buckets. A bucket defined by splits x,y holds values in the " +
"range [x,y) except the last bucket, which also includes y. The splits " +
"should be strictly increasing. Values at -inf, inf must be explicitly " +
"provided to cover all Double values; otherwise, values outside the splits " +
"specified will be treated as errors.")

@keyword_only
def __init__(self, splits=None, inputCol=None, outputCol=None):
"""
__init__(self, splits=None, inputCol=None, outputCol=None)
"""
super(Bucketizer, self).__init__()
#: param for Splitting points for mapping continuous features into buckets. With n+1 splits,
# there are n buckets. A bucket defined by splits x,y holds values in the range [x,y)
# except the last bucket, which also includes y. The splits should be strictly increasing.
# Values at -inf, inf must be explicitly provided to cover all Double values; otherwise,
# values outside the splits specified will be treated as errors.
self.splits = \
Param(self, "splits",
"Split points for mapping continuous features into buckets. With n+1 splits, " +
"there are n buckets. A bucket defined by splits x,y holds values in the " +
"range [x,y) except the last bucket, which also includes y. The splits " +
"should be strictly increasing. Values at -inf, inf must be explicitly " +
"provided to cover all Double values; otherwise, values outside the splits " +
"specified will be treated as errors.")
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)

@keyword_only
def setParams(self, splits=None, inputCol=None, outputCol=None):
"""
setParams(self, splits=None, inputCol=None, outputCol=None)
Sets params for this Bucketizer.
"""
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)

def setSplits(self, value):
"""
Sets the value of :py:attr:`splits`.
"""
self.paramMap[self.splits] = value
return self

def getSplits(self):
"""
Gets the value of threshold or its default value.
"""
return self.getOrDefault(self.splits)


@inherit_doc
class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures):
"""
Expand Down