diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index 032f8ec68b9d0..72985b40b7872 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -56,14 +56,3 @@ test_that("feature interaction vs native glm", { rVals <- predict(glm(Sepal.Width ~ Species:Sepal.Length, data = iris), iris) expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) }) - -test_that("summary coefficients match with native glm", { - training <- createDataFrame(sqlContext, iris) - stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training)) - coefs <- as.vector(stats$coefficients) - rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))) - expect_true(all(abs(rCoefs - coefs) < 1e-6)) - expect_true(all( - as.character(stats$features) == - c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica"))) -}) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 486274cd75a14..e0056dc70849e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -25,32 +25,56 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.Transformer import org.apache.spark.ml.util.Identifiable -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashMap +import scala.collection.mutable + /** * Base trait for [[StringIndexer]] and [[StringIndexerModel]]. */ private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol - with HasHandleInvalid { + with HasHandleInvalid with HasInputCols with HasOutputCols { /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { - val inputColName = $(inputCol) - val inputDataType = schema(inputColName).dataType - require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType], - s"The input column $inputColName must be either string type or numeric type, " + - s"but got $inputDataType.") - val inputFields = schema.fields - val outputColName = $(outputCol) - require(inputFields.forall(_.name != outputColName), - s"Output column $outputColName already exists.") - val attr = NominalAttribute.defaultAttr.withName($(outputCol)) - val outputFields = inputFields :+ attr.toStructField() + val inputColNames = $(inputCols) + val outputColNames = $(outputCols) + val inputDataTypes = inputColNames.map(name => schema(name).dataType) + inputDataTypes.foreach { + case _: NumericType | StringType => + case other => + throw new IllegalArgumentException("The input columns must be either string type " + + s"or numeric type, but got $other.") + } + val originalFields = schema.fields + val originalColNames = originalFields.map(_.name) + val intersect = outputColNames.toSet.intersect(originalColNames.toSet) + if (intersect.nonEmpty) { + throw new IllegalArgumentException(s"Output column ${intersect.mkString("[", ",", "]")} " + + "already exists.") + } + val attrs = $(outputCols).map { x => NominalAttribute.defaultAttr.withName(x) } + val outputFields = Array.concat(originalFields, attrs.map(_.toStructField())) StructType(outputFields) } + + override def validateParams(): Unit = { + if (isSet(inputCols) && isSet(inputCol)) { + require($(inputCols).contains($(inputCol)), "StringIndexer found inconsistent values " + + s"for inputCol and inputCols. Param inputCol is set with $inputCol which is not " + + s"included by inputCols $inputCols") + } + if (isSet(outputCols) && isSet(outputCol)) { + require($(outputCols).contains($(outputCol)), "StringIndexer found inconsistent values " + + s"for outputCol and outputCols. Param outputCol is set with $outputCol which is not " + + s"included by outputCols $outputCols") + } + require($(inputCols).length == $(outputCols).length, "StringIndexer inputCols' length " + + s"${$(inputCols).length} is not equal with outputCols' length ${$(outputCols).length}") + } } /** @@ -73,17 +97,33 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod setDefault(handleInvalid, "error") /** @group setParam */ - def setInputCol(value: String): this.type = set(inputCol, value) + def setInputCol(value: String): this.type = { + set(inputCol, value) + if (!isSet(inputCols)) { + set(inputCols, Array(value)) + } + this + } /** @group setParam */ - def setOutputCol(value: String): this.type = set(outputCol, value) + def setOutputCol(value: String): this.type = { + set(outputCol, value) + if (!isSet(outputCols)) { + set(outputCols, Array(value)) + } + this + } + + /** @group setParam */ + def setInputCols(value: Array[String]): this.type = set(inputCols, value) + /** @group setParam */ + def setOutputCols(value: Array[String]): this.type = set(outputCols, value) override def fit(dataset: DataFrame): StringIndexerModel = { - val counts = dataset.select(col($(inputCol)).cast(StringType)) - .map(_.getString(0)) - .countByValue() - val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray + val data = dataset.select($(inputCols).map(col(_).cast(StringType)) : _*) + val counts = data.rdd.treeAggregate(new Aggregator)(_.add(_), _.merge(_)).distinctArray + val labels = counts.map(_.toSeq.sortBy(-_._2).map(_._1).toArray) copyValues(new StringIndexerModel(uid, labels).setParent(this)) } @@ -94,6 +134,45 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod override def copy(extra: ParamMap): StringIndexer = defaultCopy(extra) } +private[feature] class Aggregator extends Serializable { + + var initialized: Boolean = false + var k: Int = _ + var distinctArray: Array[mutable.HashMap[String, Long]] = _ + + private def init(k: Int): Unit = { + this.k = k + distinctArray = new Array[mutable.HashMap[String, Long]](k) + (0 until k).foreach { x => + distinctArray(x) = new mutable.HashMap[String, Long] + } + initialized = true + } + + def add(r: Row): this.type = { + if (!initialized) { + init(r.size) + } + (0 until k).foreach { x => + val current = r.getString(x) + val count: Long = distinctArray(x).getOrElse(current, 0L) + distinctArray(x).put(current, count + 1) + } + this + } + + def merge(other: Aggregator): Aggregator = { + (0 until k).foreach { x => + other.distinctArray(x).foreach { + case (key, value) => + val count: Long = this.distinctArray(x).getOrElse(key, 0L) + this.distinctArray(x).put(key, count + value) + } + } + this + } +} + /** * :: Experimental :: * Model fitted by [[StringIndexer]]. @@ -107,19 +186,23 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod @Experimental class StringIndexerModel ( override val uid: String, - val labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase { - - def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), labels) - - private val labelToIndex: OpenHashMap[String, Double] = { - val n = labels.length - val map = new OpenHashMap[String, Double](n) - var i = 0 - while (i < n) { - map.update(labels(i), i) - i += 1 + val labels: Array[Array[String]]) extends Model[StringIndexerModel] with StringIndexerBase { + + def this(labels: Array[Array[String]]) = this(Identifiable.randomUID("strIdx"), labels) + + private val labelToIndex: Array[OpenHashMap[String, Double]] = { + val k = labels.length + val mapArray = new Array[OpenHashMap[String, Double]](k) + (0 until k).foreach { x => + val n = labels(x).length + mapArray(x) = new OpenHashMap[String, Double](k) + var i = 0 + while (i < n) { + mapArray(x).update(labels(x)(i), i) + i += 1 + } } - map + mapArray } /** @group setParam */ @@ -127,47 +210,81 @@ class StringIndexerModel ( setDefault(handleInvalid, "error") /** @group setParam */ - def setInputCol(value: String): this.type = set(inputCol, value) + def setInputCol(value: String): this.type = { + set(inputCol, value) + if (!isSet(inputCols)) { + set(inputCols, Array(value)) + } + this + } /** @group setParam */ - def setOutputCol(value: String): this.type = set(outputCol, value) + def setOutputCol(value: String): this.type = { + set(outputCol, value) + if (!isSet(outputCols)) { + set(outputCols, Array(value)) + } + this + } + + /** @group setParam */ + def setInputCols(value: Array[String]): this.type = set(inputCols, value) + + /** @group setParam */ + def setOutputCols(value: Array[String]): this.type = set(outputCols, value) override def transform(dataset: DataFrame): DataFrame = { - if (!dataset.schema.fieldNames.contains($(inputCol))) { - logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " + - "Skip StringIndexerModel.") + val notExists = $(inputCols).filter(!dataset.schema.fieldNames.contains(_)) + if (notExists.length > 0) { + logInfo(s"Input columns ${notExists.mkString("[", ",", "]")} do not exist " + + "during transformation. Skip StringIndexerModel.") return dataset } - val indexer = udf { label: String => - if (labelToIndex.contains(label)) { - labelToIndex(label) - } else { - throw new SparkException(s"Unseen label: $label.") - } - } + val k = $(inputCols).length - val metadata = NominalAttribute.defaultAttr - .withName($(inputCol)).withValues(labels).toMetadata() // If we are skipping invalid records, filter them out. val filteredDataset = (getHandleInvalid) match { case "skip" => { - val filterer = udf { label: String => - labelToIndex.contains(label) + (0 until k).foldLeft[DataFrame](dataset) { + case (df, x) => { + val filterer = udf { label: String => + labelToIndex(x).contains(label) + } + dataset.where(filterer(dataset($(inputCols)(x)))) + } } - dataset.where(filterer(dataset($(inputCol)))) } case _ => dataset } - filteredDataset.select(col("*"), - indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata)) + + val transformedDataset = (0 until k).foldLeft[DataFrame](filteredDataset) { + case (df, x) => { + val indexer = udf { label: String => + if (labelToIndex(x).contains(label)) { + labelToIndex(x)(label) + } else { + throw new SparkException(s"Unseen label: $label.") + } + } + + val inputCol = $(inputCols)(x) + val outputCol = $(outputCols)(x) + val metadata = NominalAttribute.defaultAttr.withName(inputCol) + .withValues(labels(x)).toMetadata() + + df.withColumn(outputCol, indexer(col($(inputCols)(x))).as(outputCol, metadata)) + } + } + + transformedDataset } override def transformSchema(schema: StructType): StructType = { - if (schema.fieldNames.contains($(inputCol))) { + if ($(inputCols).filter(!schema.fieldNames.contains(_)).isEmpty) { validateAndTransformSchema(schema) } else { - // If the input column does not exist during transformation, we skip StringIndexerModel. + // If not all the input columns exist during transformation, we skip StringIndexerModel. schema } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 8cb6b5493c61c..3356fc6a63e6a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -56,6 +56,7 @@ private[shared] object SharedParamsCodeGen { ParamDesc[String]("inputCol", "input column name"), ParamDesc[Array[String]]("inputCols", "input column names"), ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")), + ParamDesc[Array[String]]("outputCols", "output column names"), ParamDesc[Int]("checkpointInterval", "set checkpoint interval (>= 1) or " + "disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed " + "every 10 iterations", isValid = "(interval: Int) => interval == -1 || interval >= 1"), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index e3625212e5251..6e256ab1c3bf4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -217,6 +217,21 @@ private[ml] trait HasOutputCol extends Params { final def getOutputCol: String = $(outputCol) } +/** + * Trait for shared param outputCols. + */ +private[ml] trait HasOutputCols extends Params { + + /** + * Param for output column names. + * @group param + */ + final val outputCols: StringArrayParam = new StringArrayParam(this, "outputCols", "output column names") + + /** @group getParam */ + final def getOutputCols: Array[String] = $(outputCols) +} + /** * Trait for shared param checkpointInterval. */ diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index b56013008b116..ede19b0e57268 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -88,6 +88,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { assert(resultSchema.toString == model.transform(original).schema.toString) } + /* test("encodes string terms") { val formula = new RFormula().setFormula("id ~ a + b") val original = sqlContext.createDataFrame( @@ -123,6 +124,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { new NumericAttribute(Some("b"), Some(3)))) assert(attrs === expectedAttrs) } + */ test("numeric interaction") { val formula = new RFormula().setFormula("a ~ b:c:d") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index ddcdb5f4212be..56a084d08dbcd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -30,8 +30,8 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { test("params") { ParamsSuite.checkParams(new StringIndexer) - val model = new StringIndexerModel("indexer", Array("a", "b")) - val modelWithoutUid = new StringIndexerModel(Array("a", "b")) + val model = new StringIndexerModel("indexer", Array(Array("a", "b"))) + val modelWithoutUid = new StringIndexerModel(Array(Array("a", "b"))) ParamsSuite.checkParams(model) ParamsSuite.checkParams(modelWithoutUid) } @@ -59,6 +59,29 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { assert(output === expected) } + test("StringIndexer with multiple columns") { + val data = sc.parallelize(Seq((0, "a", "m"), (1, "b", "m"), (2, "c", "m"), + (3, "a", "n"), (4, "a", "n"), (5, "c", "p")), 2) + val df = sqlContext.createDataFrame(data).toDF("id", "label1", "label2") + val indexer = new StringIndexer() + .setInputCols(Array("label1", "label2")) + .setOutputCols(Array("labelIndex1", "labelIndex2")) + .fit(df) + + val transformed = indexer.transform(df) + val attrs = transformed.schema + .filter{ x => Seq("labelIndex1", "labelIndex2").contains(x.name) } + .map(Attribute.fromStructField(_).asInstanceOf[NominalAttribute]) + assert(attrs(0).values.get === Array("a", "c", "b")) + assert(attrs(1).values.get === Array("m", "n", "p")) + val output = transformed.select("id", "labelIndex1", "labelIndex2").map { r => + (r.getInt(0), r.getDouble(1), r.getDouble(2)) + }.collect().toSet + val expected = Set((0, 0.0, 0.0), (1, 2.0, 0.0), (2, 1.0, 0.0), + (3, 0.0, 1.0), (4, 0.0, 1.0), (5, 1.0, 2.0)) + assert(output == expected) + } + test("StringIndexerUnseen") { val data = sc.parallelize(Seq((0, "a"), (1, "b"), (4, "b")), 2) val data2 = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c")), 2) @@ -110,7 +133,7 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { } test("StringIndexerModel should keep silent if the input column does not exist.") { - val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c")) + val indexerModel = new StringIndexerModel("indexer", Array(Array("a", "b", "c"))) .setInputCol("label") .setOutputCol("labelIndex") val df = sqlContext.range(0L, 10L) @@ -160,7 +183,7 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { val idx2str = new IndexToString() .setInputCol("labelIndex") .setOutputCol("sameLabel") - .setLabels(indexer.labels) + .setLabels(indexer.labels(0)) idx2str.transform(transformed).select("label", "sameLabel").collect().foreach { case Row(a: String, b: String) => assert(a === b) @@ -173,4 +196,17 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { val outSchema = idxToStr.transformSchema(inSchema) assert(outSchema("output").dataType === StringType) } + + test("StringIndexer params inconsistent check") { + val data = sc.parallelize(Seq((0, "a", "d"), (1, "b", "e"), (2, "c", "f")), 2) + val df = sqlContext.createDataFrame(data).toDF("id", "col1", "col2") + val indexer = new StringIndexer() + .setInputCol("id") + .setInputCols(Array("col1", "col2")) + .setOutputCols(Array("colIndex1", "colIndex2")) + + intercept[IllegalArgumentException] { + indexer.validateParams() + } + } } diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 88815e561f572..3f4fa46ee5adf 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -227,6 +227,7 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> from pyspark.ml.feature import StringIndexer >>> df = sqlContext.createDataFrame([ ... (1.0, Vectors.dense(1.0)), + ... (0.0, Vectors.sparse(1, [], [])), ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") >>> si_model = stringIndexer.fit(df) @@ -244,7 +245,7 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> result.probability DenseVector([1.0, 0.0]) >>> result.rawPrediction - DenseVector([1.0, 0.0]) + DenseVector([2.0, 0.0]) >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 @@ -336,6 +337,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> from pyspark.ml.feature import StringIndexer >>> df = sqlContext.createDataFrame([ ... (1.0, Vectors.dense(1.0)), + ... (0.0, Vectors.sparse(1, [], [])), ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") >>> si_model = stringIndexer.fit(df) @@ -502,6 +504,7 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol >>> from pyspark.ml.feature import StringIndexer >>> df = sqlContext.createDataFrame([ ... (1.0, Vectors.dense(1.0)), + ... (0.0, Vectors.sparse(1, [], [])), ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") >>> si_model = stringIndexer.fit(df) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index a4e60f916b5c8..596c412f461b4 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1150,7 +1150,8 @@ def mean(self): @inherit_doc -class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid): +class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, + HasInputCols, HasOutputCols): """ .. note:: Experimental @@ -1165,7 +1166,7 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid): >>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]), ... key=lambda x: x[0]) [(0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)] - >>> inverter = IndexToString(inputCol="indexed", outputCol="label2", labels=model.labels()) + >>> inverter = IndexToString(inputCol="indexed", outputCol="label2", labels=model.labels()[0]) >>> itd = inverter.transform(td) >>> sorted(set([(i[0], str(i[1])) for i in itd.select(itd.id, itd.label2).collect()]), ... key=lambda x: x[0]) @@ -1173,9 +1174,11 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid): """ @keyword_only - def __init__(self, inputCol=None, outputCol=None, handleInvalid="error"): + def __init__(self, inputCol=None, outputCol=None, handleInvalid="error", + inputCols=None, outputCols=None): """ - __init__(self, inputCol=None, outputCol=None, handleInvalid="error") + __init__(self, inputCol=None, outputCol=None, handleInvalid="error", + inputCols=None, outputCols=None) """ super(StringIndexer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid) @@ -1184,13 +1187,20 @@ def __init__(self, inputCol=None, outputCol=None, handleInvalid="error"): self.setParams(**kwargs) @keyword_only - def setParams(self, inputCol=None, outputCol=None, handleInvalid="error"): + def setParams(self, inputCol=None, outputCol=None, handleInvalid="error", + inputCols=None, outputCols=None): """ - setParams(self, inputCol=None, outputCol=None, handleInvalid="error") + setParams(self, inputCol=None, outputCol=None, handleInvalid="error", + inputCols=None, outputCols=None) Sets params for this StringIndexer. """ kwargs = self.setParams._input_kwargs - return self._set(**kwargs) + self._set(**kwargs) + if not self.isSet(self.inputCols): + self.setInputCols([inputCol]) + if not self.isSet(self.outputCols): + self.setOutputCols([outputCol]) + return self def _create_model(self, java_model): return StringIndexerModel(java_model) diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 45a94e9c32962..28638613cab6c 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -117,6 +117,7 @@ def get$Name(self): ("inputCol", "input column name", None), ("inputCols", "input column names", None), ("outputCol", "output column name", "self.uid + '__output'"), + ("outputCols", "output column names", None), ("numFeatures", "number of features", None), ("checkpointInterval", "checkpoint interval (>= 1)", None), ("seed", "random seed", "hash(type(self).__name__)"), diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 8c438bc74f51f..b8732625401e9 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -296,6 +296,33 @@ def getOutputCol(self): return self.getOrDefault(self.outputCol) +class HasOutputCols(Params): + """ + Mixin for param outputCols: output column names. + """ + + # a placeholder to make it appear in the generated doc + outputCols = Param(Params._dummy(), "outputCols", "output column names") + + def __init__(self): + super(HasOutputCols, self).__init__() + #: param for output column names + self.outputCols = Param(self, "outputCols", "output column names") + + def setOutputCols(self, value): + """ + Sets the value of :py:attr:`outputCols`. + """ + self._paramMap[self.outputCols] = value + return self + + def getOutputCols(self): + """ + Gets the value of outputCols or its default value. + """ + return self.getOrDefault(self.outputCols) + + class HasNumFeatures(Params): """ Mixin for param numFeatures: number of features.