From 429cb520682dcd446f01e204f178b5c0c932e6cf Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 3 Aug 2015 18:32:54 -0700 Subject: [PATCH 01/27] first pass --- .../apache/spark/ml/feature/Interaction.scala | 112 ++++++++++++++++++ .../apache/spark/ml/feature/RFormula.scala | 20 +++- 2 files changed, 129 insertions(+), 3 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala new file mode 100644 index 000000000000..9dbfcdb4c83f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} + +@Experimental +class Interaction(override val uid: String) extends Estimator[PipelineModel] + with HasInputCols with HasOutputCol { + def this() = this(Identifiable.randomUID("interaction")) + + /** @group setParam */ + def setInputCols(values: Array[String]): this.type = set(inputCols, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def fit(dataset: DataFrame): InteractionModel = { + require(isDefined(inputCols), "Input cols must be defined first.") + require(isDefined(outputCol), "Output col must be defined first.") + require($(inputCols).length > 0, "Input cols must have non-zero length.") + + val encoderStages = ArrayBuffer[PipelineStage]() + val tempColumns = ArrayBuffer[String]() + val (factorCols, nonFactorCols) = $(inputCols) + .partition(input => dataset.schema(input).dataType == StringType) + + val encodedFactors: Option[String] = + if (factorCols.length > 0) { + val indexedCols = factorCols.map { input => + val output = input + "_idx_" + uid + encoderStages += new StringIndexer() + .setInputCol(input) + .setOutputCol(output) + tempColumns += output + output + } + val combinedIndexCol = "combined_idx_" + uid + tempColumns += combinedIndexCol + val encodedCol = if (nonFactorCols.length > 0) { + "factors_" + uid + } else { + $(outputCol) + } + encoderStages += new IndexCombiner(indexedCols, combinedIndexCol) + encoderStages += new OneHotEncoder() + .setInputCol(combinedIndex) + .setOutputCol(encodedCol) + Some(encodedCol) + } else { + None + } + + if (nonFactorCols.length > 0) { + // TODO(ekl) scale encodedFactors if exists by these cols + ??? + } + + encoderStages += new ColumnPruner(tempColumns.toSet) + new Pipeline(uid) + .setStages(encoderStages.toArray) + .fit(dataset) + .setParent(this) + } +} + +private class IndexCombiner(inputCols: Array[String], outputCol: String) extends Transformer { + override val uid = Identifiable.randomUID("indexCombiner") + + override def transform(dataset: DataFrame): DataFrame = { + val cardinalities = inputCols.map(col => + Attribute.fromStructField(dataset.schema(col)).asInstanceOf[NominalAttribute].values.length) + val combiner = udf { cols: Array[Double] => + var offset = 1 + var res = cols(0) + var i = 0 + while (i < cols.length) { + res += cols(i) * offset + offset *= cardinalities(i) + i += 1 + } + res + } + dataset.select("*", combiner(array(inputCols.map(dataset): _*)).as(outputCol)) + } + + override def transformSchema(schema: StructType): StructType = { + StructType(schema.fields :+ StructField(outputCol, DoubleType, false)) + } + + override def copy(extra: ParamMap): ColumnPruner = defaultCopy(extra) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index d5360c9217ea..1b9f1833fac5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -86,7 +86,21 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R val encoderStages = ArrayBuffer[PipelineStage]() val tempColumns = ArrayBuffer[String]() val takenNames = mutable.Set(dataset.columns: _*) - val encodedTerms = resolvedFormula.terms.map { term => + val interactionTerms = resolvedFormula.interactions.map { interaction => + val outputCol = { + var tmp = interaction.term + while (takenNames.contains(tmp)) { + tmp += "_" + } + tmp + } + takenNames.add(outputCol) + encoderStages += new Interaction() + .setInputCols(interaction.inputs) + .setOutputCol(outputCol) + outputCol + } + val standaloneTerms = resolvedFormula.terms.map { term => dataset.schema(term) match { case column if column.dataType == StringType => val indexCol = term + "_idx_" + uid @@ -109,7 +123,7 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R } } encoderStages += new VectorAssembler(uid) - .setInputCols(encodedTerms.toArray) + .setInputCols((interactionTerms ++ standaloneTerms).toArray) .setOutputCol($(featuresCol)) encoderStages += new ColumnPruner(tempColumns.toSet) val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset) @@ -203,7 +217,7 @@ class RFormulaModel private[feature]( * Utility transformer for removing temporary columns from a DataFrame. * TODO(ekl) make this a public transformer */ -private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer { +private[feature] class ColumnPruner(columnsToPrune: Set[String]) extends Transformer { override val uid = Identifiable.randomUID("columnPruner") override def transform(dataset: DataFrame): DataFrame = { From 0ece16cd2e30403f4fa8c5473e6cb6e7930ae52f Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 3 Aug 2015 20:29:16 -0700 Subject: [PATCH 02/27] compiles now --- .../apache/spark/ml/feature/Interaction.scala | 51 ++++++++++++++----- .../apache/spark/ml/feature/RFormula.scala | 4 +- .../spark/ml/feature/RFormulaParser.scala | 5 +- 3 files changed, 42 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index 9dbfcdb4c83f..480f7bb5bedc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -19,27 +19,31 @@ package org.apache.spark.ml.feature import scala.collection.mutable.ArrayBuffer +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.Identifiable import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} +import org.apache.spark.mllib.linalg.VectorUDT +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ @Experimental class Interaction(override val uid: String) extends Estimator[PipelineModel] with HasInputCols with HasOutputCol { + def this() = this(Identifiable.randomUID("interaction")) /** @group setParam */ - def setInputCols(values: Array[String]): this.type = set(inputCols, value) + def setInputCols(values: Array[String]): this.type = set(inputCols, values) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def fit(dataset: DataFrame): InteractionModel = { - require(isDefined(inputCols), "Input cols must be defined first.") - require(isDefined(outputCol), "Output col must be defined first.") - require($(inputCols).length > 0, "Input cols must have non-zero length.") - + override def fit(dataset: DataFrame): PipelineModel = { + checkParams() val encoderStages = ArrayBuffer[PipelineStage]() val tempColumns = ArrayBuffer[String]() val (factorCols, nonFactorCols) = $(inputCols) @@ -55,14 +59,14 @@ class Interaction(override val uid: String) extends Estimator[PipelineModel] tempColumns += output output } - val combinedIndexCol = "combined_idx_" + uid - tempColumns += combinedIndexCol + val combinedIndex = "combined_idx_" + uid + tempColumns += combinedIndex val encodedCol = if (nonFactorCols.length > 0) { "factors_" + uid } else { $(outputCol) } - encoderStages += new IndexCombiner(indexedCols, combinedIndexCol) + encoderStages += new IndexCombiner(indexedCols, combinedIndex) encoderStages += new OneHotEncoder() .setInputCol(combinedIndex) .setOutputCol(encodedCol) @@ -82,6 +86,24 @@ class Interaction(override val uid: String) extends Estimator[PipelineModel] .fit(dataset) .setParent(this) } + + // optimistic schema; does not contain any ML attributes + override def transformSchema(schema: StructType): StructType = { + checkParams() + if ($(inputCols).exists(col => schema(col).dataType == StringType)) { + StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, false)) + } else { + StructType(schema.fields :+ StructField($(outputCol), DoubleType, false)) + } + } + + override def copy(extra: ParamMap): Interaction = defaultCopy(extra) + + private def checkParams(): Unit = { + require(isDefined(inputCols), "Input cols must be defined first.") + require(isDefined(outputCol), "Output col must be defined first.") + require($(inputCols).length > 0, "Input cols must have non-zero length.") + } } private class IndexCombiner(inputCols: Array[String], outputCol: String) extends Transformer { @@ -89,10 +111,11 @@ private class IndexCombiner(inputCols: Array[String], outputCol: String) extends override def transform(dataset: DataFrame): DataFrame = { val cardinalities = inputCols.map(col => - Attribute.fromStructField(dataset.schema(col)).asInstanceOf[NominalAttribute].values.length) - val combiner = udf { cols: Array[Double] => + Attribute.fromStructField(dataset.schema(col)) + .asInstanceOf[NominalAttribute].values.get.length) + val combiner = udf { cols: Seq[Double] => var offset = 1 - var res = cols(0) + var res = 0.0 var i = 0 while (i < cols.length) { res += cols(i) * offset @@ -101,12 +124,12 @@ private class IndexCombiner(inputCols: Array[String], outputCol: String) extends } res } - dataset.select("*", combiner(array(inputCols.map(dataset): _*)).as(outputCol)) + dataset.select(col("*"), combiner(array(inputCols.map(dataset(_)): _*)).as(outputCol)) } override def transformSchema(schema: StructType): StructType = { StructType(schema.fields :+ StructField(outputCol, DoubleType, false)) } - override def copy(extra: ParamMap): ColumnPruner = defaultCopy(extra) + override def copy(extra: ParamMap): IndexCombiner = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 1b9f1833fac5..8a6cce5558b2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -88,7 +88,7 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R val takenNames = mutable.Set(dataset.columns: _*) val interactionTerms = resolvedFormula.interactions.map { interaction => val outputCol = { - var tmp = interaction.term + var tmp = interaction.mkString(":") while (takenNames.contains(tmp)) { tmp += "_" } @@ -96,7 +96,7 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R } takenNames.add(outputCol) encoderStages += new Interaction() - .setInputCols(interaction.inputs) + .setInputCols(interaction) .setOutputCol(outputCol) outputCol } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala index 1ca3b92a7d92..8261e946dd4a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala @@ -51,7 +51,7 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) { } case _: Intercept => } - ResolvedRFormula(label.value, includedTerms.distinct) + ResolvedRFormula(label.value, includedTerms.distinct, Nil) } /** Whether this formula specifies fitting with an intercept term. */ @@ -79,7 +79,8 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) { /** * Represents a fully evaluated and simplified R formula. */ -private[ml] case class ResolvedRFormula(label: String, terms: Seq[String]) +private[ml] case class ResolvedRFormula( + label: String, terms: Seq[String], interactions: Seq[Array[String]]) /** * R formula terms. See the R formula docs here for more information: From ab2a3477512aba52f69f3ccd0bfe620b2da0cb39 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 4 Aug 2015 18:14:31 -0700 Subject: [PATCH 03/27] combiner --- .../apache/spark/ml/feature/Interaction.scala | 30 ++++++++++++++----- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index 480f7bb5bedc..f3f7973660c6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -110,21 +110,27 @@ private class IndexCombiner(inputCols: Array[String], outputCol: String) extends override val uid = Identifiable.randomUID("indexCombiner") override def transform(dataset: DataFrame): DataFrame = { - val cardinalities = inputCols.map(col => - Attribute.fromStructField(dataset.schema(col)) - .asInstanceOf[NominalAttribute].values.get.length) - val combiner = udf { cols: Seq[Double] => + val inputMetadata = inputCols.map(col => + Attribute.fromStructField(dataset.schema(col)).asInstanceOf[NominalAttribute]) + val cardinalities = inputMetadata.map(_.values.get.length) + val combiner = udf { values: Seq[Double] => var offset = 1 var res = 0.0 var i = 0 - while (i < cols.length) { - res += cols(i) * offset + while (i < values.length) { + res += values(i) * offset offset *= cardinalities(i) i += 1 } res } - dataset.select(col("*"), combiner(array(inputCols.map(dataset(_)): _*)).as(outputCol)) + val metadata = NominalAttribute.defaultAttr + .withName(outputCol) + .withValues(combineLabels(inputMetadata.map(_.values.get))) + .toMetadata() + dataset.select( + col("*"), + combiner(array(inputCols.map(dataset(_)): _*)).as(outputCol, metadata)) } override def transformSchema(schema: StructType): StructType = { @@ -132,4 +138,14 @@ private class IndexCombiner(inputCols: Array[String], outputCol: String) extends } override def copy(extra: ParamMap): IndexCombiner = defaultCopy(extra) + + private def combineLabels(labels: Array[Array[String]]): Array[String] = { + if (labels.length <= 1) { + labels.head + } else { + combineLabels(labels.tail).flatMap { rest => + labels.head.map(l => l + ":" + rest) + } + } + } } From a3623aa6cbb2e248b45b489cfb216c9d87bc7c86 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 4 Aug 2015 20:32:39 -0700 Subject: [PATCH 04/27] attribute generation --- .../apache/spark/ml/feature/Interaction.scala | 95 +++++++++++++++---- .../apache/spark/ml/feature/RFormula.scala | 26 ++--- .../spark/ml/feature/VectorAssembler.scala | 10 +- .../spark/ml/feature/RFormulaSuite.scala | 4 +- 4 files changed, 92 insertions(+), 43 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index f3f7973660c6..a7629672f384 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -17,15 +17,15 @@ package org.apache.spark.ml.feature -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, ArrayBuilder} import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute, NumericAttribute} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.Identifiable import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} -import org.apache.spark.mllib.linalg.VectorUDT +import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -66,7 +66,7 @@ class Interaction(override val uid: String) extends Estimator[PipelineModel] } else { $(outputCol) } - encoderStages += new IndexCombiner(indexedCols, combinedIndex) + encoderStages += new IndexCombiner(indexedCols, factorCols, combinedIndex) encoderStages += new OneHotEncoder() .setInputCol(combinedIndex) .setOutputCol(encodedCol) @@ -76,8 +76,10 @@ class Interaction(override val uid: String) extends Estimator[PipelineModel] } if (nonFactorCols.length > 0) { - // TODO(ekl) scale encodedFactors if exists by these cols - ??? + encoderStages += new NumericInteraction(nonFactorCols, encodedFactors, $(outputCol)) + if (encodedFactors.isDefined) { + tempColumns += encodedFactors.get + } } encoderStages += new ColumnPruner(tempColumns.toSet) @@ -91,9 +93,9 @@ class Interaction(override val uid: String) extends Estimator[PipelineModel] override def transformSchema(schema: StructType): StructType = { checkParams() if ($(inputCols).exists(col => schema(col).dataType == StringType)) { - StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, false)) + StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, true)) } else { - StructType(schema.fields :+ StructField($(outputCol), DoubleType, false)) + StructType(schema.fields :+ StructField($(outputCol), DoubleType, true)) } } @@ -103,10 +105,14 @@ class Interaction(override val uid: String) extends Estimator[PipelineModel] require(isDefined(inputCols), "Input cols must be defined first.") require(isDefined(outputCol), "Output col must be defined first.") require($(inputCols).length > 0, "Input cols must have non-zero length.") + require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.") } } -private class IndexCombiner(inputCols: Array[String], outputCol: String) extends Transformer { +private class IndexCombiner( + inputCols: Array[String], attrNames: Array[String], outputCol: String) + extends Transformer { + override val uid = Identifiable.randomUID("indexCombiner") override def transform(dataset: DataFrame): DataFrame = { @@ -126,7 +132,7 @@ private class IndexCombiner(inputCols: Array[String], outputCol: String) extends } val metadata = NominalAttribute.defaultAttr .withName(outputCol) - .withValues(combineLabels(inputMetadata.map(_.values.get))) + .withValues(generateAttrNames(inputMetadata, attrNames)) .toMetadata() dataset.select( col("*"), @@ -134,18 +140,75 @@ private class IndexCombiner(inputCols: Array[String], outputCol: String) extends } override def transformSchema(schema: StructType): StructType = { - StructType(schema.fields :+ StructField(outputCol, DoubleType, false)) + StructType(schema.fields :+ StructField(outputCol, DoubleType, true)) } override def copy(extra: ParamMap): IndexCombiner = defaultCopy(extra) - private def combineLabels(labels: Array[Array[String]]): Array[String] = { - if (labels.length <= 1) { - labels.head + private def generateAttrNames( + attrs: Array[NominalAttribute], names: Array[String]): Array[String] = { + val colName = names.head + val attrNames = attrs.head.values.get.map(colName + "_" + _) + if (attrs.length <= 1) { + attrNames + } else { + generateAttrNames(attrs.tail, names.tail).flatMap { rest => + attrNames.map(n => n + ":" + rest) + } + } + } +} + +private class NumericInteraction( + inputCols: Array[String], vectorCol: Option[String], outputCol: String) + extends Transformer { + + override val uid = Identifiable.randomUID("indexCombiner") + + override def transform(dataset: DataFrame): DataFrame = { + if (vectorCol.isDefined) { + val scale = udf { (vec: Vector, scalars: Seq[Double]) => + val k = scalars.reduce(_ * _) + val indices = ArrayBuilder.make[Int] + val values = ArrayBuilder.make[Double] + vec.foreachActive { case (i, v) => + if (v != 0.0) { + indices += i + values += v * k + } + } + Vectors.sparse(vec.size, indices.result(), values.result()).compressed + } + val group = AttributeGroup.fromStructField(dataset.schema(vectorCol.get)) + val attrs = group.attributes.get.map { attr => + attr.withName(attr.name.get + ":" + inputCols.mkString(":")) + } + val metadata = new AttributeGroup(outputCol, attrs).toMetadata() + dataset.select( + col("*"), + scale( + col(vectorCol.get), + array(inputCols.map(dataset(_).cast(DoubleType)): _*)).as(outputCol, metadata)) } else { - combineLabels(labels.tail).flatMap { rest => - labels.head.map(l => l + ":" + rest) + val multiply = udf { values: Seq[Double] => + values.reduce(_ * _) } + val metadata = NumericAttribute.defaultAttr + .withName(inputCols.mkString(":")) + .toMetadata() + dataset.select( + col("*"), + multiply(array(inputCols.map(dataset(_).cast(DoubleType)): _*)).as(outputCol, metadata)) } } + + override def transformSchema(schema: StructType): StructType = { + if (vectorCol.isDefined) { + StructType(schema.fields :+ StructField(outputCol, new VectorUDT, true)) + } else { + StructType(schema.fields :+ StructField(outputCol, DoubleType, true)) + } + } + + override def copy(extra: ParamMap): IndexCombiner = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 8a6cce5558b2..aacec7bec6a0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -81,14 +81,12 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R require(isDefined(formula), "Formula must be defined first.") val parsedFormula = RFormulaParser.parse($(formula)) val resolvedFormula = parsedFormula.resolve(dataset.schema) - // StringType terms and terms representing interactions need to be encoded before assembly. - // TODO(ekl) add support for feature interactions val encoderStages = ArrayBuffer[PipelineStage]() val tempColumns = ArrayBuffer[String]() val takenNames = mutable.Set(dataset.columns: _*) - val interactionTerms = resolvedFormula.interactions.map { interaction => + def encodeInteraction(terms: Array[String]): String = { val outputCol = { - var tmp = interaction.mkString(":") + var tmp = terms.mkString(":") while (takenNames.contains(tmp)) { tmp += "_" } @@ -96,28 +94,16 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R } takenNames.add(outputCol) encoderStages += new Interaction() - .setInputCols(interaction) + .setInputCols(terms) .setOutputCol(outputCol) + tempColumns += outputCol outputCol } + val interactionTerms = resolvedFormula.interactions.map(encodeInteraction) val standaloneTerms = resolvedFormula.terms.map { term => dataset.schema(term) match { case column if column.dataType == StringType => - val indexCol = term + "_idx_" + uid - val encodedCol = { - var tmp = term - while (takenNames.contains(tmp)) { - tmp += "_" - } - tmp - } - takenNames.add(indexCol) - takenNames.add(encodedCol) - encoderStages += new StringIndexer().setInputCol(term).setOutputCol(indexCol) - encoderStages += new OneHotEncoder().setInputCol(indexCol).setOutputCol(encodedCol) - tempColumns += indexCol - tempColumns += encodedCol - encodedCol + encodeInteraction(Array(term)) case _ => term } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 086917fa680f..9523682c1643 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -71,12 +71,12 @@ class VectorAssembler(override val uid: String) if (group.attributes.isDefined) { // If attributes are defined, copy them with updated names. group.attributes.get.map { attr => - if (attr.name.isDefined) { - // TODO: Define a rigorous naming scheme. - attr.withName(c + "_" + attr.name.get) - } else { +// if (attr.name.isDefined) { +// // TODO: Define a rigorous naming scheme. +// attr.withName(c + "_" + attr.name.get) +// } else { attr - } +// } } } else { // Otherwise, treat all attributes as numeric. If we cannot get the number of attributes diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 6aed3243afce..b3422cbd74e9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -118,8 +118,8 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { val expectedAttrs = new AttributeGroup( "features", Array( - new BinaryAttribute(Some("a__bar"), Some(1)), - new BinaryAttribute(Some("a__foo"), Some(2)), + new BinaryAttribute(Some("a_bar"), Some(1)), + new BinaryAttribute(Some("a_foo"), Some(2)), new NumericAttribute(Some("b"), Some(3)))) assert(attrs === expectedAttrs) } From a12e58e58e7a9c44611124d21c6172cb0a4c6b53 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 5 Aug 2015 15:58:14 -0700 Subject: [PATCH 05/27] Wed Aug 5 15:58:14 PDT 2015 --- .../apache/spark/ml/feature/Interaction.scala | 21 +++++++++---- .../apache/spark/ml/feature/RFormula.scala | 20 +++++++------ .../spark/ml/feature/RFormulaParser.scala | 30 ++++++++++++------- 3 files changed, 46 insertions(+), 25 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index a7629672f384..439aacaefb65 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -163,18 +163,23 @@ private class NumericInteraction( inputCols: Array[String], vectorCol: Option[String], outputCol: String) extends Transformer { - override val uid = Identifiable.randomUID("indexCombiner") + override val uid = Identifiable.randomUID("numericInteraction") override def transform(dataset: DataFrame): DataFrame = { if (vectorCol.isDefined) { val scale = udf { (vec: Vector, scalars: Seq[Double]) => - val k = scalars.reduce(_ * _) + var x = 1.0 + var i = 0 + while (i < scalars.length) { + x *= scalars(i) + i += 1 + } val indices = ArrayBuilder.make[Int] val values = ArrayBuilder.make[Double] vec.foreachActive { case (i, v) => if (v != 0.0) { indices += i - values += v * k + values += v * x } } Vectors.sparse(vec.size, indices.result(), values.result()).compressed @@ -191,7 +196,13 @@ private class NumericInteraction( array(inputCols.map(dataset(_).cast(DoubleType)): _*)).as(outputCol, metadata)) } else { val multiply = udf { values: Seq[Double] => - values.reduce(_ * _) + var x = 1.0 + var i = 0 + while (i < values.length) { + x *= values(i) + i += 1 + } + x } val metadata = NumericAttribute.defaultAttr .withName(inputCols.mkString(":")) @@ -210,5 +221,5 @@ private class NumericInteraction( } } - override def copy(extra: ParamMap): IndexCombiner = defaultCopy(extra) + override def copy(extra: ParamMap): NumericInteraction = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index aacec7bec6a0..48fbe32e9ee7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -99,17 +99,19 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R tempColumns += outputCol outputCol } - val interactionTerms = resolvedFormula.interactions.map(encodeInteraction) - val standaloneTerms = resolvedFormula.terms.map { term => - dataset.schema(term) match { - case column if column.dataType == StringType => - encodeInteraction(Array(term)) - case _ => - term - } + val encodedCols = resolvedFormula.terms.map { + case ColumnInteraction(values) => + encodeInteraction(values) + case ColumnRef(value) => + dataset.schema(value) match { + case column if column.dataType == StringType => + encodeInteraction(Array(value)) + case _ => + value + } } encoderStages += new VectorAssembler(uid) - .setInputCols((interactionTerms ++ standaloneTerms).toArray) + .setInputCols(encodedCols.toArray) .setOutputCol($(featuresCol)) encoderStages += new ColumnPruner(tempColumns.toSet) val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala index 8261e946dd4a..6f6f9d7f989e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala @@ -31,27 +31,30 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) { * of the special '.' term. Duplicate terms will be removed during resolution. */ def resolve(schema: StructType): ResolvedRFormula = { - var includedTerms = Seq[String]() + var includedTerms = Seq[SimpleTerm]() terms.foreach { + case term: SimpleTerm => + includedTerms :+= term case Dot => - includedTerms ++= simpleTypes(schema).filter(_ != label.value) - case ColumnRef(value) => - includedTerms :+= value + includedTerms ++= simpleTypes(schema).filter(_ != label.value).map(ColumnRef) case Deletion(term: Term) => term match { - case ColumnRef(value) => - includedTerms = includedTerms.filter(_ != value) + case inner: SimpleTerm => + includedTerms = includedTerms.filter(_ != inner) case Dot => // e.g. "- .", which removes all first-order terms val fromSchema = simpleTypes(schema) - includedTerms = includedTerms.filter(fromSchema.contains(_)) + includedTerms = includedTerms.filter { + case t: ColumnRef => !includedTerms.contains(t.value) + case _ => true + } case _: Deletion => assert(false, "Deletion terms cannot be nested") case _: Intercept => } case _: Intercept => } - ResolvedRFormula(label.value, includedTerms.distinct, Nil) + ResolvedRFormula(label.value, includedTerms.distinct) } /** Whether this formula specifies fitting with an intercept term. */ @@ -79,8 +82,7 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) { /** * Represents a fully evaluated and simplified R formula. */ -private[ml] case class ResolvedRFormula( - label: String, terms: Seq[String], interactions: Seq[Array[String]]) +private[ml] case class ResolvedRFormula(label: String, terms: Seq[SimpleTerm]) /** * R formula terms. See the R formula docs here for more information: @@ -88,11 +90,17 @@ private[ml] case class ResolvedRFormula( */ private[ml] sealed trait Term +/** A standalone term after formula simplification, e.g. single variable or interaction. */ +private[ml] sealed trait SimpleTerm + /* R formula reference to all available columns, e.g. "." in a formula */ private[ml] case object Dot extends Term /* R formula reference to a column, e.g. "+ Species" in a formula */ -private[ml] case class ColumnRef(value: String) extends Term +private[ml] case class ColumnRef(value: String) extends Term with SimpleTerm + +/* R formula interaction of several columns, e.g. "Sepal_Length:Species" in a formula */ +private[ml] case class ColumnInteraction(values: Array[String]) extends Term with SimpleTerm /* R formula intercept toggle, e.g. "+ 0" in a formula */ private[ml] case class Intercept(enabled: Boolean) extends Term From dc8801a31cd7806ccd1e4ef902568e1d9bc85e94 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 5 Aug 2015 19:59:50 -0700 Subject: [PATCH 06/27] Wed Aug 5 19:59:50 PDT 2015 --- .../apache/spark/ml/feature/RFormula.scala | 12 ++-- .../spark/ml/feature/RFormulaParser.scala | 58 ++++++++++++++----- .../ml/feature/RFormulaParserSuite.scala | 20 ++++++- 3 files changed, 68 insertions(+), 22 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 48fbe32e9ee7..855486f54a7a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -84,7 +84,7 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R val encoderStages = ArrayBuffer[PipelineStage]() val tempColumns = ArrayBuffer[String]() val takenNames = mutable.Set(dataset.columns: _*) - def encodeInteraction(terms: Array[String]): String = { + def encodeInteraction(terms: Seq[String]): String = { val outputCol = { var tmp = terms.mkString(":") while (takenNames.contains(tmp)) { @@ -94,21 +94,21 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R } takenNames.add(outputCol) encoderStages += new Interaction() - .setInputCols(terms) + .setInputCols(terms.toArray) .setOutputCol(outputCol) tempColumns += outputCol outputCol } val encodedCols = resolvedFormula.terms.map { - case ColumnInteraction(values) => - encodeInteraction(values) - case ColumnRef(value) => + case terms @ Seq(value) => dataset.schema(value) match { case column if column.dataType == StringType => - encodeInteraction(Array(value)) + encodeInteraction(terms) case _ => value } + case terms => + encodeInteraction(terms) } encoderStages += new VectorAssembler(uid) .setInputCols(encodedCols.toArray) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala index 6f6f9d7f989e..13c557deb496 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala @@ -31,21 +31,26 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) { * of the special '.' term. Duplicate terms will be removed during resolution. */ def resolve(schema: StructType): ResolvedRFormula = { - var includedTerms = Seq[SimpleTerm]() + var includedTerms = Seq[Seq[String]]() terms.foreach { - case term: SimpleTerm => - includedTerms :+= term + case term: ColumnRef => + includedTerms :+= Seq(term.value) + case ColumnInteraction(terms) => + includedTerms ++= expandInteraction(schema, terms) case Dot => - includedTerms ++= simpleTypes(schema).filter(_ != label.value).map(ColumnRef) + includedTerms ++= simpleTypes(schema).filter(_ != label.value).map(Seq(_)) case Deletion(term: Term) => term match { - case inner: SimpleTerm => - includedTerms = includedTerms.filter(_ != inner) + case inner: ColumnRef => + includedTerms = includedTerms.filter(_ != Seq(inner.value)) + case ColumnInteraction(terms) => + val fromInteraction = expandInteraction(schema, terms) + includedTerms = includedTerms.filter(!fromInteraction.contains(_)) case Dot => // e.g. "- .", which removes all first-order terms val fromSchema = simpleTypes(schema) includedTerms = includedTerms.filter { - case t: ColumnRef => !includedTerms.contains(t.value) + case Seq(t) => !fromSchema.contains(t) case _ => true } case _: Deletion => @@ -70,6 +75,23 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) { intercept } + private def expandInteraction( + schema: StructType, terms: Seq[InteractionComponent]): Seq[Seq[String]] = { + if (terms.isEmpty) { + Seq(Nil) + } else { + terms.head match { + case Dot => + val rest = expandInteraction(schema, terms.tail) + simpleTypes(schema).filter(_ != label.value).map { t => + Seq(t) ++ rest + } + case ColumnRef(value) => + (Seq(value) ++ expandInteraction(schema, terms.tail)) + } + } + } + // the dot operator excludes complex column types private def simpleTypes(schema: StructType): Seq[String] = { schema.fields.filter(_.dataType match { @@ -82,7 +104,7 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) { /** * Represents a fully evaluated and simplified R formula. */ -private[ml] case class ResolvedRFormula(label: String, terms: Seq[SimpleTerm]) +private[ml] case class ResolvedRFormula(label: String, terms: Seq[Seq[String]]) /** * R formula terms. See the R formula docs here for more information: @@ -90,17 +112,17 @@ private[ml] case class ResolvedRFormula(label: String, terms: Seq[SimpleTerm]) */ private[ml] sealed trait Term -/** A standalone term after formula simplification, e.g. single variable or interaction. */ -private[ml] sealed trait SimpleTerm +/** A term that may be part of an interaction, e.g. 'x' in 'x:y' */ +private[ml] sealed trait InteractionComponent extends Term /* R formula reference to all available columns, e.g. "." in a formula */ -private[ml] case object Dot extends Term +private[ml] case object Dot extends InteractionComponent /* R formula reference to a column, e.g. "+ Species" in a formula */ -private[ml] case class ColumnRef(value: String) extends Term with SimpleTerm +private[ml] case class ColumnRef(value: String) extends InteractionComponent /* R formula interaction of several columns, e.g. "Sepal_Length:Species" in a formula */ -private[ml] case class ColumnInteraction(values: Array[String]) extends Term with SimpleTerm +private[ml] case class ColumnInteraction(terms: Seq[InteractionComponent]) extends Term /* R formula intercept toggle, e.g. "+ 0" in a formula */ private[ml] case class Intercept(enabled: Boolean) extends Term @@ -118,7 +140,15 @@ private[ml] object RFormulaParser extends RegexParsers { def columnRef: Parser[ColumnRef] = "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r ^^ { case a => ColumnRef(a) } - def term: Parser[Term] = intercept | columnRef | "\\.".r ^^ { case _ => Dot } + def dot: Parser[InteractionComponent] = "\\.".r ^^ { case _ => Dot } + + def interaction: Parser[List[InteractionComponent]] = repsep(columnRef | dot, ":") + + def term: Parser[Term] = intercept | + interaction ^^ { + case Seq(term) => term + case terms => ColumnInteraction(terms) + } def terms: Parser[List[Term]] = (term ~ rep("+" ~ term | "-" ~ term)) ^^ { case op ~ list => list.foldLeft(List(op)) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala index 436e66bab09b..a5489f772777 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala @@ -28,13 +28,21 @@ class RFormulaParserSuite extends SparkFunSuite { schema: StructType = null) { val resolved = RFormulaParser.parse(formula).resolve(schema) assert(resolved.label == label) - assert(resolved.terms == terms) + val simpleTerms = terms.map { t => + if (t.contains(":")) { + t.split(":").toSeq + } else { + Seq(t) + } + } + assert(resolved.terms == simpleTerms) } test("parse simple formulas") { checkParse("y ~ x", "y", Seq("x")) checkParse("y ~ x + x", "y", Seq("x")) - checkParse("y ~ ._foo ", "y", Seq("._foo")) + checkParse("y~x+z", "y", Seq("x", "z")) + checkParse("y ~ ._fo..o ", "y", Seq("._fo..o")) checkParse("resp ~ A_VAR + B + c123", "resp", Seq("A_VAR", "B", "c123")) } @@ -79,4 +87,12 @@ class RFormulaParserSuite extends SparkFunSuite { assert(!RFormulaParser.parse("a ~ b - 1").hasIntercept) assert(!RFormulaParser.parse("a ~ b + 1 - 1").hasIntercept) } + + test("parse interactions") { + checkParse("y ~ a:b", "y", Seq("a:b")) + checkParse("y ~ ._a:._x", "y", Seq("._a:._x")) + checkParse("y ~ foo:bar", "y", Seq("foo:bar")) + checkParse("y ~ a : b : c", "y", Seq("a:b:c")) + checkParse("y ~ q + a:b:c + b:c + c:d + z", "y", Seq("q", "a:b:c", "b:c", "c:d", "z")) + } } From 2957cb686264303344c02b96e5ce166a7a66a959 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 5 Aug 2015 22:19:28 -0700 Subject: [PATCH 07/27] fix parser --- .../spark/ml/feature/RFormulaParser.scala | 39 +++++++---- .../ml/feature/RFormulaParserSuite.scala | 66 +++++++++++++++++++ 2 files changed, 92 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala index 13c557deb496..ba256e9e4ca6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.feature +import scala.collection.mutable import scala.util.parsing.combinator.RegexParsers import org.apache.spark.mllib.linalg.VectorUDT @@ -31,6 +32,7 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) { * of the special '.' term. Duplicate terms will be removed during resolution. */ def resolve(schema: StructType): ResolvedRFormula = { + lazy val dotTerms = expandDot(schema) var includedTerms = Seq[Seq[String]]() terms.foreach { case term: ColumnRef => @@ -38,19 +40,18 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) { case ColumnInteraction(terms) => includedTerms ++= expandInteraction(schema, terms) case Dot => - includedTerms ++= simpleTypes(schema).filter(_ != label.value).map(Seq(_)) + includedTerms ++= dotTerms.map(Seq(_)) case Deletion(term: Term) => term match { case inner: ColumnRef => includedTerms = includedTerms.filter(_ != Seq(inner.value)) case ColumnInteraction(terms) => - val fromInteraction = expandInteraction(schema, terms) - includedTerms = includedTerms.filter(!fromInteraction.contains(_)) + val fromInteraction = expandInteraction(schema, terms).map(_.toSet) + includedTerms = includedTerms.filter(t => !fromInteraction.contains(t.toSet)) case Dot => // e.g. "- .", which removes all first-order terms - val fromSchema = simpleTypes(schema) includedTerms = includedTerms.filter { - case Seq(t) => !fromSchema.contains(t) + case Seq(t) => !dotTerms.contains(t) case _ => true } case _: Deletion => @@ -75,29 +76,41 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) { intercept } + // expands the Dot operators in interaction terms private def expandInteraction( schema: StructType, terms: Seq[InteractionComponent]): Seq[Seq[String]] = { if (terms.isEmpty) { Seq(Nil) } else { - terms.head match { + val rest = expandInteraction(schema, terms.tail) + val validInteractions = (terms.head match { case Dot => - val rest = expandInteraction(schema, terms.tail) - simpleTypes(schema).filter(_ != label.value).map { t => - Seq(t) ++ rest + expandDot(schema).filter(_ != label.value).flatMap { t => + rest.map { r => + Seq(t) ++ r + } } case ColumnRef(value) => - (Seq(value) ++ expandInteraction(schema, terms.tail)) - } + rest.map(Seq(value) ++ _) + }).map(_.distinct) + // Deduplicates feature interactions, for example, a:b is the same as b:a. + var seen = mutable.Set[Set[String]]() + validInteractions.flatMap { + case t if seen.contains(t.toSet) => + None + case t => + seen += t.toSet + Some(t) + }.sortBy(_.length) } } // the dot operator excludes complex column types - private def simpleTypes(schema: StructType): Seq[String] = { + private def expandDot(schema: StructType): Seq[String] = { schema.fields.filter(_.dataType match { case _: NumericType | StringType | BooleanType | _: VectorUDT => true case _ => false - }).map(_.name) + }).map(_.name).filter(_ != label.value) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala index a5489f772777..9d4856047958 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala @@ -95,4 +95,70 @@ class RFormulaParserSuite extends SparkFunSuite { checkParse("y ~ a : b : c", "y", Seq("a:b:c")) checkParse("y ~ q + a:b:c + b:c + c:d + z", "y", Seq("q", "a:b:c", "b:c", "c:d", "z")) } + + test("parse basic interactions with dot") { + val schema = (new StructType) + .add("a", "int", true) + .add("b", "long", false) + .add("c", "string", true) + checkParse("y ~ .:x", "y", Seq("a:x", "b:x", "c:x"), schema) + checkParse("a ~ .:x", "a", Seq("b:x", "c:x"), schema) + checkParse("a ~ x:.", "a", Seq("x:b", "x:c"), schema) + } + + // Test data generated in R with terms.formula(y ~ .:., data = iris) + test("parse all to all iris interactions") { + val schema = (new StructType) + .add("Sepal.Length", "double", true) + .add("Sepal.Width", "double", true) + .add("Petal.Length", "double", true) + .add("Petal.Width", "double", true) + .add("Species", "string", true) + checkParse( + "y ~ .:.", + "y", + Seq( + "Sepal.Length", + "Sepal.Width", + "Petal.Length", + "Petal.Width", + "Species", + "Sepal.Length:Sepal.Width", + "Sepal.Length:Petal.Length", + "Sepal.Length:Petal.Width", + "Sepal.Length:Species", + "Sepal.Width:Petal.Length", + "Sepal.Width:Petal.Width", + "Sepal.Width:Species", + "Petal.Length:Petal.Width", + "Petal.Length:Species", + "Petal.Width:Species"), + schema) + } + + // Test data generated in R with terms.formula(y ~ .:. - Species:., data = iris) + test("parse interaction negation with iris") { + val schema = (new StructType) + .add("Sepal.Length", "double", true) + .add("Sepal.Width", "double", true) + .add("Petal.Length", "double", true) + .add("Petal.Width", "double", true) + .add("Species", "string", true) + checkParse("y ~ .:. - .:.", "y", Nil, schema) + checkParse( + "y ~ .:. - Species:.", + "y", + Seq( + "Sepal.Length", + "Sepal.Width", + "Petal.Length", + "Petal.Width", + "Sepal.Length:Sepal.Width", + "Sepal.Length:Petal.Length", + "Sepal.Length:Petal.Width", + "Sepal.Width:Petal.Length", + "Sepal.Width:Petal.Width", + "Petal.Length:Petal.Width"), + schema) + } } From 478ee8f2e1133901746e68cc202ecfc89de8eaa9 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 5 Aug 2015 22:57:10 -0700 Subject: [PATCH 08/27] add rformula test --- .../apache/spark/ml/feature/Interaction.scala | 12 +++- .../spark/ml/feature/RFormulaParser.scala | 2 +- .../spark/ml/feature/VectorAssembler.scala | 11 +-- .../ml/feature/RFormulaParserSuite.scala | 1 + .../spark/ml/feature/RFormulaSuite.scala | 72 +++++++++++++++++++ 5 files changed, 85 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index 439aacaefb65..bb452c874f16 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -67,9 +67,15 @@ class Interaction(override val uid: String) extends Estimator[PipelineModel] $(outputCol) } encoderStages += new IndexCombiner(indexedCols, factorCols, combinedIndex) - encoderStages += new OneHotEncoder() - .setInputCol(combinedIndex) - .setOutputCol(encodedCol) + encoderStages += { + val encoder = new OneHotEncoder() + .setInputCol(combinedIndex) + .setOutputCol(encodedCol) + if ($(inputCols).length > 1) { + encoder.setDropLast(false) // R includes all columns for interactions. + } + encoder + } Some(encodedCol) } else { None diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala index ba256e9e4ca6..2238d7ca465d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala @@ -153,7 +153,7 @@ private[ml] object RFormulaParser extends RegexParsers { def columnRef: Parser[ColumnRef] = "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r ^^ { case a => ColumnRef(a) } - def dot: Parser[InteractionComponent] = "\\.".r ^^ { case _ => Dot } + def dot: Parser[InteractionComponent] = "\\.".r ^^ { case _ => Dot } def interaction: Parser[List[InteractionComponent]] = repsep(columnRef | dot, ":") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 9523682c1643..aa1e95907986 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -69,15 +69,8 @@ class VectorAssembler(override val uid: String) case _: VectorUDT => val group = AttributeGroup.fromStructField(field) if (group.attributes.isDefined) { - // If attributes are defined, copy them with updated names. - group.attributes.get.map { attr => -// if (attr.name.isDefined) { -// // TODO: Define a rigorous naming scheme. -// attr.withName(c + "_" + attr.name.get) -// } else { - attr -// } - } + // If attributes are defined, copy them. + group.attributes.get } else { // Otherwise, treat all attributes as numeric. If we cannot get the number of attributes // from metadata, check the first row. diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala index 9d4856047958..6d36509773fb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala @@ -104,6 +104,7 @@ class RFormulaParserSuite extends SparkFunSuite { checkParse("y ~ .:x", "y", Seq("a:x", "b:x", "c:x"), schema) checkParse("a ~ .:x", "a", Seq("b:x", "c:x"), schema) checkParse("a ~ x:.", "a", Seq("x:b", "x:c"), schema) + checkParse("a ~ .:b:.:.:c:.", "a", Seq("b:c"), schema) } // Test data generated in R with terms.formula(y ~ .:., data = iris) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index b3422cbd74e9..b89457e8d22a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -123,4 +123,76 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { new NumericAttribute(Some("b"), Some(3)))) assert(attrs === expectedAttrs) } + + test("numeric interaction") { + val formula = new RFormula().setFormula("a ~ b:c") + val original = sqlContext.createDataFrame( + Seq((1, 2, 4), (2, 3, 4)) + ).toDF("a", "b", "c") + val model = formula.fit(original) + val result = model.transform(original) + val expected = sqlContext.createDataFrame( + Seq( + (1, 2, 4, Vectors.dense(8.0), 1.0), + (2, 3, 4, Vectors.dense(12.0), 2.0)) + ).toDF("id", "a", "b", "features", "label") + assert(result.collect() === expected.collect()) + val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute](new NumericAttribute(Some("b:c"), Some(1)))) + assert(attrs === expectedAttrs) + } + + test("numeric:factor interaction") { + val formula = new RFormula().setFormula("id ~ a:b") + val original = sqlContext.createDataFrame( + Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5), (4, "baz", 5), (4, "baz", 5)) + ).toDF("id", "a", "b") + val model = formula.fit(original) + val result = model.transform(original) + val expected = sqlContext.createDataFrame( + Seq( + (1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0), + (2, "bar", 4, Vectors.dense(0.0, 4.0, 0.0), 2.0), + (3, "bar", 5, Vectors.dense(0.0, 5.0, 0.0), 3.0), + (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0), + (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0), + (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0)) + ).toDF("id", "a", "b", "features", "label") + assert(result.collect() === expected.collect()) + val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new BinaryAttribute(Some("a_baz:b"), Some(1)), + new BinaryAttribute(Some("a_bar:b"), Some(2)), + new BinaryAttribute(Some("a_foo:b"), Some(3)))) + assert(attrs === expectedAttrs) + } + + test("factor:factor interaction") { + val formula = new RFormula().setFormula("id ~ a:b") + val original = sqlContext.createDataFrame( + Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")) + ).toDF("id", "a", "b") + val model = formula.fit(original) + val result = model.transform(original) + val expected = sqlContext.createDataFrame( + Seq( + (1, "foo", "zq", Vectors.dense(0.0, 1.0, 0.0, 0.0), 1.0), + (2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0), 2.0), + (3, "bar", "zz", Vectors.dense(0.0, 0.0, 1.0, 0.0), 3.0)) + ).toDF("id", "a", "b", "features", "label") + assert(result.collect() === expected.collect()) + val attrs = AttributeGroup.fromStructField(result.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new BinaryAttribute(Some("a_bar:b_zq"), Some(1)), + new BinaryAttribute(Some("a_foo:b_zq"), Some(2)), + new BinaryAttribute(Some("a_bar:b_zz"), Some(3)), + new BinaryAttribute(Some("a_foo:b_zz"), Some(4)))) + assert(attrs === expectedAttrs) + } } From 3ad5464566076570438001714582082b11981479 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 5 Aug 2015 23:08:34 -0700 Subject: [PATCH 09/27] docs --- .../apache/spark/ml/feature/Interaction.scala | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index bb452c874f16..b51dad2eda74 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -30,6 +30,17 @@ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +/** + * :: Experimental :: + * Implements the transforms required for R-style feature interactions. In summary, once fitted to + * a dataset, this transformer jointly one-hot encodes all input factor columns, and then scales + * the encoded factors by all numeric input columns. If only numeric are columns are specified, + * the output column will simply be a vector containing their product. The last category will be + * preserved during one-hot encoding except when there is no interaction. + * + * See https://stat.ethz.ch/R-manual/R-devel/library/base/html/formula.html for more + * information about factor interactions in R formulae. + */ @Experimental class Interaction(override val uid: String) extends Estimator[PipelineModel] with HasInputCols with HasOutputCol { @@ -115,6 +126,10 @@ class Interaction(override val uid: String) extends Estimator[PipelineModel] } } +/** + * This helper class combines the output of multiple string-indexed columns to simulate + * the joint indexing of tuples containing all the column values. + */ private class IndexCombiner( inputCols: Array[String], attrNames: Array[String], outputCol: String) extends Transformer { @@ -165,6 +180,10 @@ private class IndexCombiner( } } +/** + * This helper class scales the input vector column by the product of the input numeric columns. + * If no vector column is specified, the output is just the product of the numeric columns. + */ private class NumericInteraction( inputCols: Array[String], vectorCol: Option[String], outputCol: String) extends Transformer { From 5f7cb9b505e043039898df9860b999d5382c4ae0 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 5 Aug 2015 23:15:14 -0700 Subject: [PATCH 10/27] Wed Aug 5 23:15:14 PDT 2015 --- .../scala/org/apache/spark/ml/feature/RFormula.scala | 2 +- .../{Interaction.scala => RInteraction.scala} | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) rename mllib/src/main/scala/org/apache/spark/ml/feature/{Interaction.scala => RInteraction.scala} (94%) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 855486f54a7a..4efbfdd2f17c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -93,7 +93,7 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R tmp } takenNames.add(outputCol) - encoderStages += new Interaction() + encoderStages += new RInteraction() .setInputCols(terms.toArray) .setOutputCol(outputCol) tempColumns += outputCol diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala similarity index 94% rename from mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala rename to mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala index b51dad2eda74..83cfbfa077e4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala @@ -33,16 +33,16 @@ import org.apache.spark.sql.types._ /** * :: Experimental :: * Implements the transforms required for R-style feature interactions. In summary, once fitted to - * a dataset, this transformer jointly one-hot encodes all input factor columns, and then scales - * the encoded factors by all numeric input columns. If only numeric are columns are specified, - * the output column will simply be a vector containing their product. The last category will be - * preserved during one-hot encoding except when there is no interaction. + * a dataset, this transformer jointly one-hot encodes all factor input columns, then scales + * the encoded vector by all numeric input columns. If only numeric columns are specified, the + * output column will be a one-length vector containing their product. During one-hot encoding, + * the last category will be preserved unless the interaction is trivial. * * See https://stat.ethz.ch/R-manual/R-devel/library/base/html/formula.html for more * information about factor interactions in R formulae. */ @Experimental -class Interaction(override val uid: String) extends Estimator[PipelineModel] +class RInteraction(override val uid: String) extends Estimator[PipelineModel] with HasInputCols with HasOutputCol { def this() = this(Identifiable.randomUID("interaction")) @@ -116,7 +116,7 @@ class Interaction(override val uid: String) extends Estimator[PipelineModel] } } - override def copy(extra: ParamMap): Interaction = defaultCopy(extra) + override def copy(extra: ParamMap): RInteraction = defaultCopy(extra) private def checkParams(): Unit = { require(isDefined(inputCols), "Input cols must be defined first.") From 4c11a773e74e677f237f23949eeb9dffa8bc43f2 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 5 Aug 2015 23:45:35 -0700 Subject: [PATCH 11/27] small nits --- R/pkg/R/mllib.R | 2 +- R/pkg/inst/tests/test_mllib.R | 10 +++- .../apache/spark/ml/feature/RFormula.scala | 10 ++-- .../spark/ml/feature/RFormulaParser.scala | 50 ++++++++++--------- .../spark/ml/feature/RInteraction.scala | 5 +- .../ml/feature/RFormulaParserSuite.scala | 2 +- .../spark/ml/feature/RFormulaSuite.scala | 18 +++---- python/pyspark/ml/feature.py | 2 +- 8 files changed, 57 insertions(+), 42 deletions(-) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index b524d1fd8749..50bf7669bfff 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -27,7 +27,7 @@ setClass("PipelineModel", representation(model = "jobj")) #' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package. #' #' @param formula A symbolic description of the model to be fitted. Currently only a few formula -#' operators are supported, including '~', '+', '-', and '.'. +#' operators are supported, including '~', '.', ':', '+', and '-'. #' @param data DataFrame for training #' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg. #' @param lambda Regularization parameter diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index f272de78ad4a..032f8ec68b9d 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -49,6 +49,14 @@ test_that("dot minus and intercept vs native glm", { expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) }) +test_that("feature interaction vs native glm", { + training <- createDataFrame(sqlContext, iris) + model <- glm(Sepal_Width ~ Species:Sepal_Length, data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ Species:Sepal.Length, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) +}) + test_that("summary coefficients match with native glm", { training <- createDataFrame(sqlContext, iris) stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training)) @@ -57,5 +65,5 @@ test_that("summary coefficients match with native glm", { expect_true(all(abs(rCoefs - coefs) < 1e-6)) expect_true(all( as.character(stats$features) == - c("(Intercept)", "Sepal_Length", "Species__versicolor", "Species__virginica"))) + c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica"))) }) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 4efbfdd2f17c..2dd18c350fa4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -47,8 +47,8 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol { /** * :: Experimental :: * Implements the transforms required for fitting a dataset against an R model formula. Currently - * we support a limited subset of the R operators, including '~' and '+'. Also see the R formula - * docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html + * we support a limited subset of the R operators, including '~', '.', ':', '+', and '-'. Also see + * the R formula docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html */ @Experimental class RFormula(override val uid: String) extends Estimator[RFormulaModel] with RFormulaBase { @@ -86,6 +86,8 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R val takenNames = mutable.Set(dataset.columns: _*) def encodeInteraction(terms: Seq[String]): String = { val outputCol = { + // TODO(ekl) this column naming should be unnecessary since we generate the right attr + // names in RInteraction, but the name is lost somewhere before VectorAssembler. var tmp = terms.mkString(":") while (takenNames.contains(tmp)) { tmp += "_" @@ -99,7 +101,7 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R tempColumns += outputCol outputCol } - val encodedCols = resolvedFormula.terms.map { + val encodedTerms = resolvedFormula.terms.map { case terms @ Seq(value) => dataset.schema(value) match { case column if column.dataType == StringType => @@ -111,7 +113,7 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R encodeInteraction(terms) } encoderStages += new VectorAssembler(uid) - .setInputCols(encodedCols.toArray) + .setInputCols(encodedTerms.toArray) .setOutputCol($(featuresCol)) encoderStages += new ColumnPruner(tempColumns.toSet) val pipelineModel = new Pipeline(uid).setStages(encoderStages.toArray).fit(dataset) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala index 2238d7ca465d..c2038d5859a9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala @@ -32,7 +32,7 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) { * of the special '.' term. Duplicate terms will be removed during resolution. */ def resolve(schema: StructType): ResolvedRFormula = { - lazy val dotTerms = expandDot(schema) + val dotTerms = expandDot(schema) var includedTerms = Seq[Seq[String]]() terms.foreach { case term: ColumnRef => @@ -80,29 +80,30 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) { private def expandInteraction( schema: StructType, terms: Seq[InteractionComponent]): Seq[Seq[String]] = { if (terms.isEmpty) { - Seq(Nil) - } else { - val rest = expandInteraction(schema, terms.tail) - val validInteractions = (terms.head match { - case Dot => - expandDot(schema).filter(_ != label.value).flatMap { t => - rest.map { r => - Seq(t) ++ r - } - } - case ColumnRef(value) => - rest.map(Seq(value) ++ _) - }).map(_.distinct) - // Deduplicates feature interactions, for example, a:b is the same as b:a. - var seen = mutable.Set[Set[String]]() - validInteractions.flatMap { - case t if seen.contains(t.toSet) => - None - case t => - seen += t.toSet - Some(t) - }.sortBy(_.length) + return Seq(Nil) } + + val rest = expandInteraction(schema, terms.tail) + val validInteractions = (terms.head match { + case Dot => + expandDot(schema).filter(_ != label.value).flatMap { t => + rest.map { r => + Seq(t) ++ r + } + } + case ColumnRef(value) => + rest.map(Seq(value) ++ _) + }).map(_.distinct) + + // Deduplicates feature interactions, for example, a:b is the same as b:a. + var seen = mutable.Set[Set[String]]() + validInteractions.flatMap { + case t if seen.contains(t.toSet) => + None + case t => + seen += t.toSet + Some(t) + }.sortBy(_.length) } // the dot operator excludes complex column types @@ -116,6 +117,9 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) { /** * Represents a fully evaluated and simplified R formula. + * @param label the column name of the R formula label (response variable). + * @param terms the simplified terms of the R formula. Interactions terms are represented as Seqs + * of column names; non-interaction terms as length 1 Seqs. */ private[ml] case class ResolvedRFormula(label: String, terms: Seq[Seq[String]]) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala index 83cfbfa077e4..20bd63b5dd41 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.types._ * See https://stat.ethz.ch/R-manual/R-devel/library/base/html/formula.html for more * information about factor interactions in R formulae. */ +// TODO(ekl) it might be nice to have standalone tests for RInteraction. @Experimental class RInteraction(override val uid: String) extends Estimator[PipelineModel] with HasInputCols with HasOutputCol { @@ -127,8 +128,8 @@ class RInteraction(override val uid: String) extends Estimator[PipelineModel] } /** - * This helper class combines the output of multiple string-indexed columns to simulate - * the joint indexing of tuples containing all the column values. + * This helper class computes the joint index of multiple string-indexed columns such that the + * combined index covers the cartesian product of column values. */ private class IndexCombiner( inputCols: Array[String], attrNames: Array[String], outputCol: String) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala index 6d36509773fb..df8b60342eda 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala @@ -25,7 +25,7 @@ class RFormulaParserSuite extends SparkFunSuite { formula: String, label: String, terms: Seq[String], - schema: StructType = null) { + schema: StructType = new StructType) { val resolved = RFormulaParser.parse(formula).resolve(schema) assert(resolved.label == label) val simpleTerms = terms.map { t => diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index b89457e8d22a..b21b4f209294 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -125,26 +125,26 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { } test("numeric interaction") { - val formula = new RFormula().setFormula("a ~ b:c") + val formula = new RFormula().setFormula("a ~ b:c:d") val original = sqlContext.createDataFrame( - Seq((1, 2, 4), (2, 3, 4)) - ).toDF("a", "b", "c") + Seq((1, 2, 4, 2), (2, 3, 4, 1)) + ).toDF("a", "b", "c", "d") val model = formula.fit(original) val result = model.transform(original) val expected = sqlContext.createDataFrame( Seq( - (1, 2, 4, Vectors.dense(8.0), 1.0), - (2, 3, 4, Vectors.dense(12.0), 2.0)) - ).toDF("id", "a", "b", "features", "label") + (1, 2, 4, 2, Vectors.dense(16.0), 1.0), + (2, 3, 4, 1, Vectors.dense(12.0), 2.0)) + ).toDF("a", "b", "c", "d", "features", "label") assert(result.collect() === expected.collect()) val attrs = AttributeGroup.fromStructField(result.schema("features")) val expectedAttrs = new AttributeGroup( "features", - Array[Attribute](new NumericAttribute(Some("b:c"), Some(1)))) + Array[Attribute](new NumericAttribute(Some("b:c:d"), Some(1)))) assert(attrs === expectedAttrs) } - test("numeric:factor interaction") { + test("factor numeric interaction") { val formula = new RFormula().setFormula("id ~ a:b") val original = sqlContext.createDataFrame( Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5), (4, "baz", 5), (4, "baz", 5)) @@ -171,7 +171,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { assert(attrs === expectedAttrs) } - test("factor:factor interaction") { + test("factor factor interaction") { val formula = new RFormula().setFormula("id ~ a:b") val original = sqlContext.createDataFrame( Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 3f04c41ac5ab..0576ccf56f7b 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1117,7 +1117,7 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol): Implements the transforms required for fitting a dataset against an R model formula. Currently we support a limited subset of the R - operators, including '~', '+', '-', and '.'. Also see the R formula + operators, including '~', '.', ':', '+', and '-'. Also see the R formula docs: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html From e5099f695a9533e402ca15434330e2f3678d30f3 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 6 Aug 2015 17:21:21 -0700 Subject: [PATCH 12/27] tests and attribute refactorign --- .../spark/ml/attribute/attributes.scala | 16 +- .../apache/spark/ml/feature/RFormula.scala | 12 +- .../spark/ml/feature/RInteraction.scala | 5 +- .../spark/ml/feature/VectorAssembler.scala | 4 +- .../spark/ml/feature/RInteractionSuite.scala | 140 ++++++++++++++++++ 5 files changed, 159 insertions(+), 18 deletions(-) create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/RInteractionSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala index e479f169021d..a7c10333c0d5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala @@ -124,18 +124,28 @@ private[attribute] trait AttributeFactory { private[attribute] def fromMetadata(metadata: Metadata): Attribute /** - * Creates an [[Attribute]] from a [[StructField]] instance. + * Creates an [[Attribute]] from a [[StructField]] instance, optionally preserving name. */ - def fromStructField(field: StructField): Attribute = { + private[ml] def decodeStructField(field: StructField, preserveName: Boolean): Attribute = { require(field.dataType.isInstanceOf[NumericType]) val metadata = field.metadata val mlAttr = AttributeKeys.ML_ATTR if (metadata.contains(mlAttr)) { - fromMetadata(metadata.getMetadata(mlAttr)).withName(field.name) + val attr = fromMetadata(metadata.getMetadata(mlAttr)) + if (preserveName) { + attr + } else { + attr.withName(field.name) + } } else { UnresolvedAttribute } } + + /** + * Creates an [[Attribute]] from a [[StructField]] instance. + */ + def fromStructField(field: StructField): Attribute = decodeStructField(field, false) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 2dd18c350fa4..9152ed039d87 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -83,18 +83,8 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R val resolvedFormula = parsedFormula.resolve(dataset.schema) val encoderStages = ArrayBuffer[PipelineStage]() val tempColumns = ArrayBuffer[String]() - val takenNames = mutable.Set(dataset.columns: _*) def encodeInteraction(terms: Seq[String]): String = { - val outputCol = { - // TODO(ekl) this column naming should be unnecessary since we generate the right attr - // names in RInteraction, but the name is lost somewhere before VectorAssembler. - var tmp = terms.mkString(":") - while (takenNames.contains(tmp)) { - tmp += "_" - } - tmp - } - takenNames.add(outputCol) + val outputCol = "interaction_" + uid + "_" + terms.mkString(":") encoderStages += new RInteraction() .setInputCols(terms.toArray) .setOutputCol(outputCol) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala index 20bd63b5dd41..a801b27f63b7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala @@ -41,7 +41,6 @@ import org.apache.spark.sql.types._ * See https://stat.ethz.ch/R-manual/R-devel/library/base/html/formula.html for more * information about factor interactions in R formulae. */ -// TODO(ekl) it might be nice to have standalone tests for RInteraction. @Experimental class RInteraction(override val uid: String) extends Estimator[PipelineModel] with HasInputCols with HasOutputCol { @@ -120,8 +119,8 @@ class RInteraction(override val uid: String) extends Estimator[PipelineModel] override def copy(extra: ParamMap): RInteraction = defaultCopy(extra) private def checkParams(): Unit = { - require(isDefined(inputCols), "Input cols must be defined first.") - require(isDefined(outputCol), "Output col must be defined first.") + require(get(inputCols).isDefined, "Input cols must be defined first.") + require(get(outputCol).isDefined, "Output col must be defined first.") require($(inputCols).length > 0, "Input cols must have non-zero length.") require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.") } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index aa1e95907986..b27edfcab696 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -56,10 +56,12 @@ class VectorAssembler(override val uid: String) val index = schema.fieldIndex(c) field.dataType match { case DoubleType => - val attr = Attribute.fromStructField(field) + val attr = Attribute.decodeStructField(field, preserveName = true) // If the input column doesn't have ML attribute, assume numeric. if (attr == UnresolvedAttribute) { Some(NumericAttribute.defaultAttr.withName(c)) + } else if (attr.name.isDefined) { + Some(attr) } else { Some(attr.withName(c)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RInteractionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RInteractionSuite.scala new file mode 100644 index 000000000000..bf1eecef033a --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RInteractionSuite.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class RInteractionSuite extends SparkFunSuite with MLlibTestSparkContext { + test("params") { + ParamsSuite.checkParams(new RInteraction()) + } + + test("parameter validation") { + val data = sqlContext.createDataFrame( + Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")) + ).toDF("id", "a", "b") + def check(inputCols: Array[String], outputCol: String, expectOk: Boolean): Unit = { + val interaction = new RInteraction() + if (inputCols != null) { + interaction.setInputCols(inputCols) + } + if (outputCol != null) { + interaction.setOutputCol(outputCol) + } + if (expectOk) { + interaction.transformSchema(data.schema) + interaction.fit(data).transform(data).collect() + } else { + intercept[IllegalArgumentException] { + interaction.fit(data) + } + intercept[IllegalArgumentException] { + interaction.transformSchema(data.schema) + } + } + } + check(Array("a", "b"), "test", true) + check(Array("id"), "test", true) + check(Array("b"), "test", true) + check(Array("b"), "test", true) + check(Array(), "test", false) + check(Array("a", "b", "b"), "id", false) + check(Array("a", "b"), null, false) + check(null, "test", false) + } + + test("numeric interaction") { + val interaction = new RInteraction() + .setInputCols(Array("b", "c", "d")) + .setOutputCol("test") + val original = sqlContext.createDataFrame( + Seq((1, 2, 4, 2), (2, 3, 4, 1)) + ).toDF("a", "b", "c", "d") + val model = interaction.fit(original) + val result = model.transform(original) + val expected = sqlContext.createDataFrame( + Seq( + (1, 2, 4, 2, 16.0), + (2, 3, 4, 1, 12.0)) + ).toDF("a", "b", "c", "d", "test") + assert(result.collect() === expected.collect()) + val attr = Attribute.decodeStructField(result.schema("test"), preserveName = true) + val expectedAttr = new NumericAttribute(Some("b:c:d"), None) + assert(attr === expectedAttr) + } + + test("factor interaction") { + val interaction = new RInteraction() + .setInputCols(Array("a", "b")) + .setOutputCol("test") + val original = sqlContext.createDataFrame( + Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")) + ).toDF("id", "a", "b") + val model = interaction.fit(original) + val result = model.transform(original) + val expected = sqlContext.createDataFrame( + Seq( + (1, "foo", "zq", Vectors.dense(0.0, 1.0, 0.0, 0.0)), + (2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0)), + (3, "bar", "zz", Vectors.dense(0.0, 0.0, 1.0, 0.0))) + ).toDF("id", "a", "b", "test") + assert(result.collect() === expected.collect()) + val attrs = AttributeGroup.fromStructField(result.schema("test")) + val expectedAttrs = new AttributeGroup( + "test", + Array[Attribute]( + new BinaryAttribute(Some("a_bar:b_zq"), Some(1)), + new BinaryAttribute(Some("a_foo:b_zq"), Some(2)), + new BinaryAttribute(Some("a_bar:b_zz"), Some(3)), + new BinaryAttribute(Some("a_foo:b_zz"), Some(4)))) + assert(attrs === expectedAttrs) + } + + test("factor numeric interaction") { + val interaction = new RInteraction() + .setInputCols(Array("a", "b")) + .setOutputCol("test") + val original = sqlContext.createDataFrame( + Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5), (4, "baz", 5), (4, "baz", 5)) + ).toDF("id", "a", "b") + val model = interaction.fit(original) + val result = model.transform(original) + val expected = sqlContext.createDataFrame( + Seq( + (1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0)), + (2, "bar", 4, Vectors.dense(0.0, 4.0, 0.0)), + (3, "bar", 5, Vectors.dense(0.0, 5.0, 0.0)), + (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0)), + (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0)), + (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0))) + ).toDF("id", "a", "b", "test") + assert(result.collect() === expected.collect()) + val attrs = AttributeGroup.fromStructField(result.schema("test")) + val expectedAttrs = new AttributeGroup( + "test", + Array[Attribute]( + new BinaryAttribute(Some("a_baz:b"), Some(1)), + new BinaryAttribute(Some("a_bar:b"), Some(2)), + new BinaryAttribute(Some("a_foo:b"), Some(3)))) + assert(attrs === expectedAttrs) + } +} From 26b692522a2177338f64dba5422e113f8a992cf5 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 6 Aug 2015 17:25:07 -0700 Subject: [PATCH 13/27] Revert user-facing R changes --- R/pkg/R/mllib.R | 2 +- R/pkg/inst/tests/test_mllib.R | 10 +-- .../apache/spark/ml/feature/RFormula.scala | 48 +++++----- .../spark/ml/feature/RFormulaParser.scala | 84 +++-------------- .../spark/ml/feature/RInteraction.scala | 8 +- .../spark/ml/feature/VectorAssembler.scala | 15 ++-- .../ml/feature/RFormulaParserSuite.scala | 89 +------------------ .../spark/ml/feature/RFormulaSuite.scala | 76 +--------------- python/pyspark/ml/feature.py | 2 +- 9 files changed, 63 insertions(+), 271 deletions(-) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 50bf7669bfff..b524d1fd8749 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -27,7 +27,7 @@ setClass("PipelineModel", representation(model = "jobj")) #' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package. #' #' @param formula A symbolic description of the model to be fitted. Currently only a few formula -#' operators are supported, including '~', '.', ':', '+', and '-'. +#' operators are supported, including '~', '+', '-', and '.'. #' @param data DataFrame for training #' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg. #' @param lambda Regularization parameter diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R index 032f8ec68b9d..f272de78ad4a 100644 --- a/R/pkg/inst/tests/test_mllib.R +++ b/R/pkg/inst/tests/test_mllib.R @@ -49,14 +49,6 @@ test_that("dot minus and intercept vs native glm", { expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) }) -test_that("feature interaction vs native glm", { - training <- createDataFrame(sqlContext, iris) - model <- glm(Sepal_Width ~ Species:Sepal_Length, data = training) - vals <- collect(select(predict(model, training), "prediction")) - rVals <- predict(glm(Sepal.Width ~ Species:Sepal.Length, data = iris), iris) - expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) -}) - test_that("summary coefficients match with native glm", { training <- createDataFrame(sqlContext, iris) stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training)) @@ -65,5 +57,5 @@ test_that("summary coefficients match with native glm", { expect_true(all(abs(rCoefs - coefs) < 1e-6)) expect_true(all( as.character(stats$features) == - c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica"))) + c("(Intercept)", "Sepal_Length", "Species__versicolor", "Species__virginica"))) }) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 9152ed039d87..d5360c9217ea 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -47,8 +47,8 @@ private[feature] trait RFormulaBase extends HasFeaturesCol with HasLabelCol { /** * :: Experimental :: * Implements the transforms required for fitting a dataset against an R model formula. Currently - * we support a limited subset of the R operators, including '~', '.', ':', '+', and '-'. Also see - * the R formula docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html + * we support a limited subset of the R operators, including '~' and '+'. Also see the R formula + * docs here: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html */ @Experimental class RFormula(override val uid: String) extends Estimator[RFormulaModel] with RFormulaBase { @@ -81,26 +81,32 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R require(isDefined(formula), "Formula must be defined first.") val parsedFormula = RFormulaParser.parse($(formula)) val resolvedFormula = parsedFormula.resolve(dataset.schema) + // StringType terms and terms representing interactions need to be encoded before assembly. + // TODO(ekl) add support for feature interactions val encoderStages = ArrayBuffer[PipelineStage]() val tempColumns = ArrayBuffer[String]() - def encodeInteraction(terms: Seq[String]): String = { - val outputCol = "interaction_" + uid + "_" + terms.mkString(":") - encoderStages += new RInteraction() - .setInputCols(terms.toArray) - .setOutputCol(outputCol) - tempColumns += outputCol - outputCol - } - val encodedTerms = resolvedFormula.terms.map { - case terms @ Seq(value) => - dataset.schema(value) match { - case column if column.dataType == StringType => - encodeInteraction(terms) - case _ => - value - } - case terms => - encodeInteraction(terms) + val takenNames = mutable.Set(dataset.columns: _*) + val encodedTerms = resolvedFormula.terms.map { term => + dataset.schema(term) match { + case column if column.dataType == StringType => + val indexCol = term + "_idx_" + uid + val encodedCol = { + var tmp = term + while (takenNames.contains(tmp)) { + tmp += "_" + } + tmp + } + takenNames.add(indexCol) + takenNames.add(encodedCol) + encoderStages += new StringIndexer().setInputCol(term).setOutputCol(indexCol) + encoderStages += new OneHotEncoder().setInputCol(indexCol).setOutputCol(encodedCol) + tempColumns += indexCol + tempColumns += encodedCol + encodedCol + case _ => + term + } } encoderStages += new VectorAssembler(uid) .setInputCols(encodedTerms.toArray) @@ -197,7 +203,7 @@ class RFormulaModel private[feature]( * Utility transformer for removing temporary columns from a DataFrame. * TODO(ekl) make this a public transformer */ -private[feature] class ColumnPruner(columnsToPrune: Set[String]) extends Transformer { +private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer { override val uid = Identifiable.randomUID("columnPruner") override def transform(dataset: DataFrame): DataFrame = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala index c2038d5859a9..1ca3b92a7d92 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormulaParser.scala @@ -17,7 +17,6 @@ package org.apache.spark.ml.feature -import scala.collection.mutable import scala.util.parsing.combinator.RegexParsers import org.apache.spark.mllib.linalg.VectorUDT @@ -32,28 +31,20 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) { * of the special '.' term. Duplicate terms will be removed during resolution. */ def resolve(schema: StructType): ResolvedRFormula = { - val dotTerms = expandDot(schema) - var includedTerms = Seq[Seq[String]]() + var includedTerms = Seq[String]() terms.foreach { - case term: ColumnRef => - includedTerms :+= Seq(term.value) - case ColumnInteraction(terms) => - includedTerms ++= expandInteraction(schema, terms) case Dot => - includedTerms ++= dotTerms.map(Seq(_)) + includedTerms ++= simpleTypes(schema).filter(_ != label.value) + case ColumnRef(value) => + includedTerms :+= value case Deletion(term: Term) => term match { - case inner: ColumnRef => - includedTerms = includedTerms.filter(_ != Seq(inner.value)) - case ColumnInteraction(terms) => - val fromInteraction = expandInteraction(schema, terms).map(_.toSet) - includedTerms = includedTerms.filter(t => !fromInteraction.contains(t.toSet)) + case ColumnRef(value) => + includedTerms = includedTerms.filter(_ != value) case Dot => // e.g. "- .", which removes all first-order terms - includedTerms = includedTerms.filter { - case Seq(t) => !dotTerms.contains(t) - case _ => true - } + val fromSchema = simpleTypes(schema) + includedTerms = includedTerms.filter(fromSchema.contains(_)) case _: Deletion => assert(false, "Deletion terms cannot be nested") case _: Intercept => @@ -76,52 +67,19 @@ private[ml] case class ParsedRFormula(label: ColumnRef, terms: Seq[Term]) { intercept } - // expands the Dot operators in interaction terms - private def expandInteraction( - schema: StructType, terms: Seq[InteractionComponent]): Seq[Seq[String]] = { - if (terms.isEmpty) { - return Seq(Nil) - } - - val rest = expandInteraction(schema, terms.tail) - val validInteractions = (terms.head match { - case Dot => - expandDot(schema).filter(_ != label.value).flatMap { t => - rest.map { r => - Seq(t) ++ r - } - } - case ColumnRef(value) => - rest.map(Seq(value) ++ _) - }).map(_.distinct) - - // Deduplicates feature interactions, for example, a:b is the same as b:a. - var seen = mutable.Set[Set[String]]() - validInteractions.flatMap { - case t if seen.contains(t.toSet) => - None - case t => - seen += t.toSet - Some(t) - }.sortBy(_.length) - } - // the dot operator excludes complex column types - private def expandDot(schema: StructType): Seq[String] = { + private def simpleTypes(schema: StructType): Seq[String] = { schema.fields.filter(_.dataType match { case _: NumericType | StringType | BooleanType | _: VectorUDT => true case _ => false - }).map(_.name).filter(_ != label.value) + }).map(_.name) } } /** * Represents a fully evaluated and simplified R formula. - * @param label the column name of the R formula label (response variable). - * @param terms the simplified terms of the R formula. Interactions terms are represented as Seqs - * of column names; non-interaction terms as length 1 Seqs. */ -private[ml] case class ResolvedRFormula(label: String, terms: Seq[Seq[String]]) +private[ml] case class ResolvedRFormula(label: String, terms: Seq[String]) /** * R formula terms. See the R formula docs here for more information: @@ -129,17 +87,11 @@ private[ml] case class ResolvedRFormula(label: String, terms: Seq[Seq[String]]) */ private[ml] sealed trait Term -/** A term that may be part of an interaction, e.g. 'x' in 'x:y' */ -private[ml] sealed trait InteractionComponent extends Term - /* R formula reference to all available columns, e.g. "." in a formula */ -private[ml] case object Dot extends InteractionComponent +private[ml] case object Dot extends Term /* R formula reference to a column, e.g. "+ Species" in a formula */ -private[ml] case class ColumnRef(value: String) extends InteractionComponent - -/* R formula interaction of several columns, e.g. "Sepal_Length:Species" in a formula */ -private[ml] case class ColumnInteraction(terms: Seq[InteractionComponent]) extends Term +private[ml] case class ColumnRef(value: String) extends Term /* R formula intercept toggle, e.g. "+ 0" in a formula */ private[ml] case class Intercept(enabled: Boolean) extends Term @@ -157,15 +109,7 @@ private[ml] object RFormulaParser extends RegexParsers { def columnRef: Parser[ColumnRef] = "([a-zA-Z]|\\.[a-zA-Z_])[a-zA-Z0-9._]*".r ^^ { case a => ColumnRef(a) } - def dot: Parser[InteractionComponent] = "\\.".r ^^ { case _ => Dot } - - def interaction: Parser[List[InteractionComponent]] = repsep(columnRef | dot, ":") - - def term: Parser[Term] = intercept | - interaction ^^ { - case Seq(term) => term - case terms => ColumnInteraction(terms) - } + def term: Parser[Term] = intercept | columnRef | "\\.".r ^^ { case _ => Dot } def terms: Parser[List[Term]] = (term ~ rep("+" ~ term | "-" ~ term)) ^^ { case op ~ list => list.foldLeft(List(op)) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala index a801b27f63b7..35746b0ed31b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala @@ -127,8 +127,8 @@ class RInteraction(override val uid: String) extends Estimator[PipelineModel] } /** - * This helper class computes the joint index of multiple string-indexed columns such that the - * combined index covers the cartesian product of column values. + * Computes the joint index of multiple string-indexed columns such that the combined index + * covers the cartesian product of column values. */ private class IndexCombiner( inputCols: Array[String], attrNames: Array[String], outputCol: String) @@ -181,8 +181,8 @@ private class IndexCombiner( } /** - * This helper class scales the input vector column by the product of the input numeric columns. - * If no vector column is specified, the output is just the product of the numeric columns. + * Scales the input vector column by the product of the input numeric columns. If no vector column + * is specified, the output is just the product of the numeric columns. */ private class NumericInteraction( inputCols: Array[String], vectorCol: Option[String], outputCol: String) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index b27edfcab696..086917fa680f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -56,12 +56,10 @@ class VectorAssembler(override val uid: String) val index = schema.fieldIndex(c) field.dataType match { case DoubleType => - val attr = Attribute.decodeStructField(field, preserveName = true) + val attr = Attribute.fromStructField(field) // If the input column doesn't have ML attribute, assume numeric. if (attr == UnresolvedAttribute) { Some(NumericAttribute.defaultAttr.withName(c)) - } else if (attr.name.isDefined) { - Some(attr) } else { Some(attr.withName(c)) } @@ -71,8 +69,15 @@ class VectorAssembler(override val uid: String) case _: VectorUDT => val group = AttributeGroup.fromStructField(field) if (group.attributes.isDefined) { - // If attributes are defined, copy them. - group.attributes.get + // If attributes are defined, copy them with updated names. + group.attributes.get.map { attr => + if (attr.name.isDefined) { + // TODO: Define a rigorous naming scheme. + attr.withName(c + "_" + attr.name.get) + } else { + attr + } + } } else { // Otherwise, treat all attributes as numeric. If we cannot get the number of attributes // from metadata, check the first row. diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala index df8b60342eda..436e66bab09b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaParserSuite.scala @@ -25,24 +25,16 @@ class RFormulaParserSuite extends SparkFunSuite { formula: String, label: String, terms: Seq[String], - schema: StructType = new StructType) { + schema: StructType = null) { val resolved = RFormulaParser.parse(formula).resolve(schema) assert(resolved.label == label) - val simpleTerms = terms.map { t => - if (t.contains(":")) { - t.split(":").toSeq - } else { - Seq(t) - } - } - assert(resolved.terms == simpleTerms) + assert(resolved.terms == terms) } test("parse simple formulas") { checkParse("y ~ x", "y", Seq("x")) checkParse("y ~ x + x", "y", Seq("x")) - checkParse("y~x+z", "y", Seq("x", "z")) - checkParse("y ~ ._fo..o ", "y", Seq("._fo..o")) + checkParse("y ~ ._foo ", "y", Seq("._foo")) checkParse("resp ~ A_VAR + B + c123", "resp", Seq("A_VAR", "B", "c123")) } @@ -87,79 +79,4 @@ class RFormulaParserSuite extends SparkFunSuite { assert(!RFormulaParser.parse("a ~ b - 1").hasIntercept) assert(!RFormulaParser.parse("a ~ b + 1 - 1").hasIntercept) } - - test("parse interactions") { - checkParse("y ~ a:b", "y", Seq("a:b")) - checkParse("y ~ ._a:._x", "y", Seq("._a:._x")) - checkParse("y ~ foo:bar", "y", Seq("foo:bar")) - checkParse("y ~ a : b : c", "y", Seq("a:b:c")) - checkParse("y ~ q + a:b:c + b:c + c:d + z", "y", Seq("q", "a:b:c", "b:c", "c:d", "z")) - } - - test("parse basic interactions with dot") { - val schema = (new StructType) - .add("a", "int", true) - .add("b", "long", false) - .add("c", "string", true) - checkParse("y ~ .:x", "y", Seq("a:x", "b:x", "c:x"), schema) - checkParse("a ~ .:x", "a", Seq("b:x", "c:x"), schema) - checkParse("a ~ x:.", "a", Seq("x:b", "x:c"), schema) - checkParse("a ~ .:b:.:.:c:.", "a", Seq("b:c"), schema) - } - - // Test data generated in R with terms.formula(y ~ .:., data = iris) - test("parse all to all iris interactions") { - val schema = (new StructType) - .add("Sepal.Length", "double", true) - .add("Sepal.Width", "double", true) - .add("Petal.Length", "double", true) - .add("Petal.Width", "double", true) - .add("Species", "string", true) - checkParse( - "y ~ .:.", - "y", - Seq( - "Sepal.Length", - "Sepal.Width", - "Petal.Length", - "Petal.Width", - "Species", - "Sepal.Length:Sepal.Width", - "Sepal.Length:Petal.Length", - "Sepal.Length:Petal.Width", - "Sepal.Length:Species", - "Sepal.Width:Petal.Length", - "Sepal.Width:Petal.Width", - "Sepal.Width:Species", - "Petal.Length:Petal.Width", - "Petal.Length:Species", - "Petal.Width:Species"), - schema) - } - - // Test data generated in R with terms.formula(y ~ .:. - Species:., data = iris) - test("parse interaction negation with iris") { - val schema = (new StructType) - .add("Sepal.Length", "double", true) - .add("Sepal.Width", "double", true) - .add("Petal.Length", "double", true) - .add("Petal.Width", "double", true) - .add("Species", "string", true) - checkParse("y ~ .:. - .:.", "y", Nil, schema) - checkParse( - "y ~ .:. - Species:.", - "y", - Seq( - "Sepal.Length", - "Sepal.Width", - "Petal.Length", - "Petal.Width", - "Sepal.Length:Sepal.Width", - "Sepal.Length:Petal.Length", - "Sepal.Length:Petal.Width", - "Sepal.Width:Petal.Length", - "Sepal.Width:Petal.Width", - "Petal.Length:Petal.Width"), - schema) - } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index b21b4f209294..6aed3243afce 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -118,81 +118,9 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { val expectedAttrs = new AttributeGroup( "features", Array( - new BinaryAttribute(Some("a_bar"), Some(1)), - new BinaryAttribute(Some("a_foo"), Some(2)), + new BinaryAttribute(Some("a__bar"), Some(1)), + new BinaryAttribute(Some("a__foo"), Some(2)), new NumericAttribute(Some("b"), Some(3)))) assert(attrs === expectedAttrs) } - - test("numeric interaction") { - val formula = new RFormula().setFormula("a ~ b:c:d") - val original = sqlContext.createDataFrame( - Seq((1, 2, 4, 2), (2, 3, 4, 1)) - ).toDF("a", "b", "c", "d") - val model = formula.fit(original) - val result = model.transform(original) - val expected = sqlContext.createDataFrame( - Seq( - (1, 2, 4, 2, Vectors.dense(16.0), 1.0), - (2, 3, 4, 1, Vectors.dense(12.0), 2.0)) - ).toDF("a", "b", "c", "d", "features", "label") - assert(result.collect() === expected.collect()) - val attrs = AttributeGroup.fromStructField(result.schema("features")) - val expectedAttrs = new AttributeGroup( - "features", - Array[Attribute](new NumericAttribute(Some("b:c:d"), Some(1)))) - assert(attrs === expectedAttrs) - } - - test("factor numeric interaction") { - val formula = new RFormula().setFormula("id ~ a:b") - val original = sqlContext.createDataFrame( - Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5), (4, "baz", 5), (4, "baz", 5)) - ).toDF("id", "a", "b") - val model = formula.fit(original) - val result = model.transform(original) - val expected = sqlContext.createDataFrame( - Seq( - (1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0), 1.0), - (2, "bar", 4, Vectors.dense(0.0, 4.0, 0.0), 2.0), - (3, "bar", 5, Vectors.dense(0.0, 5.0, 0.0), 3.0), - (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0), - (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0), - (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0), 4.0)) - ).toDF("id", "a", "b", "features", "label") - assert(result.collect() === expected.collect()) - val attrs = AttributeGroup.fromStructField(result.schema("features")) - val expectedAttrs = new AttributeGroup( - "features", - Array[Attribute]( - new BinaryAttribute(Some("a_baz:b"), Some(1)), - new BinaryAttribute(Some("a_bar:b"), Some(2)), - new BinaryAttribute(Some("a_foo:b"), Some(3)))) - assert(attrs === expectedAttrs) - } - - test("factor factor interaction") { - val formula = new RFormula().setFormula("id ~ a:b") - val original = sqlContext.createDataFrame( - Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")) - ).toDF("id", "a", "b") - val model = formula.fit(original) - val result = model.transform(original) - val expected = sqlContext.createDataFrame( - Seq( - (1, "foo", "zq", Vectors.dense(0.0, 1.0, 0.0, 0.0), 1.0), - (2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0), 2.0), - (3, "bar", "zz", Vectors.dense(0.0, 0.0, 1.0, 0.0), 3.0)) - ).toDF("id", "a", "b", "features", "label") - assert(result.collect() === expected.collect()) - val attrs = AttributeGroup.fromStructField(result.schema("features")) - val expectedAttrs = new AttributeGroup( - "features", - Array[Attribute]( - new BinaryAttribute(Some("a_bar:b_zq"), Some(1)), - new BinaryAttribute(Some("a_foo:b_zq"), Some(2)), - new BinaryAttribute(Some("a_bar:b_zz"), Some(3)), - new BinaryAttribute(Some("a_foo:b_zz"), Some(4)))) - assert(attrs === expectedAttrs) - } } diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 2c45034471b9..cb4dfa21298c 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1157,7 +1157,7 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol): Implements the transforms required for fitting a dataset against an R model formula. Currently we support a limited subset of the R - operators, including '~', '.', ':', '+', and '-'. Also see the R formula + operators, including '~', '+', '-', and '.'. Also see the R formula docs: http://stat.ethz.ch/R-manual/R-patched/library/stats/html/formula.html From 8396677674022c9b29b54c7519c5fba71eb3aa03 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 14 Sep 2015 19:17:23 -0700 Subject: [PATCH 14/27] Mon Sep 14 19:17:23 PDT 2015 --- .../spark/ml/feature/RInteraction.scala | 331 ++++++++++++++++++ 1 file changed, 331 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala new file mode 100644 index 000000000000..7bc14693888c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala @@ -0,0 +1,331 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.collection.mutable.{ArrayBuffer, ArrayBuilder} + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute, NumericAttribute} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} +import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +/** + * :: Experimental :: + * Implements the transforms required for R-style feature interactions. In summary, once fitted to + * a dataset, this transformer jointly one-hot encodes all factor input columns, then scales + * the encoded vector by all numeric input columns. If only numeric columns are specified, the + * output column will be a one-length vector containing their product. During one-hot encoding, + * the last category will be preserved unless the interaction is trivial. + * + * See https://stat.ethz.ch/R-manual/R-devel/library/base/html/formula.html for more + * information about factor interactions in R formulae. + */ +@Experimental +class RInteraction(override val uid: String) extends Estimator[PipelineModel] + with HasInputCols with HasOutputCol { + + def this() = this(Identifiable.randomUID("interaction")) + + /** @group setParam */ + def setInputCols(values: Array[String]): this.type = set(inputCols, values) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def fit(dataset: DataFrame): PipelineModel = { + checkParams() + val encoderStages = ArrayBuffer[PipelineStage]() + val tempColumns = ArrayBuffer[String]() + val (factorCols, nonFactorCols) = $(inputCols) + .partition(input => dataset.schema(input).dataType == StringType) + + val encodedFactors: Option[String] = + if (factorCols.length > 0) { + val indexedCols = factorCols.map { input => + val output = input + "_idx_" + uid + encoderStages += new StringIndexer() + .setInputCol(input) + .setOutputCol(output) + tempColumns += output + output + } + val combinedIndex = "combined_idx_" + uid + tempColumns += combinedIndex + val encodedCol = if (nonFactorCols.length > 0) { + "factors_" + uid + } else { + $(outputCol) + } + encoderStages += new IndexCombiner(indexedCols, factorCols, combinedIndex) + encoderStages += { + val encoder = new OneHotEncoder() + .setInputCol(combinedIndex) + .setOutputCol(encodedCol) + if ($(inputCols).length > 1) { + encoder.setDropLast(false) // R includes all columns for interactions. + } + encoder + } + Some(encodedCol) + } else { + None + } + + if (nonFactorCols.length > 0) { + encoderStages += new NumericInteraction(nonFactorCols, encodedFactors, $(outputCol)) + if (encodedFactors.isDefined) { + tempColumns += encodedFactors.get + } + } + + encoderStages += new ColumnPruner(tempColumns.toSet) + new Pipeline(uid) + .setStages(encoderStages.toArray) + .fit(dataset) + .setParent(this) + } + + // optimistic schema; does not contain any ML attributes + override def transformSchema(schema: StructType): StructType = { + checkParams() + if ($(inputCols).exists(col => schema(col).dataType == StringType)) { + StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, true)) + } else { + StructType(schema.fields :+ StructField($(outputCol), DoubleType, true)) + } + } + + override def copy(extra: ParamMap): RInteraction = defaultCopy(extra) + + private def checkParams(): Unit = { + require(get(inputCols).isDefined, "Input cols must be defined first.") + require(get(outputCol).isDefined, "Output col must be defined first.") + require($(inputCols).length > 0, "Input cols must have non-zero length.") + require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.") + } +} + +class Interaction(override val uid: String) extends Transformer + with HasInputCols with HasOutputCol { + + def this() = this(Identifiable.randomUID("interaction")) + + /** @group setParam */ + def setInputCols(values: Array[String]): this.type = set(inputCols, values) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transformSchema(schema: StructType): StructType = { + checkParams() + if ($(inputCols).exists(col => schema(col).dataType == StringType)) { + StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, true)) + } else { + StructType(schema.fields :+ StructField($(outputCol), DoubleType, true)) + } + } + + override def transform(dataset: DataFrame): DataFrame = { + val isCategorical: Array[Boolean] = $(inputCols).map { col => + dataset.schema(col).dataType match { + case DoubleType => + UnaryIterator() + case _: VectorUDT => + OneHotIterator() + } + } + inputCols.foreach { + } + dataset + dataset.select( + col("*"), + encode(array(inputCols).cast(DoubleType)).as(outputColName, metadata)) + } + + override def copy(extra: ParamMap): RInteraction = defaultCopy(extra) + + private def checkParams(): Unit = { + require(get(inputCols).isDefined, "Input cols must be defined first.") + require(get(outputCol).isDefined, "Output col must be defined first.") + require($(inputCols).length > 0, "Input cols must have non-zero length.") + require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.") + } +} + +private object Interaction { + private[feature] def interact(vv: Any*): Vector = { + val indices = ArrayBuilder.make[Int] + val values = ArrayBuilder.make[Double] + var cur = 0 + vv.foreach { + for (item in vector) { + for (item in vector) { + } + } + case v: Double => + if (v != 0.0) { + indices += cur + values += v + } + cur += 1 + case vec: Vector => + vec.foreachActive { case (i, v) => + if (v != 0.0) { + indices += cur + i + values += v + } + } + cur += vec.size + case null => + // TODO: output Double.NaN? + throw new SparkException("Values to assemble cannot be null.") + case o => + throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.") + } + Vectors.sparse(cur, indices.result(), values.result()).compressed + } +} + +/** + * Computes the joint index of multiple string-indexed columns such that the combined index + * covers the cartesian product of column values. + */ +private class IndexCombiner( + inputCols: Array[String], attrNames: Array[String], outputCol: String) + extends Transformer { + + override val uid = Identifiable.randomUID("indexCombiner") + + override def transform(dataset: DataFrame): DataFrame = { + val inputMetadata = inputCols.map(col => + Attribute.fromStructField(dataset.schema(col)).asInstanceOf[NominalAttribute]) + val cardinalities = inputMetadata.map(_.values.get.length) + val combiner = udf { values: Seq[Double] => + var offset = 1 + var res = 0.0 + var i = 0 + while (i < values.length) { + res += values(i) * offset + offset *= cardinalities(i) + i += 1 + } + res + } + val metadata = NominalAttribute.defaultAttr + .withName(outputCol) + .withValues(generateAttrNames(inputMetadata, attrNames)) + .toMetadata() + dataset.select( + col("*"), + combiner(array(inputCols.map(dataset(_)): _*)).as(outputCol, metadata)) + } + + override def transformSchema(schema: StructType): StructType = { + StructType(schema.fields :+ StructField(outputCol, DoubleType, true)) + } + + override def copy(extra: ParamMap): IndexCombiner = defaultCopy(extra) + + private def generateAttrNames( + attrs: Array[NominalAttribute], names: Array[String]): Array[String] = { + val colName = names.head + val attrNames = attrs.head.values.get.map(colName + "_" + _) + if (attrs.length <= 1) { + attrNames + } else { + generateAttrNames(attrs.tail, names.tail).flatMap { rest => + attrNames.map(n => n + ":" + rest) + } + } + } +} + +/** + * Scales the input vector column by the product of the input numeric columns. If no vector column + * is specified, the output is just the product of the numeric columns. + */ +private class NumericInteraction( + inputCols: Array[String], vectorCol: Option[String], outputCol: String) + extends Transformer { + + override val uid = Identifiable.randomUID("numericInteraction") + + override def transform(dataset: DataFrame): DataFrame = { + if (vectorCol.isDefined) { + val scale = udf { (vec: Vector, scalars: Seq[Double]) => + var x = 1.0 + var i = 0 + while (i < scalars.length) { + x *= scalars(i) + i += 1 + } + val indices = ArrayBuilder.make[Int] + val values = ArrayBuilder.make[Double] + vec.foreachActive { case (i, v) => + if (v != 0.0) { + indices += i + values += v * x + } + } + Vectors.sparse(vec.size, indices.result(), values.result()).compressed + } + val group = AttributeGroup.fromStructField(dataset.schema(vectorCol.get)) + val attrs = group.attributes.get.map { attr => + attr.withName(attr.name.get + ":" + inputCols.mkString(":")) + } + val metadata = new AttributeGroup(outputCol, attrs).toMetadata() + dataset.select( + col("*"), + scale( + col(vectorCol.get), + array(inputCols.map(dataset(_).cast(DoubleType)): _*)).as(outputCol, metadata)) + } else { + val multiply = udf { values: Seq[Double] => + var x = 1.0 + var i = 0 + while (i < values.length) { + x *= values(i) + i += 1 + } + x + } + val metadata = NumericAttribute.defaultAttr + .withName(inputCols.mkString(":")) + .toMetadata() + dataset.select( + col("*"), + multiply(array(inputCols.map(dataset(_).cast(DoubleType)): _*)).as(outputCol, metadata)) + } + } + + override def transformSchema(schema: StructType): StructType = { + if (vectorCol.isDefined) { + StructType(schema.fields :+ StructField(outputCol, new VectorUDT, true)) + } else { + StructType(schema.fields :+ StructField(outputCol, DoubleType, true)) + } + } + + override def copy(extra: ParamMap): NumericInteraction = defaultCopy(extra) +} From ebdc7427a7fcf9618eb21331ab1e544442c0f805 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 14 Sep 2015 20:14:02 -0700 Subject: [PATCH 15/27] Mon Sep 14 20:14:02 PDT 2015 --- .../spark/ml/feature/RInteraction.scala | 70 ++++++++++++------- 1 file changed, 45 insertions(+), 25 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala index 7bc14693888c..7712fd006aba 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala @@ -174,36 +174,56 @@ class Interaction(override val uid: String) extends Transformer } private object Interaction { - private[feature] def interact(vv: Any*): Vector = { - val indices = ArrayBuilder.make[Int] - val values = ArrayBuilder.make[Double] - var cur = 0 - vv.foreach { - for (item in vector) { - for (item in vector) { - } - } - case v: Double => - if (v != 0.0) { + def getIterator(v: Any) = v match { + case d: Double => + Vectors.dense(d) + case vec: Vector => + var indices = ArrayBuilder.make[Int] + var values = ArrayBuilder.make[Double] + var cur = 0 + vec.foreachActive { case (i, v) => + // TODO(ekl) precompute cardinality from the ml attrs + if (CARDINALITY(i) > 0) { + indices += cur + v.toInt + values += 1.0 + cur += CARDINALITY(i) + } else { indices += cur values += v + cur += 1 } - cur += 1 - case vec: Vector => - vec.foreachActive { case (i, v) => - if (v != 0.0) { - indices += cur + i - values += v - } + } + Vectors.sparse(cur, indices.result(), values.result()) + case null => + throw new SparkException("Values to interact cannot be null.") + case o => + throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.") + } + + def interact(vv: Any*): Vector = { + var indices = ArrayBuilder.make[Int] + var values = ArrayBuilder.make[Double] + var size = 1 + indices += 1 + values += 1.0 + vv.foreach { v => + val prevIndices = indices.result() + val prevValues = values.result() + val prevSize = size + val currentVector = getIterator(v) + indices = ArrayBuilder.make[Int] + values = ArrayBuilder.make[Double] + size *= currentVector.size + currentVector.foreachActive { (i, a) => + var j = 0 + while (j < prevIndices.length) { + indices += prevIndices(j) + i * prevSize + values += prevValues(j) * a + j += 1 } - cur += vec.size - case null => - // TODO: output Double.NaN? - throw new SparkException("Values to assemble cannot be null.") - case o => - throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.") + } } - Vectors.sparse(cur, indices.result(), values.result()).compressed + Vectors.sparse(size, indices.result(), values.result()).compressed } } From 183586a9bca3abade8f91aa9c12a3a6f4ae34b20 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 15 Sep 2015 13:36:00 -0700 Subject: [PATCH 16/27] Tue Sep 15 13:36:00 PDT 2015 --- .../spark/ml/feature/RInteraction.scala | 152 ++++++++++-------- .../spark/ml/feature/RInteractionSuite.scala | 41 +++++ 2 files changed, 128 insertions(+), 65 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala index 7712fd006aba..f35e78d1c659 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala @@ -19,14 +19,15 @@ package org.apache.spark.ml.feature import scala.collection.mutable.{ArrayBuffer, ArrayBuilder} +import org.apache.spark.SparkException import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute, NumericAttribute} +import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.Identifiable import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -147,23 +148,98 @@ class Interaction(override val uid: String) extends Transformer } override def transform(dataset: DataFrame): DataFrame = { - val isCategorical: Array[Boolean] = $(inputCols).map { col => - dataset.schema(col).dataType match { - case DoubleType => - UnaryIterator() + checkParams() + + val numValues: Array[Array[Int]] = $(inputCols).map { col => + val field = dataset.schema(col) + field.dataType match { + case _: NumericType | BooleanType => Array[Int]() case _: VectorUDT => - OneHotIterator() + val group = AttributeGroup.fromStructField(field) + assert(group.attributes.isDefined, "TODO") + val cardinalities = group.attributes.get.map { + case nominal: NominalAttribute => nominal.getNumValues.get + case _ => 0 + } + cardinalities.toArray } } - inputCols.foreach { + assert(numValues.length == $(inputCols).length) + + def getIterator(fieldIndex: Int, v: Any) = v match { + case d: Double => + Vectors.dense(d) + case vec: Vector => + var indices = ArrayBuilder.make[Int] + var values = ArrayBuilder.make[Double] + var cur = 0 + vec.foreachActive { case (i, v) => + val numColumns = numValues(fieldIndex)(i) + if (numColumns > 0) { + // One-hot encode the field. TODO(ekl) support drop-last? + indices += cur + v.toInt + values += 1.0 + cur += numColumns + } else { + // Copy the field verbatim. + indices += cur + values += v + cur += 1 + } + } + Vectors.sparse(cur, indices.result(), values.result()) + case null => + throw new SparkException("Values to interact cannot be null.") + case o => + throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.") } - dataset + + def interact(vv: Any*): Vector = { + var indices = ArrayBuilder.make[Int] + var values = ArrayBuilder.make[Double] + var size = 1 + indices += 0 + values += 1.0 + var fieldIndex = 0 + while (fieldIndex < vv.length) { + val prevIndices = indices.result() + val prevValues = values.result() + val prevSize = size + val currentVector = getIterator(fieldIndex, vv(fieldIndex)) + indices = ArrayBuilder.make[Int] + values = ArrayBuilder.make[Double] + size *= currentVector.size + currentVector.foreachActive { (i, a) => + var j = 0 + while (j < prevIndices.length) { + indices += prevIndices(j) + i * prevSize + values += prevValues(j) * a + j += 1 + } + } + fieldIndex += 1 + } + Vectors.sparse(size, indices.result(), values.result()).compressed + } + + val interactFunc = udf { r: Row => + interact(r.toSeq: _*) + } + + val args = $(inputCols).map { c => + dataset.schema(c).dataType match { + case DoubleType => dataset(c) + case _: VectorUDT => dataset(c) + case _: NumericType | BooleanType => dataset(c).cast(DoubleType) + } + } + dataset.select( col("*"), - encode(array(inputCols).cast(DoubleType)).as(outputColName, metadata)) + interactFunc(struct(args: _*)).as($(outputCol))) } - override def copy(extra: ParamMap): RInteraction = defaultCopy(extra) + override def copy(extra: ParamMap): Interaction = defaultCopy(extra) private def checkParams(): Unit = { require(get(inputCols).isDefined, "Input cols must be defined first.") @@ -173,60 +249,6 @@ class Interaction(override val uid: String) extends Transformer } } -private object Interaction { - def getIterator(v: Any) = v match { - case d: Double => - Vectors.dense(d) - case vec: Vector => - var indices = ArrayBuilder.make[Int] - var values = ArrayBuilder.make[Double] - var cur = 0 - vec.foreachActive { case (i, v) => - // TODO(ekl) precompute cardinality from the ml attrs - if (CARDINALITY(i) > 0) { - indices += cur + v.toInt - values += 1.0 - cur += CARDINALITY(i) - } else { - indices += cur - values += v - cur += 1 - } - } - Vectors.sparse(cur, indices.result(), values.result()) - case null => - throw new SparkException("Values to interact cannot be null.") - case o => - throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.") - } - - def interact(vv: Any*): Vector = { - var indices = ArrayBuilder.make[Int] - var values = ArrayBuilder.make[Double] - var size = 1 - indices += 1 - values += 1.0 - vv.foreach { v => - val prevIndices = indices.result() - val prevValues = values.result() - val prevSize = size - val currentVector = getIterator(v) - indices = ArrayBuilder.make[Int] - values = ArrayBuilder.make[Double] - size *= currentVector.size - currentVector.foreachActive { (i, a) => - var j = 0 - while (j < prevIndices.length) { - indices += prevIndices(j) + i * prevSize - values += prevValues(j) * a - j += 1 - } - } - } - Vectors.sparse(size, indices.result(), values.result()).compressed - } -} - /** * Computes the joint index of multiple string-indexed columns such that the combined index * covers the cartesian product of column values. diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RInteractionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RInteractionSuite.scala index bf1eecef033a..7c614108f64a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RInteractionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RInteractionSuite.scala @@ -22,12 +22,53 @@ import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.functions.col class RInteractionSuite extends SparkFunSuite with MLlibTestSparkContext { test("params") { ParamsSuite.checkParams(new RInteraction()) } + test("new interaction") { + val data = sqlContext.createDataFrame( + Seq( + (1, "foo", true, 4, Vectors.dense(0.0, 0.0, 1.0), Vectors.dense(5.0, 3.0)), + (2, "bar", true, 4, Vectors.dense(1.0, 4.0, 2.0), Vectors.dense(4.0, 3.0)), + (3, "bar", true, 5, Vectors.dense(2.0, 5.0, 3.0), Vectors.dense(5.0, 3.0)), + (4, "baz", true, 5, Vectors.dense(3.0, 8.0, 4.0), Vectors.dense(5.0, 2.0)), + (4, "baz", false, 5, Vectors.dense(4.0, 9.0, 8.0), Vectors.dense(7.0, 1.0)), + (4, "baz", false, 5, Vectors.dense(5.0, 2.0, 9.0), Vectors.dense(2.0, 0.0))) + ).toDF("id", "a", "bin", "b", "test", "test2") + val attrs = new AttributeGroup( + "test", + Array[Attribute]( + NominalAttribute.defaultAttr.withValues(Array("a", "b", "c", "d", "e", "f")), + NumericAttribute.defaultAttr.withName("magnitude"), + NominalAttribute.defaultAttr.withValues( + Array("green", "blue", "red", "violet", "yellow", + "orange", "black", "white", "azure", "gray")))) + val attrs2 = new AttributeGroup( + "test2", + Array[Attribute]( + NumericAttribute.defaultAttr.withName("species"), + NominalAttribute.defaultAttr.withValues(Array("one", "two", "three", "four")))) + val df = data.select( + col("id"), col("b"), col("bin"), col("test").as("test", attrs.toMetadata()), + col("test2").as("test2", attrs.toMetadata())) + df.collect.foreach(println) + println(df.schema) + df.schema.foreach { field => + println(field.metadata) + } + val trans = new Interaction().setInputCols(Array("id", "test2", "test")).setOutputCol("feature") + val res = trans.transform(df) + res.collect.foreach(println) + println(res.schema) + res.schema.foreach { field => + println(field.metadata) + } + } + test("parameter validation") { val data = sqlContext.createDataFrame( Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")) From 50b21abaa0b128d9811a86cafaa4b4abbf712c41 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 15 Sep 2015 15:59:26 -0700 Subject: [PATCH 17/27] Tue Sep 15 15:59:26 PDT 2015 --- .../spark/ml/feature/RInteraction.scala | 69 +++++++++++++++++-- .../spark/ml/feature/RInteractionSuite.scala | 22 +++--- 2 files changed, 76 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala index f35e78d1c659..e8ad531d631b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala @@ -150,10 +150,18 @@ class Interaction(override val uid: String) extends Transformer override def transform(dataset: DataFrame): DataFrame = { checkParams() + val attrs = genAttrs($(inputCols).map(col => dataset.schema(col))) + val numValues: Array[Array[Int]] = $(inputCols).map { col => val field = dataset.schema(col) field.dataType match { - case _: NumericType | BooleanType => Array[Int]() + case _: NumericType | BooleanType => + Attribute.fromStructField(field) match { + case nominal: NominalAttribute => + Array(nominal.getNumValues.get) + case _ => + Array(0) + } case _: VectorUDT => val group = AttributeGroup.fromStructField(field) assert(group.attributes.isDefined, "TODO") @@ -165,21 +173,28 @@ class Interaction(override val uid: String) extends Transformer } } assert(numValues.length == $(inputCols).length) + println("num input cols: " + $(inputCols).mkString(", ")) + println("numValues: " + numValues.map(_.mkString(":")).mkString(", ")) def getIterator(fieldIndex: Int, v: Any) = v match { case d: Double => - Vectors.dense(d) + val numOutputCols = numValues(fieldIndex)(0) + if (numOutputCols > 0) { + Vectors.sparse(numOutputCols, Array(d.toInt), Array(1.0)) + } else { + Vectors.dense(d) + } case vec: Vector => var indices = ArrayBuilder.make[Int] var values = ArrayBuilder.make[Double] var cur = 0 vec.foreachActive { case (i, v) => - val numColumns = numValues(fieldIndex)(i) - if (numColumns > 0) { + val numOutputCols = numValues(fieldIndex)(i) + if (numOutputCols > 0) { // One-hot encode the field. TODO(ekl) support drop-last? indices += cur + v.toInt values += 1.0 - cur += numColumns + cur += numOutputCols } else { // Copy the field verbatim. indices += cur @@ -236,7 +251,49 @@ class Interaction(override val uid: String) extends Transformer dataset.select( col("*"), - interactFunc(struct(args: _*)).as($(outputCol))) + interactFunc(struct(args: _*)).as($(outputCol), attrs.toMetadata())) + } + + private def genAttrs(schema: Seq[StructField]): AttributeGroup = { + var attrs = Seq[Attribute]() + schema.foreach { field => + val attrIterator = field.dataType match { + case _: NumericType | BooleanType => + val attr = Attribute.fromStructField(field) + encodedAttrIterator(None, Seq(attr)) + case _: VectorUDT => + val group = AttributeGroup.fromStructField(field) + encodedAttrIterator(Some(group.name), group.attributes.get) + } + attrs = attrIterator.flatMap { attr => + if (attrs.isEmpty) { + Seq(attr) + } else { + attrs.map(prev => prev.withName(prev.name.getOrElse("UNKNOWN") + ":" + attr.name.get)) + } + } + } + println("Num attrs: " + attrs.length) + attrs.foreach(a => println("a: " + a.toMetadata)) + new AttributeGroup($(outputCol), attrs.toArray) + } + + private def encodedAttrIterator(groupName: Option[String], attrs: Seq[Attribute]): Seq[Attribute] = { + def format(i: Int, attrName: Option[String], value: Option[String]): String = { + Seq(groupName, Some(attrName.getOrElse(i.toString)), value).flatten.mkString("_") + } + attrs.zipWithIndex.flatMap { + case (nominal: NominalAttribute, i) => + if (nominal.values.isDefined) { + nominal.values.get.map( + v => BinaryAttribute.defaultAttr.withName(format(i, nominal.name, Some(v)))) + } else { + Array.tabulate(nominal.getNumValues.get)( + j => BinaryAttribute.defaultAttr.withName(format(i, nominal.name, Some(j.toString)))) + } + case (a: Attribute, i) => + Seq(NumericAttribute.defaultAttr.withName(format(i, a.name, None))) + } } override def copy(extra: ParamMap): Interaction = defaultCopy(extra) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RInteractionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RInteractionSuite.scala index 7c614108f64a..51e61966a1f7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RInteractionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RInteractionSuite.scala @@ -33,34 +33,38 @@ class RInteractionSuite extends SparkFunSuite with MLlibTestSparkContext { val data = sqlContext.createDataFrame( Seq( (1, "foo", true, 4, Vectors.dense(0.0, 0.0, 1.0), Vectors.dense(5.0, 3.0)), - (2, "bar", true, 4, Vectors.dense(1.0, 4.0, 2.0), Vectors.dense(4.0, 3.0)), - (3, "bar", true, 5, Vectors.dense(2.0, 5.0, 3.0), Vectors.dense(5.0, 3.0)), - (4, "baz", true, 5, Vectors.dense(3.0, 8.0, 4.0), Vectors.dense(5.0, 2.0)), - (4, "baz", false, 5, Vectors.dense(4.0, 9.0, 8.0), Vectors.dense(7.0, 1.0)), - (4, "baz", false, 5, Vectors.dense(5.0, 2.0, 9.0), Vectors.dense(2.0, 0.0))) + (1, "bar", true, 4, Vectors.dense(1.0, 4.0, 2.0), Vectors.dense(4.0, 3.0)), + (1, "bar", true, 5, Vectors.dense(2.0, 5.0, 3.0), Vectors.dense(5.0, 3.0)), + (1, "baz", true, 5, Vectors.dense(3.0, 8.0, 4.0), Vectors.dense(5.0, 2.0)), + (1, "baz", false, 5, Vectors.dense(4.0, 9.0, 8.0), Vectors.dense(7.0, 1.0)), + (2, "baz", false, 5, Vectors.dense(5.0, 2.0, 9.0), Vectors.dense(2.0, 0.0))) ).toDF("id", "a", "bin", "b", "test", "test2") val attrs = new AttributeGroup( "test", Array[Attribute]( NominalAttribute.defaultAttr.withValues(Array("a", "b", "c", "d", "e", "f")), NumericAttribute.defaultAttr.withName("magnitude"), - NominalAttribute.defaultAttr.withValues( + NominalAttribute.defaultAttr.withName("colors").withValues( Array("green", "blue", "red", "violet", "yellow", "orange", "black", "white", "azure", "gray")))) + + val idAttr = NominalAttribute.defaultAttr.withValues(Array("red", "blue")) val attrs2 = new AttributeGroup( "test2", Array[Attribute]( - NumericAttribute.defaultAttr.withName("species"), + NumericAttribute.defaultAttr, NominalAttribute.defaultAttr.withValues(Array("one", "two", "three", "four")))) val df = data.select( - col("id"), col("b"), col("bin"), col("test").as("test", attrs.toMetadata()), - col("test2").as("test2", attrs.toMetadata())) + col("id").as("id", idAttr.toMetadata()), col("b"), col("bin"), + col("test").as("test", attrs.toMetadata()), + col("test2").as("test2", attrs2.toMetadata())) df.collect.foreach(println) println(df.schema) df.schema.foreach { field => println(field.metadata) } val trans = new Interaction().setInputCols(Array("id", "test2", "test")).setOutputCol("feature") +// val trans = new Interaction().setInputCols(Array("id", "test")).setOutputCol("feature") val res = trans.transform(df) res.collect.foreach(println) println(res.schema) From 1ceefd27bbd820f289d922e9144f4d3cf4e40c6a Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 15 Sep 2015 16:22:48 -0700 Subject: [PATCH 18/27] Tue Sep 15 16:22:48 PDT 2015 --- .../spark/ml/feature/RInteraction.scala | 30 +++++++++++-------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala index e8ad531d631b..0abbfa40c849 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala @@ -209,6 +209,7 @@ class Interaction(override val uid: String) extends Transformer throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.") } + // TODO(ekl) make in same order as attrs def interact(vv: Any*): Vector = { var indices = ArrayBuilder.make[Int] var values = ArrayBuilder.make[Double] @@ -255,8 +256,18 @@ class Interaction(override val uid: String) extends Transformer } private def genAttrs(schema: Seq[StructField]): AttributeGroup = { - var attrs = Seq[Attribute]() - schema.foreach { field => + def gen(iterators: Seq[Seq[Attribute]]): Seq[Attribute] = { + if (iterators.length == 1) { + iterators.head + } else { + iterators.head.flatMap { head => + gen(iterators.tail).map { tail => + NumericAttribute.defaultAttr.withName(head.name.get + ":" + tail.name.get) + } + } + } + } + val iterators = schema.map { field => val attrIterator = field.dataType match { case _: NumericType | BooleanType => val attr = Attribute.fromStructField(field) @@ -265,22 +276,15 @@ class Interaction(override val uid: String) extends Transformer val group = AttributeGroup.fromStructField(field) encodedAttrIterator(Some(group.name), group.attributes.get) } - attrs = attrIterator.flatMap { attr => - if (attrs.isEmpty) { - Seq(attr) - } else { - attrs.map(prev => prev.withName(prev.name.getOrElse("UNKNOWN") + ":" + attr.name.get)) - } - } + attrIterator } - println("Num attrs: " + attrs.length) - attrs.foreach(a => println("a: " + a.toMetadata)) - new AttributeGroup($(outputCol), attrs.toArray) + new AttributeGroup($(outputCol), gen(iterators).toArray) } private def encodedAttrIterator(groupName: Option[String], attrs: Seq[Attribute]): Seq[Attribute] = { def format(i: Int, attrName: Option[String], value: Option[String]): String = { - Seq(groupName, Some(attrName.getOrElse(i.toString)), value).flatten.mkString("_") + val parts = Seq(groupName, Some(attrName.getOrElse(i.toString)), value) + parts.flatten.mkString("_") } attrs.zipWithIndex.flatMap { case (nominal: NominalAttribute, i) => From f357aca099e4c4fcf45b921a1fbdc1644a86df5f Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 15 Sep 2015 17:50:03 -0700 Subject: [PATCH 19/27] Tue Sep 15 17:50:03 PDT 2015 --- .../spark/ml/feature/RInteraction.scala | 33 +++++++++---------- .../spark/ml/feature/RInteractionSuite.scala | 16 ++++----- 2 files changed, 23 insertions(+), 26 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala index 0abbfa40c849..53c338d72572 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala @@ -209,15 +209,14 @@ class Interaction(override val uid: String) extends Transformer throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.") } - // TODO(ekl) make in same order as attrs def interact(vv: Any*): Vector = { var indices = ArrayBuilder.make[Int] var values = ArrayBuilder.make[Double] var size = 1 indices += 0 values += 1.0 - var fieldIndex = 0 - while (fieldIndex < vv.length) { + var fieldIndex = vv.length - 1 + while (fieldIndex >= 0) { val prevIndices = indices.result() val prevValues = values.result() val prevSize = size @@ -233,7 +232,7 @@ class Interaction(override val uid: String) extends Transformer j += 1 } } - fieldIndex += 1 + fieldIndex -= 1 } Vectors.sparse(size, indices.result(), values.result()).compressed } @@ -256,18 +255,8 @@ class Interaction(override val uid: String) extends Transformer } private def genAttrs(schema: Seq[StructField]): AttributeGroup = { - def gen(iterators: Seq[Seq[Attribute]]): Seq[Attribute] = { - if (iterators.length == 1) { - iterators.head - } else { - iterators.head.flatMap { head => - gen(iterators.tail).map { tail => - NumericAttribute.defaultAttr.withName(head.name.get + ":" + tail.name.get) - } - } - } - } - val iterators = schema.map { field => + var attrs: Seq[Attribute] = Nil + schema.reverse.map { field => val attrIterator = field.dataType match { case _: NumericType | BooleanType => val attr = Attribute.fromStructField(field) @@ -276,9 +265,17 @@ class Interaction(override val uid: String) extends Transformer val group = AttributeGroup.fromStructField(field) encodedAttrIterator(Some(group.name), group.attributes.get) } - attrIterator + if (attrs.isEmpty) { + attrs = attrIterator + } else { + attrs = attrIterator.flatMap { head => + attrs.map { tail => + NumericAttribute.defaultAttr.withName(head.name.get + ":" + tail.name.get) + } + } + } } - new AttributeGroup($(outputCol), gen(iterators).toArray) + new AttributeGroup($(outputCol), attrs.toArray) } private def encodedAttrIterator(groupName: Option[String], attrs: Seq[Attribute]): Seq[Attribute] = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RInteractionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RInteractionSuite.scala index 51e61966a1f7..f14a3293159a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RInteractionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RInteractionSuite.scala @@ -32,12 +32,12 @@ class RInteractionSuite extends SparkFunSuite with MLlibTestSparkContext { test("new interaction") { val data = sqlContext.createDataFrame( Seq( - (1, "foo", true, 4, Vectors.dense(0.0, 0.0, 1.0), Vectors.dense(5.0, 3.0)), - (1, "bar", true, 4, Vectors.dense(1.0, 4.0, 2.0), Vectors.dense(4.0, 3.0)), - (1, "bar", true, 5, Vectors.dense(2.0, 5.0, 3.0), Vectors.dense(5.0, 3.0)), - (1, "baz", true, 5, Vectors.dense(3.0, 8.0, 4.0), Vectors.dense(5.0, 2.0)), - (1, "baz", false, 5, Vectors.dense(4.0, 9.0, 8.0), Vectors.dense(7.0, 1.0)), - (2, "baz", false, 5, Vectors.dense(5.0, 2.0, 9.0), Vectors.dense(2.0, 0.0))) + (0, "foo", true, 4, Vectors.dense(0.0, 0.0, 1.0), Vectors.dense(5.0, 3.0)), + (0, "bar", true, 4, Vectors.dense(1.0, 4.0, 2.0), Vectors.dense(4.0, 3.0)), + (0, "bar", true, 5, Vectors.dense(2.0, 5.0, 3.0), Vectors.dense(5.0, 3.0)), + (0, "baz", true, 5, Vectors.dense(3.0, 8.0, 4.0), Vectors.dense(5.0, 2.0)), + (0, "baz", false, 5, Vectors.dense(4.0, 9.0, 8.0), Vectors.dense(7.0, 1.0)), + (1, "baz", false, 5, Vectors.dense(5.0, 2.0, 9.0), Vectors.dense(2.0, 0.0))) ).toDF("id", "a", "bin", "b", "test", "test2") val attrs = new AttributeGroup( "test", @@ -63,8 +63,8 @@ class RInteractionSuite extends SparkFunSuite with MLlibTestSparkContext { df.schema.foreach { field => println(field.metadata) } - val trans = new Interaction().setInputCols(Array("id", "test2", "test")).setOutputCol("feature") -// val trans = new Interaction().setInputCols(Array("id", "test")).setOutputCol("feature") +// val trans = new Interaction().setInputCols(Array("id", "test2", "test")).setOutputCol("feature") + val trans = new Interaction().setInputCols(Array("id", "test2")).setOutputCol("feature") val res = trans.transform(df) res.collect.foreach(println) println(res.schema) From 569428d39d07f0073ce52864c813326910c01f46 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 15 Sep 2015 19:03:51 -0700 Subject: [PATCH 20/27] Tue Sep 15 19:03:51 PDT 2015 --- .../apache/spark/ml/feature/Interaction.scala | 238 ++++++++++ .../spark/ml/feature/RInteraction.scala | 431 ------------------ .../spark/ml/feature/InteractionSuite.scala | 211 +++++++++ .../spark/ml/feature/RInteractionSuite.scala | 185 -------- 4 files changed, 449 insertions(+), 616 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala delete mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala delete mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/RInteractionSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala new file mode 100644 index 000000000000..d47207ae11f4 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.collection.mutable.{ArrayBuffer, ArrayBuilder} + +import org.apache.spark.SparkException +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} +import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +/** + * :: Experimental :: + * Implements the transforms required for R-style feature interactions. In summary, once fitted to + * a dataset, this transformer jointly one-hot encodes all factor input columns, then scales + * the encoded vector by all numeric input columns. If only numeric columns are specified, the + * output column will be a one-length vector containing their product. During one-hot encoding, + * the last category will be preserved unless the interaction is trivial. + * + * See https://stat.ethz.ch/R-manual/R-devel/library/base/html/formula.html for more + * information about factor interactions in R formulae. + */ +@Experimental +class Interaction(override val uid: String) extends Transformer + with HasInputCols with HasOutputCol { + + def this() = this(Identifiable.randomUID("interaction")) + + /** @group setParam */ + def setInputCols(values: Array[String]): this.type = set(inputCols, values) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + // optimistic schema; does not contain any ML attributes + override def transformSchema(schema: StructType): StructType = { + checkParams() + StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, true)) + } + + override def transform(dataset: DataFrame): DataFrame = { + checkParams() + val fieldIterators = getIterators(dataset) + + def interact(vv: Any*): Vector = { + var indices = ArrayBuilder.make[Int] + var values = ArrayBuilder.make[Double] + var size = 1 + indices += 0 + values += 1.0 + var fieldIndex = vv.length - 1 + while (fieldIndex >= 0) { + val prevIndices = indices.result() + val prevValues = values.result() + val prevSize = size + val currentIterator = fieldIterators(fieldIndex) + indices = ArrayBuilder.make[Int] + values = ArrayBuilder.make[Double] + size *= currentIterator.size + currentIterator.foreachActive(vv(fieldIndex), (i, a) => { + var j = 0 + while (j < prevIndices.length) { + indices += prevIndices(j) + i * prevSize + values += prevValues(j) * a + j += 1 + } + }) + fieldIndex -= 1 + } + Vectors.sparse(size, indices.result(), values.result()).compressed + } + + val interactFunc = udf { r: Row => + interact(r.toSeq: _*) + } + val args = $(inputCols).map { c => + dataset.schema(c).dataType match { + case DoubleType => dataset(c) + case _: VectorUDT => dataset(c) + case _: NumericType | BooleanType => dataset(c).cast(DoubleType) + } + } + val attrs = generateAttrs($(inputCols).map(col => dataset.schema(col))) + dataset.select( + col("*"), + interactFunc(struct(args: _*)).as($(outputCol), attrs.toMetadata())) + } + + private def getIterators(dataset: DataFrame): Array[EncodingIterator] = { + def getCardinality(attr: Attribute): Int = { + attr match { + case nominal: NominalAttribute => + nominal.getNumValues.getOrElse( + throw new SparkException("Nominal fields must have attr numValues defined.")) + case _ => + 0 // treated as numeric + } + } + $(inputCols).map { col => + val field = dataset.schema(col) + val cardinalities = field.dataType match { + case _: NumericType | BooleanType => + Array(getCardinality(Attribute.fromStructField(field))) + case _: VectorUDT => + val attrs = AttributeGroup.fromStructField(field).attributes.getOrElse( + throw new SparkException("Vector fields must have attributes defined.")) + attrs.map(getCardinality).toArray + } + new EncodingIterator(cardinalities) + }.toArray + } + + private def generateAttrs(schema: Seq[StructField]): AttributeGroup = { + var attrs: Seq[Attribute] = Nil + schema.reverse.map { field => + val attrIterator = field.dataType match { + case _: NumericType | BooleanType => + val attr = Attribute.fromStructField(field) + encodedAttrIterator(None, Seq(attr)) + case _: VectorUDT => + val group = AttributeGroup.fromStructField(field) + encodedAttrIterator(Some(group.name), group.attributes.get) + } + if (attrs.isEmpty) { + attrs = attrIterator + } else { + attrs = attrIterator.flatMap { head => + attrs.map { tail => + NumericAttribute.defaultAttr.withName(head.name.get + ":" + tail.name.get) + } + } + } + } + new AttributeGroup($(outputCol), attrs.toArray) + } + + private def encodedAttrIterator(groupName: Option[String], attrs: Seq[Attribute]): Seq[Attribute] = { + def format(i: Int, attrName: Option[String], value: Option[String]): String = { + val parts = Seq(groupName, Some(attrName.getOrElse(i.toString)), value) + parts.flatten.mkString("_") + } + attrs.zipWithIndex.flatMap { + case (nominal: NominalAttribute, i) => + if (nominal.values.isDefined) { + nominal.values.get.map( + v => BinaryAttribute.defaultAttr.withName(format(i, nominal.name, Some(v)))) + } else { + Array.tabulate(nominal.getNumValues.get)( + j => BinaryAttribute.defaultAttr.withName(format(i, nominal.name, Some(j.toString)))) + } + case (a: Attribute, i) => + Seq(NumericAttribute.defaultAttr.withName(format(i, a.name, None))) + } + } + + override def copy(extra: ParamMap): Interaction = defaultCopy(extra) + + private def checkParams(): Unit = { + require(get(inputCols).isDefined, "Input cols must be defined first.") + require(get(outputCol).isDefined, "Output col must be defined first.") + require($(inputCols).length > 0, "Input cols must have non-zero length.") + require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.") + } +} + +/** + * An iterator over VectorUDT or Double that one-hot encodes nominal fields in the output. + * + * @param cardinalities An array defining the cardinality of each vector sub-field, or a single + * value if the field is numeric. Fields with zero cardinality will not be + * one-hot encoded (output verbatim). + */ +// TODO(ekl) do we want to support drop-last like OneHotEncoder does? +private[ml] class EncodingIterator(cardinalities: Array[Int]) { + /** The size of the output vector. */ + val size = cardinalities.map(i => if (i > 0) i else 1).sum + + /** + * @param value The row to iterate over, either a Double or Vector. + * @param f The callback to invoke on each non-zero (index, value) output pair. + */ + def foreachActive(value: Any, f: (Int, Double) => Unit) = value match { + case d: Double => + assert(cardinalities.length == 1) + val numOutputCols = cardinalities(0) + if (numOutputCols > 0) { + assert(d >= 0.0 && d == d.toInt, s"Values from column must be indices, but got $d.") + f(d.toInt, 1.0) + } else { + f(0, d) + } + case vec: Vector => + assert(cardinalities.length == vec.size, + s"Vector column size was ${vec.size}, expected ${cardinalities.length}") + val dense = vec.toDense + var i = 0 + var cur = 0 + while (i < dense.size) { + val numOutputCols = cardinalities(i) + if (numOutputCols > 0) { + val x = dense.values(i).toInt + assert(x >= 0.0 && x == x.toInt, s"Values from column must be indices, but got $x.") + f(cur + x, 1.0) + cur += numOutputCols + } else { + f(cur, dense.values(i)) + cur += 1 + } + i += 1 + } + case null => + throw new SparkException("Values to interact cannot be null.") + case o => + throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.") + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala deleted file mode 100644 index 53c338d72572..000000000000 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RInteraction.scala +++ /dev/null @@ -1,431 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.feature - -import scala.collection.mutable.{ArrayBuffer, ArrayBuilder} - -import org.apache.spark.SparkException -import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable -import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} -import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} -import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types._ - -/** - * :: Experimental :: - * Implements the transforms required for R-style feature interactions. In summary, once fitted to - * a dataset, this transformer jointly one-hot encodes all factor input columns, then scales - * the encoded vector by all numeric input columns. If only numeric columns are specified, the - * output column will be a one-length vector containing their product. During one-hot encoding, - * the last category will be preserved unless the interaction is trivial. - * - * See https://stat.ethz.ch/R-manual/R-devel/library/base/html/formula.html for more - * information about factor interactions in R formulae. - */ -@Experimental -class RInteraction(override val uid: String) extends Estimator[PipelineModel] - with HasInputCols with HasOutputCol { - - def this() = this(Identifiable.randomUID("interaction")) - - /** @group setParam */ - def setInputCols(values: Array[String]): this.type = set(inputCols, values) - - /** @group setParam */ - def setOutputCol(value: String): this.type = set(outputCol, value) - - override def fit(dataset: DataFrame): PipelineModel = { - checkParams() - val encoderStages = ArrayBuffer[PipelineStage]() - val tempColumns = ArrayBuffer[String]() - val (factorCols, nonFactorCols) = $(inputCols) - .partition(input => dataset.schema(input).dataType == StringType) - - val encodedFactors: Option[String] = - if (factorCols.length > 0) { - val indexedCols = factorCols.map { input => - val output = input + "_idx_" + uid - encoderStages += new StringIndexer() - .setInputCol(input) - .setOutputCol(output) - tempColumns += output - output - } - val combinedIndex = "combined_idx_" + uid - tempColumns += combinedIndex - val encodedCol = if (nonFactorCols.length > 0) { - "factors_" + uid - } else { - $(outputCol) - } - encoderStages += new IndexCombiner(indexedCols, factorCols, combinedIndex) - encoderStages += { - val encoder = new OneHotEncoder() - .setInputCol(combinedIndex) - .setOutputCol(encodedCol) - if ($(inputCols).length > 1) { - encoder.setDropLast(false) // R includes all columns for interactions. - } - encoder - } - Some(encodedCol) - } else { - None - } - - if (nonFactorCols.length > 0) { - encoderStages += new NumericInteraction(nonFactorCols, encodedFactors, $(outputCol)) - if (encodedFactors.isDefined) { - tempColumns += encodedFactors.get - } - } - - encoderStages += new ColumnPruner(tempColumns.toSet) - new Pipeline(uid) - .setStages(encoderStages.toArray) - .fit(dataset) - .setParent(this) - } - - // optimistic schema; does not contain any ML attributes - override def transformSchema(schema: StructType): StructType = { - checkParams() - if ($(inputCols).exists(col => schema(col).dataType == StringType)) { - StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, true)) - } else { - StructType(schema.fields :+ StructField($(outputCol), DoubleType, true)) - } - } - - override def copy(extra: ParamMap): RInteraction = defaultCopy(extra) - - private def checkParams(): Unit = { - require(get(inputCols).isDefined, "Input cols must be defined first.") - require(get(outputCol).isDefined, "Output col must be defined first.") - require($(inputCols).length > 0, "Input cols must have non-zero length.") - require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.") - } -} - -class Interaction(override val uid: String) extends Transformer - with HasInputCols with HasOutputCol { - - def this() = this(Identifiable.randomUID("interaction")) - - /** @group setParam */ - def setInputCols(values: Array[String]): this.type = set(inputCols, values) - - /** @group setParam */ - def setOutputCol(value: String): this.type = set(outputCol, value) - - override def transformSchema(schema: StructType): StructType = { - checkParams() - if ($(inputCols).exists(col => schema(col).dataType == StringType)) { - StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, true)) - } else { - StructType(schema.fields :+ StructField($(outputCol), DoubleType, true)) - } - } - - override def transform(dataset: DataFrame): DataFrame = { - checkParams() - - val attrs = genAttrs($(inputCols).map(col => dataset.schema(col))) - - val numValues: Array[Array[Int]] = $(inputCols).map { col => - val field = dataset.schema(col) - field.dataType match { - case _: NumericType | BooleanType => - Attribute.fromStructField(field) match { - case nominal: NominalAttribute => - Array(nominal.getNumValues.get) - case _ => - Array(0) - } - case _: VectorUDT => - val group = AttributeGroup.fromStructField(field) - assert(group.attributes.isDefined, "TODO") - val cardinalities = group.attributes.get.map { - case nominal: NominalAttribute => nominal.getNumValues.get - case _ => 0 - } - cardinalities.toArray - } - } - assert(numValues.length == $(inputCols).length) - println("num input cols: " + $(inputCols).mkString(", ")) - println("numValues: " + numValues.map(_.mkString(":")).mkString(", ")) - - def getIterator(fieldIndex: Int, v: Any) = v match { - case d: Double => - val numOutputCols = numValues(fieldIndex)(0) - if (numOutputCols > 0) { - Vectors.sparse(numOutputCols, Array(d.toInt), Array(1.0)) - } else { - Vectors.dense(d) - } - case vec: Vector => - var indices = ArrayBuilder.make[Int] - var values = ArrayBuilder.make[Double] - var cur = 0 - vec.foreachActive { case (i, v) => - val numOutputCols = numValues(fieldIndex)(i) - if (numOutputCols > 0) { - // One-hot encode the field. TODO(ekl) support drop-last? - indices += cur + v.toInt - values += 1.0 - cur += numOutputCols - } else { - // Copy the field verbatim. - indices += cur - values += v - cur += 1 - } - } - Vectors.sparse(cur, indices.result(), values.result()) - case null => - throw new SparkException("Values to interact cannot be null.") - case o => - throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.") - } - - def interact(vv: Any*): Vector = { - var indices = ArrayBuilder.make[Int] - var values = ArrayBuilder.make[Double] - var size = 1 - indices += 0 - values += 1.0 - var fieldIndex = vv.length - 1 - while (fieldIndex >= 0) { - val prevIndices = indices.result() - val prevValues = values.result() - val prevSize = size - val currentVector = getIterator(fieldIndex, vv(fieldIndex)) - indices = ArrayBuilder.make[Int] - values = ArrayBuilder.make[Double] - size *= currentVector.size - currentVector.foreachActive { (i, a) => - var j = 0 - while (j < prevIndices.length) { - indices += prevIndices(j) + i * prevSize - values += prevValues(j) * a - j += 1 - } - } - fieldIndex -= 1 - } - Vectors.sparse(size, indices.result(), values.result()).compressed - } - - val interactFunc = udf { r: Row => - interact(r.toSeq: _*) - } - - val args = $(inputCols).map { c => - dataset.schema(c).dataType match { - case DoubleType => dataset(c) - case _: VectorUDT => dataset(c) - case _: NumericType | BooleanType => dataset(c).cast(DoubleType) - } - } - - dataset.select( - col("*"), - interactFunc(struct(args: _*)).as($(outputCol), attrs.toMetadata())) - } - - private def genAttrs(schema: Seq[StructField]): AttributeGroup = { - var attrs: Seq[Attribute] = Nil - schema.reverse.map { field => - val attrIterator = field.dataType match { - case _: NumericType | BooleanType => - val attr = Attribute.fromStructField(field) - encodedAttrIterator(None, Seq(attr)) - case _: VectorUDT => - val group = AttributeGroup.fromStructField(field) - encodedAttrIterator(Some(group.name), group.attributes.get) - } - if (attrs.isEmpty) { - attrs = attrIterator - } else { - attrs = attrIterator.flatMap { head => - attrs.map { tail => - NumericAttribute.defaultAttr.withName(head.name.get + ":" + tail.name.get) - } - } - } - } - new AttributeGroup($(outputCol), attrs.toArray) - } - - private def encodedAttrIterator(groupName: Option[String], attrs: Seq[Attribute]): Seq[Attribute] = { - def format(i: Int, attrName: Option[String], value: Option[String]): String = { - val parts = Seq(groupName, Some(attrName.getOrElse(i.toString)), value) - parts.flatten.mkString("_") - } - attrs.zipWithIndex.flatMap { - case (nominal: NominalAttribute, i) => - if (nominal.values.isDefined) { - nominal.values.get.map( - v => BinaryAttribute.defaultAttr.withName(format(i, nominal.name, Some(v)))) - } else { - Array.tabulate(nominal.getNumValues.get)( - j => BinaryAttribute.defaultAttr.withName(format(i, nominal.name, Some(j.toString)))) - } - case (a: Attribute, i) => - Seq(NumericAttribute.defaultAttr.withName(format(i, a.name, None))) - } - } - - override def copy(extra: ParamMap): Interaction = defaultCopy(extra) - - private def checkParams(): Unit = { - require(get(inputCols).isDefined, "Input cols must be defined first.") - require(get(outputCol).isDefined, "Output col must be defined first.") - require($(inputCols).length > 0, "Input cols must have non-zero length.") - require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.") - } -} - -/** - * Computes the joint index of multiple string-indexed columns such that the combined index - * covers the cartesian product of column values. - */ -private class IndexCombiner( - inputCols: Array[String], attrNames: Array[String], outputCol: String) - extends Transformer { - - override val uid = Identifiable.randomUID("indexCombiner") - - override def transform(dataset: DataFrame): DataFrame = { - val inputMetadata = inputCols.map(col => - Attribute.fromStructField(dataset.schema(col)).asInstanceOf[NominalAttribute]) - val cardinalities = inputMetadata.map(_.values.get.length) - val combiner = udf { values: Seq[Double] => - var offset = 1 - var res = 0.0 - var i = 0 - while (i < values.length) { - res += values(i) * offset - offset *= cardinalities(i) - i += 1 - } - res - } - val metadata = NominalAttribute.defaultAttr - .withName(outputCol) - .withValues(generateAttrNames(inputMetadata, attrNames)) - .toMetadata() - dataset.select( - col("*"), - combiner(array(inputCols.map(dataset(_)): _*)).as(outputCol, metadata)) - } - - override def transformSchema(schema: StructType): StructType = { - StructType(schema.fields :+ StructField(outputCol, DoubleType, true)) - } - - override def copy(extra: ParamMap): IndexCombiner = defaultCopy(extra) - - private def generateAttrNames( - attrs: Array[NominalAttribute], names: Array[String]): Array[String] = { - val colName = names.head - val attrNames = attrs.head.values.get.map(colName + "_" + _) - if (attrs.length <= 1) { - attrNames - } else { - generateAttrNames(attrs.tail, names.tail).flatMap { rest => - attrNames.map(n => n + ":" + rest) - } - } - } -} - -/** - * Scales the input vector column by the product of the input numeric columns. If no vector column - * is specified, the output is just the product of the numeric columns. - */ -private class NumericInteraction( - inputCols: Array[String], vectorCol: Option[String], outputCol: String) - extends Transformer { - - override val uid = Identifiable.randomUID("numericInteraction") - - override def transform(dataset: DataFrame): DataFrame = { - if (vectorCol.isDefined) { - val scale = udf { (vec: Vector, scalars: Seq[Double]) => - var x = 1.0 - var i = 0 - while (i < scalars.length) { - x *= scalars(i) - i += 1 - } - val indices = ArrayBuilder.make[Int] - val values = ArrayBuilder.make[Double] - vec.foreachActive { case (i, v) => - if (v != 0.0) { - indices += i - values += v * x - } - } - Vectors.sparse(vec.size, indices.result(), values.result()).compressed - } - val group = AttributeGroup.fromStructField(dataset.schema(vectorCol.get)) - val attrs = group.attributes.get.map { attr => - attr.withName(attr.name.get + ":" + inputCols.mkString(":")) - } - val metadata = new AttributeGroup(outputCol, attrs).toMetadata() - dataset.select( - col("*"), - scale( - col(vectorCol.get), - array(inputCols.map(dataset(_).cast(DoubleType)): _*)).as(outputCol, metadata)) - } else { - val multiply = udf { values: Seq[Double] => - var x = 1.0 - var i = 0 - while (i < values.length) { - x *= values(i) - i += 1 - } - x - } - val metadata = NumericAttribute.defaultAttr - .withName(inputCols.mkString(":")) - .toMetadata() - dataset.select( - col("*"), - multiply(array(inputCols.map(dataset(_).cast(DoubleType)): _*)).as(outputCol, metadata)) - } - } - - override def transformSchema(schema: StructType): StructType = { - if (vectorCol.isDefined) { - StructType(schema.fields :+ StructField(outputCol, new VectorUDT, true)) - } else { - StructType(schema.fields :+ StructField(outputCol, DoubleType, true)) - } - } - - override def copy(extra: ParamMap): NumericInteraction = defaultCopy(extra) -} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala new file mode 100644 index 000000000000..b89bce6dc11e --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.collection.mutable.ArrayBuilder + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.functions.col + +class RInteractionSuite extends SparkFunSuite with MLlibTestSparkContext { + test("params") { + ParamsSuite.checkParams(new Interaction()) + } + + test("encoding iterator") { + def encode(cardinalities: Array[Int], value: Any): Vector = { + var indices = ArrayBuilder.make[Int] + var values = ArrayBuilder.make[Double] + val iter = new EncodingIterator(cardinalities) + iter.foreachActive(value, (i, v) => { + indices += i + values += v + }) + Vectors.sparse(iter.size, indices.result(), values.result()).compressed + } + assert(encode(Array(0), 2.2) === Vectors.dense(2.2)) + assert(encode(Array(3), Vectors.dense(1)) === Vectors.dense(0, 1, 0)) + assert(encode(Array(0, 0), Vectors.dense(1.1, 2.2)) === Vectors.dense(1.1, 2.2)) + assert(encode(Array(3, 0), Vectors.dense(1, 2.2)) === Vectors.dense(0, 1, 0, 2.2)) + assert(encode(Array(2, 0), Vectors.dense(1, 2.2)) === Vectors.dense(0, 1, 2.2)) + assert(encode(Array(2, 0, 1), Vectors.dense(0, 2.2, 0)) === Vectors.dense(1, 0, 2.2, 1)) + intercept[SparkException] { encode(Array(0), "foo") } + intercept[SparkException] { encode(Array(0), null) } + intercept[AssertionError] { encode(Array(1), 2.2) } + intercept[AssertionError] { encode(Array(1), Vectors.dense(2.2)) } + intercept[AssertionError] { encode(Array(1), Vectors.dense(1.0, 2.0, 3.0)) } + } + + test("new interaction") { + val data = sqlContext.createDataFrame( + Seq( + (0, "foo", true, 4, Vectors.dense(0.0, 0.0, 1.0), Vectors.dense(5.0, 3.0)), + (0, "bar", true, 4, Vectors.dense(1.0, 4.0, 2.0), Vectors.dense(4.0, 3.0)), + (0, "bar", true, 5, Vectors.dense(2.0, 5.0, 3.0), Vectors.dense(5.0, 3.0)), + (0, "baz", true, 5, Vectors.dense(3.0, 8.0, 4.0), Vectors.dense(5.0, 2.0)), + (0, "baz", false, 5, Vectors.dense(4.0, 9.0, 8.0), Vectors.dense(7.0, 1.0)), + (1, "baz", false, 5, Vectors.dense(5.0, 2.0, 9.0), Vectors.dense(2.0, 0.0))) + ).toDF("id", "a", "bin", "b", "test", "test2") + val attrs = new AttributeGroup( + "test", + Array[Attribute]( + NominalAttribute.defaultAttr.withValues(Array("a", "b", "c", "d", "e", "f")), + NumericAttribute.defaultAttr.withName("magnitude"), + NominalAttribute.defaultAttr.withName("colors").withValues( + Array("green", "blue", "red", "violet", "yellow", + "orange", "black", "white", "azure", "gray")))) + + val idAttr = NominalAttribute.defaultAttr.withValues(Array("red", "blue")) + val attrs2 = new AttributeGroup( + "test2", + Array[Attribute]( + NumericAttribute.defaultAttr, + NominalAttribute.defaultAttr.withValues(Array("one", "two", "three", "four")))) + val df = data.select( + col("id").as("id", idAttr.toMetadata()), col("b"), col("bin"), + col("test").as("test", attrs.toMetadata()), + col("test2").as("test2", attrs2.toMetadata())) + df.collect.foreach(println) + println(df.schema) + df.schema.foreach { field => + println(field.metadata) + } +// val trans = new Interaction().setInputCols(Array("id", "test2", "test")).setOutputCol("feature") + val trans = new Interaction().setInputCols(Array("id", "test2")).setOutputCol("feature") + val res = trans.transform(df) + res.collect.foreach(println) + println(res.schema) + res.schema.foreach { field => + println(field.metadata) + } + } +// +// test("parameter validation") { +// val data = sqlContext.createDataFrame( +// Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")) +// ).toDF("id", "a", "b") +// def check(inputCols: Array[String], outputCol: String, expectOk: Boolean): Unit = { +// val interaction = new RInteraction() +// if (inputCols != null) { +// interaction.setInputCols(inputCols) +// } +// if (outputCol != null) { +// interaction.setOutputCol(outputCol) +// } +// if (expectOk) { +// interaction.transformSchema(data.schema) +// interaction.fit(data).transform(data).collect() +// } else { +// intercept[IllegalArgumentException] { +// interaction.fit(data) +// } +// intercept[IllegalArgumentException] { +// interaction.transformSchema(data.schema) +// } +// } +// } +// check(Array("a", "b"), "test", true) +// check(Array("id"), "test", true) +// check(Array("b"), "test", true) +// check(Array("b"), "test", true) +// check(Array(), "test", false) +// check(Array("a", "b", "b"), "id", false) +// check(Array("a", "b"), null, false) +// check(null, "test", false) +// } +// +// test("numeric interaction") { +// val interaction = new RInteraction() +// .setInputCols(Array("b", "c", "d")) +// .setOutputCol("test") +// val original = sqlContext.createDataFrame( +// Seq((1, 2, 4, 2), (2, 3, 4, 1)) +// ).toDF("a", "b", "c", "d") +// val model = interaction.fit(original) +// val result = model.transform(original) +// val expected = sqlContext.createDataFrame( +// Seq( +// (1, 2, 4, 2, 16.0), +// (2, 3, 4, 1, 12.0)) +// ).toDF("a", "b", "c", "d", "test") +// assert(result.collect() === expected.collect()) +// val attr = Attribute.decodeStructField(result.schema("test"), preserveName = true) +// val expectedAttr = new NumericAttribute(Some("b:c:d"), None) +// assert(attr === expectedAttr) +// } +// +// test("factor interaction") { +// val interaction = new RInteraction() +// .setInputCols(Array("a", "b")) +// .setOutputCol("test") +// val original = sqlContext.createDataFrame( +// Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")) +// ).toDF("id", "a", "b") +// val model = interaction.fit(original) +// val result = model.transform(original) +// val expected = sqlContext.createDataFrame( +// Seq( +// (1, "foo", "zq", Vectors.dense(0.0, 1.0, 0.0, 0.0)), +// (2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0)), +// (3, "bar", "zz", Vectors.dense(0.0, 0.0, 1.0, 0.0))) +// ).toDF("id", "a", "b", "test") +// assert(result.collect() === expected.collect()) +// val attrs = AttributeGroup.fromStructField(result.schema("test")) +// val expectedAttrs = new AttributeGroup( +// "test", +// Array[Attribute]( +// new BinaryAttribute(Some("a_bar:b_zq"), Some(1)), +// new BinaryAttribute(Some("a_foo:b_zq"), Some(2)), +// new BinaryAttribute(Some("a_bar:b_zz"), Some(3)), +// new BinaryAttribute(Some("a_foo:b_zz"), Some(4)))) +// assert(attrs === expectedAttrs) +// } +// +// test("factor numeric interaction") { +// val interaction = new RInteraction() +// .setInputCols(Array("a", "b")) +// .setOutputCol("test") +// val original = sqlContext.createDataFrame( +// Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5), (4, "baz", 5), (4, "baz", 5)) +// ).toDF("id", "a", "b") +// val model = interaction.fit(original) +// val result = model.transform(original) +// val expected = sqlContext.createDataFrame( +// Seq( +// (1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0)), +// (2, "bar", 4, Vectors.dense(0.0, 4.0, 0.0)), +// (3, "bar", 5, Vectors.dense(0.0, 5.0, 0.0)), +// (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0)), +// (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0)), +// (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0))) +// ).toDF("id", "a", "b", "test") +// assert(result.collect() === expected.collect()) +// val attrs = AttributeGroup.fromStructField(result.schema("test")) +// val expectedAttrs = new AttributeGroup( +// "test", +// Array[Attribute]( +// new BinaryAttribute(Some("a_baz:b"), Some(1)), +// new BinaryAttribute(Some("a_bar:b"), Some(2)), +// new BinaryAttribute(Some("a_foo:b"), Some(3)))) +// assert(attrs === expectedAttrs) +// } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RInteractionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RInteractionSuite.scala deleted file mode 100644 index f14a3293159a..000000000000 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RInteractionSuite.scala +++ /dev/null @@ -1,185 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.feature - -import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.functions.col - -class RInteractionSuite extends SparkFunSuite with MLlibTestSparkContext { - test("params") { - ParamsSuite.checkParams(new RInteraction()) - } - - test("new interaction") { - val data = sqlContext.createDataFrame( - Seq( - (0, "foo", true, 4, Vectors.dense(0.0, 0.0, 1.0), Vectors.dense(5.0, 3.0)), - (0, "bar", true, 4, Vectors.dense(1.0, 4.0, 2.0), Vectors.dense(4.0, 3.0)), - (0, "bar", true, 5, Vectors.dense(2.0, 5.0, 3.0), Vectors.dense(5.0, 3.0)), - (0, "baz", true, 5, Vectors.dense(3.0, 8.0, 4.0), Vectors.dense(5.0, 2.0)), - (0, "baz", false, 5, Vectors.dense(4.0, 9.0, 8.0), Vectors.dense(7.0, 1.0)), - (1, "baz", false, 5, Vectors.dense(5.0, 2.0, 9.0), Vectors.dense(2.0, 0.0))) - ).toDF("id", "a", "bin", "b", "test", "test2") - val attrs = new AttributeGroup( - "test", - Array[Attribute]( - NominalAttribute.defaultAttr.withValues(Array("a", "b", "c", "d", "e", "f")), - NumericAttribute.defaultAttr.withName("magnitude"), - NominalAttribute.defaultAttr.withName("colors").withValues( - Array("green", "blue", "red", "violet", "yellow", - "orange", "black", "white", "azure", "gray")))) - - val idAttr = NominalAttribute.defaultAttr.withValues(Array("red", "blue")) - val attrs2 = new AttributeGroup( - "test2", - Array[Attribute]( - NumericAttribute.defaultAttr, - NominalAttribute.defaultAttr.withValues(Array("one", "two", "three", "four")))) - val df = data.select( - col("id").as("id", idAttr.toMetadata()), col("b"), col("bin"), - col("test").as("test", attrs.toMetadata()), - col("test2").as("test2", attrs2.toMetadata())) - df.collect.foreach(println) - println(df.schema) - df.schema.foreach { field => - println(field.metadata) - } -// val trans = new Interaction().setInputCols(Array("id", "test2", "test")).setOutputCol("feature") - val trans = new Interaction().setInputCols(Array("id", "test2")).setOutputCol("feature") - val res = trans.transform(df) - res.collect.foreach(println) - println(res.schema) - res.schema.foreach { field => - println(field.metadata) - } - } - - test("parameter validation") { - val data = sqlContext.createDataFrame( - Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")) - ).toDF("id", "a", "b") - def check(inputCols: Array[String], outputCol: String, expectOk: Boolean): Unit = { - val interaction = new RInteraction() - if (inputCols != null) { - interaction.setInputCols(inputCols) - } - if (outputCol != null) { - interaction.setOutputCol(outputCol) - } - if (expectOk) { - interaction.transformSchema(data.schema) - interaction.fit(data).transform(data).collect() - } else { - intercept[IllegalArgumentException] { - interaction.fit(data) - } - intercept[IllegalArgumentException] { - interaction.transformSchema(data.schema) - } - } - } - check(Array("a", "b"), "test", true) - check(Array("id"), "test", true) - check(Array("b"), "test", true) - check(Array("b"), "test", true) - check(Array(), "test", false) - check(Array("a", "b", "b"), "id", false) - check(Array("a", "b"), null, false) - check(null, "test", false) - } - - test("numeric interaction") { - val interaction = new RInteraction() - .setInputCols(Array("b", "c", "d")) - .setOutputCol("test") - val original = sqlContext.createDataFrame( - Seq((1, 2, 4, 2), (2, 3, 4, 1)) - ).toDF("a", "b", "c", "d") - val model = interaction.fit(original) - val result = model.transform(original) - val expected = sqlContext.createDataFrame( - Seq( - (1, 2, 4, 2, 16.0), - (2, 3, 4, 1, 12.0)) - ).toDF("a", "b", "c", "d", "test") - assert(result.collect() === expected.collect()) - val attr = Attribute.decodeStructField(result.schema("test"), preserveName = true) - val expectedAttr = new NumericAttribute(Some("b:c:d"), None) - assert(attr === expectedAttr) - } - - test("factor interaction") { - val interaction = new RInteraction() - .setInputCols(Array("a", "b")) - .setOutputCol("test") - val original = sqlContext.createDataFrame( - Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")) - ).toDF("id", "a", "b") - val model = interaction.fit(original) - val result = model.transform(original) - val expected = sqlContext.createDataFrame( - Seq( - (1, "foo", "zq", Vectors.dense(0.0, 1.0, 0.0, 0.0)), - (2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0)), - (3, "bar", "zz", Vectors.dense(0.0, 0.0, 1.0, 0.0))) - ).toDF("id", "a", "b", "test") - assert(result.collect() === expected.collect()) - val attrs = AttributeGroup.fromStructField(result.schema("test")) - val expectedAttrs = new AttributeGroup( - "test", - Array[Attribute]( - new BinaryAttribute(Some("a_bar:b_zq"), Some(1)), - new BinaryAttribute(Some("a_foo:b_zq"), Some(2)), - new BinaryAttribute(Some("a_bar:b_zz"), Some(3)), - new BinaryAttribute(Some("a_foo:b_zz"), Some(4)))) - assert(attrs === expectedAttrs) - } - - test("factor numeric interaction") { - val interaction = new RInteraction() - .setInputCols(Array("a", "b")) - .setOutputCol("test") - val original = sqlContext.createDataFrame( - Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5), (4, "baz", 5), (4, "baz", 5)) - ).toDF("id", "a", "b") - val model = interaction.fit(original) - val result = model.transform(original) - val expected = sqlContext.createDataFrame( - Seq( - (1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0)), - (2, "bar", 4, Vectors.dense(0.0, 4.0, 0.0)), - (3, "bar", 5, Vectors.dense(0.0, 5.0, 0.0)), - (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0)), - (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0)), - (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0))) - ).toDF("id", "a", "b", "test") - assert(result.collect() === expected.collect()) - val attrs = AttributeGroup.fromStructField(result.schema("test")) - val expectedAttrs = new AttributeGroup( - "test", - Array[Attribute]( - new BinaryAttribute(Some("a_baz:b"), Some(1)), - new BinaryAttribute(Some("a_bar:b"), Some(2)), - new BinaryAttribute(Some("a_foo:b"), Some(3)))) - assert(attrs === expectedAttrs) - } -} From b1bbaaed5a1ec50a9b16580829b6fd43960b057e Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 15 Sep 2015 19:57:31 -0700 Subject: [PATCH 21/27] Tue Sep 15 19:57:31 PDT 2015 --- .../apache/spark/ml/feature/Interaction.scala | 39 ++--- .../spark/ml/feature/InteractionSuite.scala | 139 +++++++++++++----- 2 files changed, 123 insertions(+), 55 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index d47207ae11f4..8fdc704fcb52 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -33,14 +33,17 @@ import org.apache.spark.sql.types._ /** * :: Experimental :: - * Implements the transforms required for R-style feature interactions. In summary, once fitted to - * a dataset, this transformer jointly one-hot encodes all factor input columns, then scales - * the encoded vector by all numeric input columns. If only numeric columns are specified, the - * output column will be a one-length vector containing their product. During one-hot encoding, - * the last category will be preserved unless the interaction is trivial. + * Implements the transforms required for R-style feature interactions. This transformer takes in + * Double and Vector columns and outputs a flattened vector of interactions. To handle interaction, + * we first one-hot encode nominal columns. Then, a vector of all their cross-products is + * produced. + * + * For example, given the inputs Double(2), Vector(3, 4), the result would be Vector(6, 8) if + * all columns were numeric. If the first was nominal with three different values, the result + * would then be Vector(0, 0, 3, 0, 0, 4). * * See https://stat.ethz.ch/R-manual/R-devel/library/base/html/formula.html for more - * information about factor interactions in R formulae. + * information about interactions in R formulae. */ @Experimental class Interaction(override val uid: String) extends Transformer @@ -64,13 +67,13 @@ class Interaction(override val uid: String) extends Transformer checkParams() val fieldIterators = getIterators(dataset) - def interact(vv: Any*): Vector = { + def interactFunc = udf { row: Row => var indices = ArrayBuilder.make[Int] var values = ArrayBuilder.make[Double] var size = 1 indices += 0 values += 1.0 - var fieldIndex = vv.length - 1 + var fieldIndex = row.length - 1 while (fieldIndex >= 0) { val prevIndices = indices.result() val prevValues = values.result() @@ -79,7 +82,7 @@ class Interaction(override val uid: String) extends Transformer indices = ArrayBuilder.make[Int] values = ArrayBuilder.make[Double] size *= currentIterator.size - currentIterator.foreachActive(vv(fieldIndex), (i, a) => { + currentIterator.foreachActive(row(fieldIndex), (i, a) => { var j = 0 while (j < prevIndices.length) { indices += prevIndices(j) + i * prevSize @@ -92,9 +95,6 @@ class Interaction(override val uid: String) extends Transformer Vectors.sparse(size, indices.result(), values.result()).compressed } - val interactFunc = udf { r: Row => - interact(r.toSeq: _*) - } val args = $(inputCols).map { c => dataset.schema(c).dataType match { case DoubleType => dataset(c) @@ -125,7 +125,7 @@ class Interaction(override val uid: String) extends Transformer Array(getCardinality(Attribute.fromStructField(field))) case _: VectorUDT => val attrs = AttributeGroup.fromStructField(field).attributes.getOrElse( - throw new SparkException("Vector fields must have attributes defined.")) + throw new SparkException("Vector attributes must be defined for interaction.")) attrs.map(getCardinality).toArray } new EncodingIterator(cardinalities) @@ -138,10 +138,10 @@ class Interaction(override val uid: String) extends Transformer val attrIterator = field.dataType match { case _: NumericType | BooleanType => val attr = Attribute.fromStructField(field) - encodedAttrIterator(None, Seq(attr)) + getAttributesAfterEncoding(None, Seq(attr)) case _: VectorUDT => val group = AttributeGroup.fromStructField(field) - encodedAttrIterator(Some(group.name), group.attributes.get) + getAttributesAfterEncoding(Some(group.name), group.attributes.get) } if (attrs.isEmpty) { attrs = attrIterator @@ -156,7 +156,8 @@ class Interaction(override val uid: String) extends Transformer new AttributeGroup($(outputCol), attrs.toArray) } - private def encodedAttrIterator(groupName: Option[String], attrs: Seq[Attribute]): Seq[Attribute] = { + private def getAttributesAfterEncoding( + groupName: Option[String], attrs: Seq[Attribute]): Seq[Attribute] = { def format(i: Int, attrName: Option[String], value: Option[String]): String = { val parts = Seq(groupName, Some(attrName.getOrElse(i.toString)), value) parts.flatten.mkString("_") @@ -192,7 +193,7 @@ class Interaction(override val uid: String) extends Transformer * value if the field is numeric. Fields with zero cardinality will not be * one-hot encoded (output verbatim). */ -// TODO(ekl) do we want to support drop-last like OneHotEncoder does? +// TODO(ekl) support drop-last option like OneHotEncoder does private[ml] class EncodingIterator(cardinalities: Array[Int]) { /** The size of the output vector. */ val size = cardinalities.map(i => if (i > 0) i else 1).sum @@ -220,9 +221,9 @@ private[ml] class EncodingIterator(cardinalities: Array[Int]) { while (i < dense.size) { val numOutputCols = cardinalities(i) if (numOutputCols > 0) { - val x = dense.values(i).toInt + val x = dense.values(i) assert(x >= 0.0 && x == x.toInt, s"Values from column must be indices, but got $x.") - f(cur + x, 1.0) + f(cur + x.toInt, 1.0) cur += numOutputCols } else { f(cur, dense.values(i)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala index b89bce6dc11e..67feaafd7767 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala @@ -51,53 +51,120 @@ class RInteractionSuite extends SparkFunSuite with MLlibTestSparkContext { intercept[SparkException] { encode(Array(0), "foo") } intercept[SparkException] { encode(Array(0), null) } intercept[AssertionError] { encode(Array(1), 2.2) } - intercept[AssertionError] { encode(Array(1), Vectors.dense(2.2)) } + intercept[AssertionError] { encode(Array(3), Vectors.dense(2.2)) } intercept[AssertionError] { encode(Array(1), Vectors.dense(1.0, 2.0, 3.0)) } } - test("new interaction") { + test("numeric interaction") { val data = sqlContext.createDataFrame( Seq( - (0, "foo", true, 4, Vectors.dense(0.0, 0.0, 1.0), Vectors.dense(5.0, 3.0)), - (0, "bar", true, 4, Vectors.dense(1.0, 4.0, 2.0), Vectors.dense(4.0, 3.0)), - (0, "bar", true, 5, Vectors.dense(2.0, 5.0, 3.0), Vectors.dense(5.0, 3.0)), - (0, "baz", true, 5, Vectors.dense(3.0, 8.0, 4.0), Vectors.dense(5.0, 2.0)), - (0, "baz", false, 5, Vectors.dense(4.0, 9.0, 8.0), Vectors.dense(7.0, 1.0)), - (1, "baz", false, 5, Vectors.dense(5.0, 2.0, 9.0), Vectors.dense(2.0, 0.0))) - ).toDF("id", "a", "bin", "b", "test", "test2") - val attrs = new AttributeGroup( - "test", + (2, Vectors.dense(3.0, 4.0)), + (1, Vectors.dense(1.0, 5.0))) + ).toDF("a", "b") + val groupAttr = new AttributeGroup( + "b", Array[Attribute]( - NominalAttribute.defaultAttr.withValues(Array("a", "b", "c", "d", "e", "f")), - NumericAttribute.defaultAttr.withName("magnitude"), - NominalAttribute.defaultAttr.withName("colors").withValues( - Array("green", "blue", "red", "violet", "yellow", - "orange", "black", "white", "azure", "gray")))) + NumericAttribute.defaultAttr.withName("foo"), + NumericAttribute.defaultAttr.withName("bar"))) + val df = data.select( + col("a").as("a", NumericAttribute.defaultAttr.toMetadata()), + col("b").as("b", groupAttr.toMetadata())) + val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features") + val res = trans.transform(df) + val expected = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(3.0, 4.0), Vectors.dense(6.0, 8.0)), + (1, Vectors.dense(1.0, 5.0), Vectors.dense(1.0, 5.0))) + ).toDF("a", "b", "features") + assert(res.collect() === expected.collect()) + val attrs = AttributeGroup.fromStructField(res.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("a:b_foo"), Some(1)), + new NumericAttribute(Some("a:b_bar"), Some(2)))) + assert(attrs === expectedAttrs) + } - val idAttr = NominalAttribute.defaultAttr.withValues(Array("red", "blue")) - val attrs2 = new AttributeGroup( - "test2", + test("nominal interaction") { + val data = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(3.0, 4.0)), + (1, Vectors.dense(1.0, 5.0))) + ).toDF("a", "b") + val groupAttr = new AttributeGroup( + "b", Array[Attribute]( - NumericAttribute.defaultAttr, - NominalAttribute.defaultAttr.withValues(Array("one", "two", "three", "four")))) + NumericAttribute.defaultAttr.withName("foo"), + NumericAttribute.defaultAttr.withName("bar"))) val df = data.select( - col("id").as("id", idAttr.toMetadata()), col("b"), col("bin"), - col("test").as("test", attrs.toMetadata()), - col("test2").as("test2", attrs2.toMetadata())) - df.collect.foreach(println) - println(df.schema) - df.schema.foreach { field => - println(field.metadata) - } -// val trans = new Interaction().setInputCols(Array("id", "test2", "test")).setOutputCol("feature") - val trans = new Interaction().setInputCols(Array("id", "test2")).setOutputCol("feature") + col("a").as("a", + NominalAttribute.defaultAttr.withValues(Array("up", "down", "left")).toMetadata()), + col("b").as("b", groupAttr.toMetadata())) + val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features") val res = trans.transform(df) - res.collect.foreach(println) - println(res.schema) - res.schema.foreach { field => - println(field.metadata) - } + val expected = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(3.0, 4.0), Vectors.dense(0, 0, 0, 0, 3, 4)), + (1, Vectors.dense(1.0, 5.0), Vectors.dense(0, 0, 1, 5, 0, 0))) + ).toDF("a", "b", "features") + assert(res.collect() === expected.collect()) + val attrs = AttributeGroup.fromStructField(res.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("a_up:b_foo"), Some(1)), + new NumericAttribute(Some("a_up:b_bar"), Some(2)), + new NumericAttribute(Some("a_down:b_foo"), Some(3)), + new NumericAttribute(Some("a_down:b_bar"), Some(4)), + new NumericAttribute(Some("a_left:b_foo"), Some(5)), + new NumericAttribute(Some("a_left:b_bar"), Some(6)))) + assert(attrs === expectedAttrs) } + +// test("new interaction") { +// val data = sqlContext.createDataFrame( +// Seq( +// (0, "foo", true, 4, Vectors.dense(0.0, 0.0, 1.0), Vectors.dense(5.0, 3.0)), +// (0, "bar", true, 4, Vectors.dense(1.0, 4.0, 2.0), Vectors.dense(4.0, 3.0)), +// (0, "bar", true, 5, Vectors.dense(2.0, 5.0, 3.0), Vectors.dense(5.0, 3.0)), +// (0, "baz", true, 5, Vectors.dense(3.0, 8.0, 4.0), Vectors.dense(5.0, 2.0)), +// (0, "baz", false, 5, Vectors.dense(4.0, 9.0, 8.0), Vectors.dense(7.0, 1.0)), +// (1, "baz", false, 5, Vectors.dense(5.0, 2.0, 9.0), Vectors.dense(2.0, 0.0))) +// ).toDF("id", "a", "bin", "b", "test", "test2") +// val attrs = new AttributeGroup( +// "test", +// Array[Attribute]( +// NominalAttribute.defaultAttr.withValues(Array("a", "b", "c", "d", "e", "f")), +// NumericAttribute.defaultAttr.withName("magnitude"), +// NominalAttribute.defaultAttr.withName("colors").withValues( +// Array("green", "blue", "red", "violet", "yellow", +// "orange", "black", "white", "azure", "gray")))) +// +// val idAttr = NominalAttribute.defaultAttr.withValues(Array("red", "blue")) +// val attrs2 = new AttributeGroup( +// "test2", +// Array[Attribute]( +// NumericAttribute.defaultAttr, +// NominalAttribute.defaultAttr.withValues(Array("one", "two", "three", "four")))) +// val df = data.select( +// col("id").as("id", idAttr.toMetadata()), col("b"), col("bin"), +// col("test").as("test", attrs.toMetadata()), +// col("test2").as("test2", attrs2.toMetadata())) +// df.collect.foreach(println) +// println(df.schema) +// df.schema.foreach { field => +// println(field.metadata) +// } +//// val trans = new Interaction().setInputCols(Array("id", "test2", "test")).setOutputCol("feature") +// val trans = new Interaction().setInputCols(Array("id", "test2")).setOutputCol("feature") +// val res = trans.transform(df) +// res.collect.foreach(println) +// println(res.schema) +// res.schema.foreach { field => +// println(field.metadata) +// } +// } // // test("parameter validation") { // val data = sqlContext.createDataFrame( From e258426099c91ec5ad50733df1046fd74e5cf764 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 15 Sep 2015 20:12:54 -0700 Subject: [PATCH 22/27] Tue Sep 15 20:12:54 PDT 2015 --- .../apache/spark/ml/feature/Interaction.scala | 8 +- .../spark/ml/feature/InteractionSuite.scala | 195 ++++-------------- 2 files changed, 46 insertions(+), 157 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index 8fdc704fcb52..087735350630 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -207,7 +207,9 @@ private[ml] class EncodingIterator(cardinalities: Array[Int]) { assert(cardinalities.length == 1) val numOutputCols = cardinalities(0) if (numOutputCols > 0) { - assert(d >= 0.0 && d == d.toInt, s"Values from column must be indices, but got $d.") + assert( + d >= 0.0 && d == d.toInt && d < numOutputCols, + s"Values from column must be indices, but got $d.") f(d.toInt, 1.0) } else { f(0, d) @@ -222,7 +224,9 @@ private[ml] class EncodingIterator(cardinalities: Array[Int]) { val numOutputCols = cardinalities(i) if (numOutputCols > 0) { val x = dense.values(i) - assert(x >= 0.0 && x == x.toInt, s"Values from column must be indices, but got $x.") + assert( + x >= 0.0 && x == x.toInt && x < numOutputCols, + s"Values from column must be indices, but got $x.") f(cur + x.toInt, 1.0) cur += numOutputCols } else { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala index 67feaafd7767..e8918b40e2da 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala @@ -98,8 +98,8 @@ class RInteractionSuite extends SparkFunSuite with MLlibTestSparkContext { NumericAttribute.defaultAttr.withName("foo"), NumericAttribute.defaultAttr.withName("bar"))) val df = data.select( - col("a").as("a", - NominalAttribute.defaultAttr.withValues(Array("up", "down", "left")).toMetadata()), + col("a").as( + "a", NominalAttribute.defaultAttr.withValues(Array("up", "down", "left")).toMetadata()), col("b").as("b", groupAttr.toMetadata())) val trans = new Interaction().setInputCols(Array("a", "b")).setOutputCol("features") val res = trans.transform(df) @@ -122,157 +122,42 @@ class RInteractionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(attrs === expectedAttrs) } -// test("new interaction") { -// val data = sqlContext.createDataFrame( -// Seq( -// (0, "foo", true, 4, Vectors.dense(0.0, 0.0, 1.0), Vectors.dense(5.0, 3.0)), -// (0, "bar", true, 4, Vectors.dense(1.0, 4.0, 2.0), Vectors.dense(4.0, 3.0)), -// (0, "bar", true, 5, Vectors.dense(2.0, 5.0, 3.0), Vectors.dense(5.0, 3.0)), -// (0, "baz", true, 5, Vectors.dense(3.0, 8.0, 4.0), Vectors.dense(5.0, 2.0)), -// (0, "baz", false, 5, Vectors.dense(4.0, 9.0, 8.0), Vectors.dense(7.0, 1.0)), -// (1, "baz", false, 5, Vectors.dense(5.0, 2.0, 9.0), Vectors.dense(2.0, 0.0))) -// ).toDF("id", "a", "bin", "b", "test", "test2") -// val attrs = new AttributeGroup( -// "test", -// Array[Attribute]( -// NominalAttribute.defaultAttr.withValues(Array("a", "b", "c", "d", "e", "f")), -// NumericAttribute.defaultAttr.withName("magnitude"), -// NominalAttribute.defaultAttr.withName("colors").withValues( -// Array("green", "blue", "red", "violet", "yellow", -// "orange", "black", "white", "azure", "gray")))) -// -// val idAttr = NominalAttribute.defaultAttr.withValues(Array("red", "blue")) -// val attrs2 = new AttributeGroup( -// "test2", -// Array[Attribute]( -// NumericAttribute.defaultAttr, -// NominalAttribute.defaultAttr.withValues(Array("one", "two", "three", "four")))) -// val df = data.select( -// col("id").as("id", idAttr.toMetadata()), col("b"), col("bin"), -// col("test").as("test", attrs.toMetadata()), -// col("test2").as("test2", attrs2.toMetadata())) -// df.collect.foreach(println) -// println(df.schema) -// df.schema.foreach { field => -// println(field.metadata) -// } -//// val trans = new Interaction().setInputCols(Array("id", "test2", "test")).setOutputCol("feature") -// val trans = new Interaction().setInputCols(Array("id", "test2")).setOutputCol("feature") -// val res = trans.transform(df) -// res.collect.foreach(println) -// println(res.schema) -// res.schema.foreach { field => -// println(field.metadata) -// } -// } -// -// test("parameter validation") { -// val data = sqlContext.createDataFrame( -// Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")) -// ).toDF("id", "a", "b") -// def check(inputCols: Array[String], outputCol: String, expectOk: Boolean): Unit = { -// val interaction = new RInteraction() -// if (inputCols != null) { -// interaction.setInputCols(inputCols) -// } -// if (outputCol != null) { -// interaction.setOutputCol(outputCol) -// } -// if (expectOk) { -// interaction.transformSchema(data.schema) -// interaction.fit(data).transform(data).collect() -// } else { -// intercept[IllegalArgumentException] { -// interaction.fit(data) -// } -// intercept[IllegalArgumentException] { -// interaction.transformSchema(data.schema) -// } -// } -// } -// check(Array("a", "b"), "test", true) -// check(Array("id"), "test", true) -// check(Array("b"), "test", true) -// check(Array("b"), "test", true) -// check(Array(), "test", false) -// check(Array("a", "b", "b"), "id", false) -// check(Array("a", "b"), null, false) -// check(null, "test", false) -// } -// -// test("numeric interaction") { -// val interaction = new RInteraction() -// .setInputCols(Array("b", "c", "d")) -// .setOutputCol("test") -// val original = sqlContext.createDataFrame( -// Seq((1, 2, 4, 2), (2, 3, 4, 1)) -// ).toDF("a", "b", "c", "d") -// val model = interaction.fit(original) -// val result = model.transform(original) -// val expected = sqlContext.createDataFrame( -// Seq( -// (1, 2, 4, 2, 16.0), -// (2, 3, 4, 1, 12.0)) -// ).toDF("a", "b", "c", "d", "test") -// assert(result.collect() === expected.collect()) -// val attr = Attribute.decodeStructField(result.schema("test"), preserveName = true) -// val expectedAttr = new NumericAttribute(Some("b:c:d"), None) -// assert(attr === expectedAttr) -// } -// -// test("factor interaction") { -// val interaction = new RInteraction() -// .setInputCols(Array("a", "b")) -// .setOutputCol("test") -// val original = sqlContext.createDataFrame( -// Seq((1, "foo", "zq"), (2, "bar", "zq"), (3, "bar", "zz")) -// ).toDF("id", "a", "b") -// val model = interaction.fit(original) -// val result = model.transform(original) -// val expected = sqlContext.createDataFrame( -// Seq( -// (1, "foo", "zq", Vectors.dense(0.0, 1.0, 0.0, 0.0)), -// (2, "bar", "zq", Vectors.dense(1.0, 0.0, 0.0, 0.0)), -// (3, "bar", "zz", Vectors.dense(0.0, 0.0, 1.0, 0.0))) -// ).toDF("id", "a", "b", "test") -// assert(result.collect() === expected.collect()) -// val attrs = AttributeGroup.fromStructField(result.schema("test")) -// val expectedAttrs = new AttributeGroup( -// "test", -// Array[Attribute]( -// new BinaryAttribute(Some("a_bar:b_zq"), Some(1)), -// new BinaryAttribute(Some("a_foo:b_zq"), Some(2)), -// new BinaryAttribute(Some("a_bar:b_zz"), Some(3)), -// new BinaryAttribute(Some("a_foo:b_zz"), Some(4)))) -// assert(attrs === expectedAttrs) -// } -// -// test("factor numeric interaction") { -// val interaction = new RInteraction() -// .setInputCols(Array("a", "b")) -// .setOutputCol("test") -// val original = sqlContext.createDataFrame( -// Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 5), (4, "baz", 5), (4, "baz", 5)) -// ).toDF("id", "a", "b") -// val model = interaction.fit(original) -// val result = model.transform(original) -// val expected = sqlContext.createDataFrame( -// Seq( -// (1, "foo", 4, Vectors.dense(0.0, 0.0, 4.0)), -// (2, "bar", 4, Vectors.dense(0.0, 4.0, 0.0)), -// (3, "bar", 5, Vectors.dense(0.0, 5.0, 0.0)), -// (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0)), -// (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0)), -// (4, "baz", 5, Vectors.dense(5.0, 0.0, 0.0))) -// ).toDF("id", "a", "b", "test") -// assert(result.collect() === expected.collect()) -// val attrs = AttributeGroup.fromStructField(result.schema("test")) -// val expectedAttrs = new AttributeGroup( -// "test", -// Array[Attribute]( -// new BinaryAttribute(Some("a_baz:b"), Some(1)), -// new BinaryAttribute(Some("a_bar:b"), Some(2)), -// new BinaryAttribute(Some("a_foo:b"), Some(3)))) -// assert(attrs === expectedAttrs) -// } + test("default attr names") { + val data = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(0.0, 4.0), 1.0), + (1, Vectors.dense(1.0, 5.0), 10.0)) + ).toDF("a", "b", "c") + val groupAttr = new AttributeGroup( + "b", + Array[Attribute]( + NominalAttribute.defaultAttr.withNumValues(2), + NumericAttribute.defaultAttr)) + val df = data.select( + col("a").as("a", NominalAttribute.defaultAttr.withNumValues(3).toMetadata()), + col("b").as("b", groupAttr.toMetadata()), + col("c").as("c", NumericAttribute.defaultAttr.toMetadata())) + val trans = new Interaction().setInputCols(Array("a", "b", "c")).setOutputCol("features") + val res = trans.transform(df) + val expected = sqlContext.createDataFrame( + Seq( + (2, Vectors.dense(0.0, 4.0), 1.0, Vectors.dense(0, 0, 0, 0, 0, 0, 1, 0, 4)), + (1, Vectors.dense(1.0, 5.0), 10.0, Vectors.dense(0, 0, 0, 0, 10, 50, 0, 0, 0))) + ).toDF("a", "b", "c", "features") + assert(res.collect() === expected.collect()) + val attrs = AttributeGroup.fromStructField(res.schema("features")) + val expectedAttrs = new AttributeGroup( + "features", + Array[Attribute]( + new NumericAttribute(Some("a_0:b_0_0:c"), Some(1)), + new NumericAttribute(Some("a_0:b_0_1:c"), Some(2)), + new NumericAttribute(Some("a_0:b_1:c"), Some(3)), + new NumericAttribute(Some("a_1:b_0_0:c"), Some(4)), + new NumericAttribute(Some("a_1:b_0_1:c"), Some(5)), + new NumericAttribute(Some("a_1:b_1:c"), Some(6)), + new NumericAttribute(Some("a_2:b_0_0:c"), Some(7)), + new NumericAttribute(Some("a_2:b_0_1:c"), Some(8)), + new NumericAttribute(Some("a_2:b_1:c"), Some(9)))) + assert(attrs === expectedAttrs) + } } From 060f9d708fb9075446cea97c1d688fef31f3a91e Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 15 Sep 2015 20:13:45 -0700 Subject: [PATCH 23/27] Tue Sep 15 20:13:45 PDT 2015 --- .../main/scala/org/apache/spark/ml/feature/Interaction.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index 087735350630..24de7b7d8814 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -193,7 +193,7 @@ class Interaction(override val uid: String) extends Transformer * value if the field is numeric. Fields with zero cardinality will not be * one-hot encoded (output verbatim). */ -// TODO(ekl) support drop-last option like OneHotEncoder does +// TODO(ekl) support drop-last option like OneHotEncoder does? private[ml] class EncodingIterator(cardinalities: Array[Int]) { /** The size of the output vector. */ val size = cardinalities.map(i => if (i > 0) i else 1).sum From c3da8f8bf22a15b5bcbe91c9eddfef66abbb6faf Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 15 Sep 2015 20:14:26 -0700 Subject: [PATCH 24/27] Tue Sep 15 20:14:25 PDT 2015 --- .../main/scala/org/apache/spark/ml/feature/Interaction.scala | 4 ++-- .../scala/org/apache/spark/ml/feature/InteractionSuite.scala | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index 24de7b7d8814..b6953f91f724 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -39,8 +39,8 @@ import org.apache.spark.sql.types._ * produced. * * For example, given the inputs Double(2), Vector(3, 4), the result would be Vector(6, 8) if - * all columns were numeric. If the first was nominal with three different values, the result - * would then be Vector(0, 0, 3, 0, 0, 4). + * all columns were numeric. If the first input was nominal with four different values, the result + * would then be Vector(0, 0, 0, 0, 3, 4, 0, 0). * * See https://stat.ethz.ch/R-manual/R-devel/library/base/html/formula.html for more * information about interactions in R formulae. diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala index e8918b40e2da..e7ac58dca6b6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.functions.col -class RInteractionSuite extends SparkFunSuite with MLlibTestSparkContext { +class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext { test("params") { ParamsSuite.checkParams(new Interaction()) } @@ -53,6 +53,8 @@ class RInteractionSuite extends SparkFunSuite with MLlibTestSparkContext { intercept[AssertionError] { encode(Array(1), 2.2) } intercept[AssertionError] { encode(Array(3), Vectors.dense(2.2)) } intercept[AssertionError] { encode(Array(1), Vectors.dense(1.0, 2.0, 3.0)) } + intercept[AssertionError] { encode(Array(3), Vectors.dense(-1)) } + intercept[AssertionError] { encode(Array(3), Vectors.dense(3)) } } test("numeric interaction") { From 92c828710cdd4ad4580dc06ea1b9ba51e2b5ed8f Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 15 Sep 2015 20:18:24 -0700 Subject: [PATCH 25/27] revert attributes change --- .../apache/spark/ml/attribute/attributes.scala | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala index a7c10333c0d5..e479f169021d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala @@ -124,28 +124,18 @@ private[attribute] trait AttributeFactory { private[attribute] def fromMetadata(metadata: Metadata): Attribute /** - * Creates an [[Attribute]] from a [[StructField]] instance, optionally preserving name. + * Creates an [[Attribute]] from a [[StructField]] instance. */ - private[ml] def decodeStructField(field: StructField, preserveName: Boolean): Attribute = { + def fromStructField(field: StructField): Attribute = { require(field.dataType.isInstanceOf[NumericType]) val metadata = field.metadata val mlAttr = AttributeKeys.ML_ATTR if (metadata.contains(mlAttr)) { - val attr = fromMetadata(metadata.getMetadata(mlAttr)) - if (preserveName) { - attr - } else { - attr.withName(field.name) - } + fromMetadata(metadata.getMetadata(mlAttr)).withName(field.name) } else { UnresolvedAttribute } } - - /** - * Creates an [[Attribute]] from a [[StructField]] instance. - */ - def fromStructField(field: StructField): Attribute = decodeStructField(field, false) } /** From 09cba2c1315e3f5c1f8579a4537c164018b47408 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 17 Sep 2015 00:35:09 -0700 Subject: [PATCH 26/27] first pass clean up validate params --- .../apache/spark/ml/feature/Interaction.scala | 205 ++++++++++-------- .../spark/ml/feature/InteractionSuite.scala | 24 +- 2 files changed, 131 insertions(+), 98 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index b6953f91f724..456b49eed1a3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -33,17 +33,14 @@ import org.apache.spark.sql.types._ /** * :: Experimental :: - * Implements the transforms required for R-style feature interactions. This transformer takes in - * Double and Vector columns and outputs a flattened vector of interactions. To handle interaction, - * we first one-hot encode nominal columns. Then, a vector of all their cross-products is + * Implements the feature interaction transform. This transformer takes in Double and Vector type + * columns and outputs a flattened vector of their feature interactions. To handle interaction, + * we first one-hot encode any nominal features. Then, a vector of the feature cross-products is * produced. * - * For example, given the inputs Double(2), Vector(3, 4), the result would be Vector(6, 8) if - * all columns were numeric. If the first input was nominal with four different values, the result - * would then be Vector(0, 0, 0, 0, 3, 4, 0, 0). - * - * See https://stat.ethz.ch/R-manual/R-devel/library/base/html/formula.html for more - * information about interactions in R formulae. + * For example, given the input feature values `Double(2)` and `Vector(3, 4)`, the output would be + * `Vector(6, 8)` if all input features were numeric. If the first feature was instead nominal + * with four categories, the output would then be `Vector(0, 0, 0, 0, 3, 4, 0, 0)`. */ @Experimental class Interaction(override val uid: String) extends Transformer @@ -59,13 +56,15 @@ class Interaction(override val uid: String) extends Transformer // optimistic schema; does not contain any ML attributes override def transformSchema(schema: StructType): StructType = { - checkParams() - StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, true)) + validateParams() + StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, false)) } override def transform(dataset: DataFrame): DataFrame = { - checkParams() - val fieldIterators = getIterators(dataset) + validateParams() + val inputFeatures = $(inputCols).map(c => dataset.schema(c)) + val featureEncoders = getFeatureEncoders(inputFeatures) + val featureAttrs = getFeatureAttrs(inputFeatures) def interactFunc = udf { row: Row => var indices = ArrayBuilder.make[Int] @@ -73,16 +72,16 @@ class Interaction(override val uid: String) extends Transformer var size = 1 indices += 0 values += 1.0 - var fieldIndex = row.length - 1 - while (fieldIndex >= 0) { + var featureIndex = row.length - 1 + while (featureIndex >= 0) { val prevIndices = indices.result() val prevValues = values.result() val prevSize = size - val currentIterator = fieldIterators(fieldIndex) + val currentEncoder = featureEncoders(featureIndex) indices = ArrayBuilder.make[Int] values = ArrayBuilder.make[Double] - size *= currentIterator.size - currentIterator.foreachActive(row(fieldIndex), (i, a) => { + size *= currentEncoder.outputSize + currentEncoder.foreachNonzeroOutput(row(featureIndex), (i, a) => { var j = 0 while (j < prevIndices.length) { indices += prevIndices(j) + i * prevSize @@ -90,79 +89,106 @@ class Interaction(override val uid: String) extends Transformer j += 1 } }) - fieldIndex -= 1 + featureIndex -= 1 } Vectors.sparse(size, indices.result(), values.result()).compressed } - val args = $(inputCols).map { c => - dataset.schema(c).dataType match { - case DoubleType => dataset(c) - case _: VectorUDT => dataset(c) - case _: NumericType | BooleanType => dataset(c).cast(DoubleType) + val featureCols = inputFeatures.map { f => + f.dataType match { + case DoubleType => dataset(f.name) + case _: VectorUDT => dataset(f.name) + case _: NumericType | BooleanType => dataset(f.name).cast(DoubleType) } } - val attrs = generateAttrs($(inputCols).map(col => dataset.schema(col))) dataset.select( col("*"), - interactFunc(struct(args: _*)).as($(outputCol), attrs.toMetadata())) + interactFunc(struct(featureCols: _*)).as($(outputCol), featureAttrs.toMetadata())) } - private def getIterators(dataset: DataFrame): Array[EncodingIterator] = { - def getCardinality(attr: Attribute): Int = { + /** + * Creates a feature encoder for each input column, which supports efficient iteration over + * one-hot encoded feature values. See also the class-level comment of [[FeatureEncoder]]. + * + * @param features The input feature columns to create encoders for. + */ + private def getFeatureEncoders(features: Seq[StructField]): Array[FeatureEncoder] = { + def getNumFeatures(attr: Attribute): Int = { attr match { case nominal: NominalAttribute => - nominal.getNumValues.getOrElse( - throw new SparkException("Nominal fields must have attr numValues defined.")) + math.max(1, nominal.getNumValues.getOrElse( + throw new SparkException("Nominal features must have attr numValues defined."))) case _ => - 0 // treated as numeric + 1 // numeric feature } } - $(inputCols).map { col => - val field = dataset.schema(col) - val cardinalities = field.dataType match { + features.map { f => + val numFeatures = f.dataType match { case _: NumericType | BooleanType => - Array(getCardinality(Attribute.fromStructField(field))) + Array(getNumFeatures(Attribute.fromStructField(f))) case _: VectorUDT => - val attrs = AttributeGroup.fromStructField(field).attributes.getOrElse( + val attrs = AttributeGroup.fromStructField(f).attributes.getOrElse( throw new SparkException("Vector attributes must be defined for interaction.")) - attrs.map(getCardinality).toArray + attrs.map(getNumFeatures).toArray } - new EncodingIterator(cardinalities) + new FeatureEncoder(numFeatures) }.toArray } - private def generateAttrs(schema: Seq[StructField]): AttributeGroup = { - var attrs: Seq[Attribute] = Nil - schema.reverse.map { field => - val attrIterator = field.dataType match { + /** + * Generates ML attributes for the output vector of all feature interactions. We make a best + * effort to generate reasonable names for output features, based on the concatenation of the + * interacting feature names and values delimited with `_`. When no feature name is specified, + * we fall back to using the feature index (e.g. `foo:bar_2_0` may indicate an interaction + * between the numeric `foo` feature and a nominal third feature from column `bar`. + * + * @param features The input feature columns to the Interaction transformer. + */ + private def getFeatureAttrs(features: Seq[StructField]): AttributeGroup = { + var featureAttrs: Seq[Attribute] = Nil + features.reverse.foreach { f => + val encodedAttrs = f.dataType match { case _: NumericType | BooleanType => - val attr = Attribute.fromStructField(field) - getAttributesAfterEncoding(None, Seq(attr)) + val attr = Attribute.fromStructField(f) + encodedFeatureAttrs(Seq(attr), None) case _: VectorUDT => - val group = AttributeGroup.fromStructField(field) - getAttributesAfterEncoding(Some(group.name), group.attributes.get) + val group = AttributeGroup.fromStructField(f) + encodedFeatureAttrs(group.attributes.get, Some(group.name)) } - if (attrs.isEmpty) { - attrs = attrIterator + if (featureAttrs.isEmpty) { + featureAttrs = encodedAttrs } else { - attrs = attrIterator.flatMap { head => - attrs.map { tail => + featureAttrs = encodedAttrs.flatMap { head => + featureAttrs.map { tail => NumericAttribute.defaultAttr.withName(head.name.get + ":" + tail.name.get) } } } } - new AttributeGroup($(outputCol), attrs.toArray) + new AttributeGroup($(outputCol), featureAttrs.toArray) } - private def getAttributesAfterEncoding( - groupName: Option[String], attrs: Seq[Attribute]): Seq[Attribute] = { - def format(i: Int, attrName: Option[String], value: Option[String]): String = { - val parts = Seq(groupName, Some(attrName.getOrElse(i.toString)), value) + /** + * Generates the output ML attributes for a single input feature. Each output feature name has + * up to three parts: the group name, feature name, and category name (for nominal features), + * each separated by an underscore. + * + * @param inputAttrs The attributes of the input feature. + * @param groupName Optional name of the input feature group (for Vector type features). + */ + private def encodedFeatureAttrs( + inputAttrs: Seq[Attribute], + groupName: Option[String]): Seq[Attribute] = { + + def format( + index: Int, + attrName: Option[String], + categoryName: Option[String]): String = { + val parts = Seq(groupName, Some(attrName.getOrElse(index.toString)), categoryName) parts.flatten.mkString("_") } - attrs.zipWithIndex.flatMap { + + inputAttrs.zipWithIndex.flatMap { case (nominal: NominalAttribute, i) => if (nominal.values.isDefined) { nominal.values.get.map( @@ -178,7 +204,7 @@ class Interaction(override val uid: String) extends Transformer override def copy(extra: ParamMap): Interaction = defaultCopy(extra) - private def checkParams(): Unit = { + override def validateParams(): Unit = { require(get(inputCols).isDefined, "Input cols must be defined first.") require(get(outputCol).isDefined, "Output col must be defined first.") require($(inputCols).length > 0, "Input cols must have non-zero length.") @@ -186,27 +212,41 @@ class Interaction(override val uid: String) extends Transformer } } -/** - * An iterator over VectorUDT or Double that one-hot encodes nominal fields in the output. +/** + * This class performs on-the-fly one-hot encoding of features as you iterate over them. To + * indicate which input features should be one-hot encoded, an array of the feature counts + * must be passed in ahead of time. * - * @param cardinalities An array defining the cardinality of each vector sub-field, or a single - * value if the field is numeric. Fields with zero cardinality will not be - * one-hot encoded (output verbatim). + * @param numFeatures Array of feature counts for each input feature. For nominal features this + * count is equal to the number of categories. For numeric features the count + * should be set to 1. */ -// TODO(ekl) support drop-last option like OneHotEncoder does? -private[ml] class EncodingIterator(cardinalities: Array[Int]) { +private[ml] class FeatureEncoder(numFeatures: Array[Int]) { + assert(numFeatures.forall(_ > 0), "Features counts must all be positive.") + /** The size of the output vector. */ - val size = cardinalities.map(i => if (i > 0) i else 1).sum + val outputSize = numFeatures.sum + + /** Precomputed offsets for the location of each output feature. */ + private val outputOffsets = { + val arr = new Array[Int](numFeatures.length) + for (i <- 1 until arr.length) { + arr(i) = arr(i - 1) + numFeatures(i - 1) + } + arr + } /** - * @param value The row to iterate over, either a Double or Vector. + * Given an input row of features, invokes the specific function for every non-zero output. + * + * @param value The row value to encode, either a Double or Vector. * @param f The callback to invoke on each non-zero (index, value) output pair. */ - def foreachActive(value: Any, f: (Int, Double) => Unit) = value match { + def foreachNonzeroOutput(value: Any, f: (Int, Double) => Unit): Unit = value match { case d: Double => - assert(cardinalities.length == 1) - val numOutputCols = cardinalities(0) - if (numOutputCols > 0) { + assert(numFeatures.length == 1, "DoubleType columns should only contain one feature.") + val numOutputCols = numFeatures.head + if (numOutputCols > 1) { assert( d >= 0.0 && d == d.toInt && d < numOutputCols, s"Values from column must be indices, but got $d.") @@ -215,25 +255,18 @@ private[ml] class EncodingIterator(cardinalities: Array[Int]) { f(0, d) } case vec: Vector => - assert(cardinalities.length == vec.size, - s"Vector column size was ${vec.size}, expected ${cardinalities.length}") - val dense = vec.toDense - var i = 0 - var cur = 0 - while (i < dense.size) { - val numOutputCols = cardinalities(i) - if (numOutputCols > 0) { - val x = dense.values(i) + assert(numFeatures.length == vec.size, + s"Vector column size was ${vec.size}, expected ${numFeatures.length}") + vec.foreachActive { (i, v) => + val numOutputCols = numFeatures(i) + if (numOutputCols > 1) { assert( - x >= 0.0 && x == x.toInt && x < numOutputCols, - s"Values from column must be indices, but got $x.") - f(cur + x.toInt, 1.0) - cur += numOutputCols + v >= 0.0 && v == v.toInt && v < numOutputCols, + s"Values from column must be indices, but got $v.") + f(outputOffsets(i) + v.toInt, 1.0) } else { - f(cur, dense.values(i)) - cur += 1 + f(outputOffsets(i), v) } - i += 1 } case null => throw new SparkException("Values to interact cannot be null.") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala index e7ac58dca6b6..2beb62ca0823 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala @@ -31,26 +31,26 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext { ParamsSuite.checkParams(new Interaction()) } - test("encoding iterator") { + test("feature encoder") { def encode(cardinalities: Array[Int], value: Any): Vector = { var indices = ArrayBuilder.make[Int] var values = ArrayBuilder.make[Double] - val iter = new EncodingIterator(cardinalities) - iter.foreachActive(value, (i, v) => { + val encoder = new FeatureEncoder(cardinalities) + encoder.foreachNonzeroOutput(value, (i, v) => { indices += i values += v }) - Vectors.sparse(iter.size, indices.result(), values.result()).compressed + Vectors.sparse(encoder.outputSize, indices.result(), values.result()).compressed } - assert(encode(Array(0), 2.2) === Vectors.dense(2.2)) + assert(encode(Array(1), 2.2) === Vectors.dense(2.2)) assert(encode(Array(3), Vectors.dense(1)) === Vectors.dense(0, 1, 0)) - assert(encode(Array(0, 0), Vectors.dense(1.1, 2.2)) === Vectors.dense(1.1, 2.2)) - assert(encode(Array(3, 0), Vectors.dense(1, 2.2)) === Vectors.dense(0, 1, 0, 2.2)) - assert(encode(Array(2, 0), Vectors.dense(1, 2.2)) === Vectors.dense(0, 1, 2.2)) - assert(encode(Array(2, 0, 1), Vectors.dense(0, 2.2, 0)) === Vectors.dense(1, 0, 2.2, 1)) - intercept[SparkException] { encode(Array(0), "foo") } - intercept[SparkException] { encode(Array(0), null) } - intercept[AssertionError] { encode(Array(1), 2.2) } + assert(encode(Array(1, 1), Vectors.dense(1.1, 2.2)) === Vectors.dense(1.1, 2.2)) + assert(encode(Array(3, 1), Vectors.dense(1, 2.2)) === Vectors.dense(0, 1, 0, 2.2)) + assert(encode(Array(2, 1), Vectors.dense(1, 2.2)) === Vectors.dense(0, 1, 2.2)) + assert(encode(Array(2, 1, 1), Vectors.dense(0, 2.2, 0)) === Vectors.dense(1, 0, 2.2, 0)) + intercept[SparkException] { encode(Array(1), "foo") } + intercept[SparkException] { encode(Array(1), null) } + intercept[AssertionError] { encode(Array(2), 2.2) } intercept[AssertionError] { encode(Array(3), Vectors.dense(2.2)) } intercept[AssertionError] { encode(Array(1), Vectors.dense(1.0, 2.0, 3.0)) } intercept[AssertionError] { encode(Array(3), Vectors.dense(-1)) } From 1ae9ef0db969a774d6fdf19a154b78529859a3fb Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 17 Sep 2015 13:21:10 -0700 Subject: [PATCH 27/27] comments 2 --- .../scala/org/apache/spark/ml/feature/Interaction.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index 456b49eed1a3..9194763fb32f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import scala.collection.mutable.{ArrayBuffer, ArrayBuilder} +import scala.collection.mutable.ArrayBuilder import org.apache.spark.SparkException import org.apache.spark.annotation.Experimental @@ -25,7 +25,7 @@ import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.Identifiable -import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} +import org.apache.spark.ml.Transformer import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ @@ -230,8 +230,10 @@ private[ml] class FeatureEncoder(numFeatures: Array[Int]) { /** Precomputed offsets for the location of each output feature. */ private val outputOffsets = { val arr = new Array[Int](numFeatures.length) - for (i <- 1 until arr.length) { + var i = 1 + while (i < arr.length) { arr(i) = arr(i - 1) + numFeatures(i - 1) + i += 1 } arr }