Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.types._

/**
* :: AlphaComponent ::
* A feature transformer than merge multiple columns into a vector column.
* A feature transformer that merges multiple columns into a vector column.
*/
@AlphaComponent
class VectorAssembler extends Transformer with HasInputCols with HasOutputCol {
Expand Down
27 changes: 24 additions & 3 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.util.NoSuchElementException

import scala.annotation.varargs
import scala.collection.mutable
import scala.collection.JavaConverters._

import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.util.Identifiable
Expand Down Expand Up @@ -218,6 +219,19 @@ class BooleanParam(parent: Params, name: String, doc: String) // No need for isV
override def w(value: Boolean): ParamPair[Boolean] = super.w(value)
}

/** Specialized version of [[Param[Array[T]]]] for Java. */
class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array[String] => Boolean)
extends Param[Array[String]](parent, name, doc, isValid) {

def this(parent: Params, name: String, doc: String) =
this(parent, name, doc, ParamValidators.alwaysTrue)

override def w(value: Array[String]): ParamPair[Array[String]] = super.w(value)

/** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
def w(value: java.util.List[String]): ParamPair[Array[String]] = w(value.asScala.toArray)
}

/**
* A param amd its value.
*/
Expand Down Expand Up @@ -310,9 +324,7 @@ trait Params extends Identifiable with Serializable {
* Sets a parameter in the embedded param map.
*/
protected final def set[T](param: Param[T], value: T): this.type = {
shouldOwn(param)
paramMap.put(param.asInstanceOf[Param[Any]], value)
this
set(param -> value)
}

/**
Expand All @@ -322,6 +334,15 @@ trait Params extends Identifiable with Serializable {
set(getParam(param), value)
}

/**
* Sets a parameter in the embedded param map.
*/
protected final def set(paramPair: ParamPair[_]): this.type = {
shouldOwn(paramPair.param)
paramMap.put(paramPair)
this
}

/**
* Optionally returns the user-supplied value of a param.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ private[shared] object SharedParamsCodeGen {
case _ if c == classOf[Float] => "FloatParam"
case _ if c == classOf[Double] => "DoubleParam"
case _ if c == classOf[Boolean] => "BooleanParam"
case _ if c.isArray && c.getComponentType == classOf[String] => s"StringArrayParam"
case _ => s"Param[${getTypeString(c)}]"
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ private[ml] trait HasInputCols extends Params {
* Param for input column names.
* @group param
*/
final val inputCols: Param[Array[String]] = new Param[Array[String]](this, "inputCols", "input column names")
final val inputCols: StringArrayParam = new StringArrayParam(this, "inputCols", "input column names")

/** @group getParam */
final def getInputCols: Array[String] = $(inputCols)
Expand Down
43 changes: 41 additions & 2 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
#

from pyspark.rdd import ignore_unicode_prefix
from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasNumFeatures
from pyspark.ml.param.shared import HasInputCol, HasInputCols, HasOutputCol, HasNumFeatures
from pyspark.ml.util import keyword_only
from pyspark.ml.wrapper import JavaTransformer
from pyspark.mllib.common import inherit_doc

__all__ = ['Tokenizer', 'HashingTF']
__all__ = ['Tokenizer', 'HashingTF', 'VectorAssembler']


@inherit_doc
Expand Down Expand Up @@ -112,6 +112,45 @@ def setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None):
return self._set(**kwargs)


@inherit_doc
class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol):
"""
A feature transformer that merges multiple columns into a vector column.

>>> from pyspark.sql import Row
>>> df = sc.parallelize([Row(a=1, b=0, c=3)]).toDF()
>>> vecAssembler = VectorAssembler(inputCols=["a", "b", "c"], outputCol="features")
>>> vecAssembler.transform(df).head().features
SparseVector(3, {0: 1.0, 2: 3.0})
>>> vecAssembler.setParams(outputCol="freqs").transform(df).head().freqs
SparseVector(3, {0: 1.0, 2: 3.0})
>>> params = {vecAssembler.inputCols: ["b", "a"], vecAssembler.outputCol: "vector"}
>>> vecAssembler.transform(df, params).head().vector
SparseVector(2, {1: 1.0})
"""

_java_class = "org.apache.spark.ml.feature.VectorAssembler"

@keyword_only
def __init__(self, inputCols=None, outputCol=None):
"""
__init__(self, inputCols=None, outputCol=None)
"""
super(VectorAssembler, self).__init__()
self._setDefault()
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)

@keyword_only
def setParams(self, inputCols=None, outputCol=None):
"""
setParams(self, inputCols=None, outputCol=None)
Sets params for this VectorAssembler.
"""
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)


if __name__ == "__main__":
import doctest
from pyspark.context import SparkContext
Expand Down
1 change: 1 addition & 0 deletions python/pyspark/ml/param/_shared_params_code_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def get$Name(self):
("predictionCol", "prediction column name", "'prediction'"),
("rawPredictionCol", "raw prediction column name", "'rawPrediction'"),
("inputCol", "input column name", None),
("inputCols", "input column names", None),
("outputCol", "output column name", None),
("numFeatures", "number of features", None)]
code = []
Expand Down
29 changes: 29 additions & 0 deletions python/pyspark/ml/param/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,35 @@ def getInputCol(self):
return self.getOrDefault(self.inputCol)


class HasInputCols(Params):
"""
Mixin for param inputCols: input column names.
"""

# a placeholder to make it appear in the generated doc
inputCols = Param(Params._dummy(), "inputCols", "input column names")

def __init__(self):
super(HasInputCols, self).__init__()
#: param for input column names
self.inputCols = Param(self, "inputCols", "input column names")
if None is not None:
self._setDefault(inputCols=None)

def setInputCols(self, value):
"""
Sets the value of :py:attr:`inputCols`.
"""
self.paramMap[self.inputCols] = value
return self

def getInputCols(self):
"""
Gets the value of inputCols or its default value.
"""
return self.getOrDefault(self.inputCols)


class HasOutputCol(Params):
"""
Mixin for param outputCol: output column name.
Expand Down
13 changes: 7 additions & 6 deletions python/pyspark/ml/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ def _transfer_params_to_java(self, params, java_obj):
paramMap = self.extractParamMap(params)
for param in self.params:
if param in paramMap:
java_obj.set(param.name, paramMap[param])
value = paramMap[param]
java_param = java_obj.getParam(param.name)
java_obj.set(java_param.w(value))

def _empty_java_param_map(self):
"""
Expand All @@ -79,7 +81,8 @@ def _create_java_param_map(self, params, java_obj):
paramMap = self._empty_java_param_map()
for param, value in params.items():
if param.parent is self:
paramMap.put(java_obj.getParam(param.name), value)
java_param = java_obj.getParam(param.name)
paramMap.put(java_param.w(value))
return paramMap


Expand Down Expand Up @@ -126,10 +129,8 @@ class JavaTransformer(Transformer, JavaWrapper):

def transform(self, dataset, params={}):
java_obj = self._java_obj()
self._transfer_params_to_java({}, java_obj)
java_param_map = self._create_java_param_map(params, java_obj)
return DataFrame(java_obj.transform(dataset._jdf, java_param_map),
dataset.sql_ctx)
self._transfer_params_to_java(params, java_obj)
return DataFrame(java_obj.transform(dataset._jdf), dataset.sql_ctx)


@inherit_doc
Expand Down