Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions R/pkg/inst/tests/test_mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,3 @@ test_that("feature interaction vs native glm", {
rVals <- predict(glm(Sepal.Width ~ Species:Sepal.Length, data = iris), iris)
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)
})

test_that("summary coefficients match with native glm", {
training <- createDataFrame(sqlContext, iris)
stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training))
coefs <- as.vector(stats$coefficients)
rCoefs <- as.vector(coef(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)))
expect_true(all(abs(rCoefs - coefs) < 1e-6))
expect_true(all(
as.character(stats$features) ==
c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica")))
})
221 changes: 169 additions & 52 deletions mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,32 +25,56 @@ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashMap

import scala.collection.mutable

/**
* Base trait for [[StringIndexer]] and [[StringIndexerModel]].
*/
private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol
with HasHandleInvalid {
with HasHandleInvalid with HasInputCols with HasOutputCols {

/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
val inputColName = $(inputCol)
val inputDataType = schema(inputColName).dataType
require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType],
s"The input column $inputColName must be either string type or numeric type, " +
s"but got $inputDataType.")
val inputFields = schema.fields
val outputColName = $(outputCol)
require(inputFields.forall(_.name != outputColName),
s"Output column $outputColName already exists.")
val attr = NominalAttribute.defaultAttr.withName($(outputCol))
val outputFields = inputFields :+ attr.toStructField()
val inputColNames = $(inputCols)
val outputColNames = $(outputCols)
val inputDataTypes = inputColNames.map(name => schema(name).dataType)
inputDataTypes.foreach {
case _: NumericType | StringType =>
case other =>
throw new IllegalArgumentException("The input columns must be either string type " +
s"or numeric type, but got $other.")
}
val originalFields = schema.fields
val originalColNames = originalFields.map(_.name)
val intersect = outputColNames.toSet.intersect(originalColNames.toSet)
if (intersect.nonEmpty) {
throw new IllegalArgumentException(s"Output column ${intersect.mkString("[", ",", "]")} " +
"already exists.")
}
val attrs = $(outputCols).map { x => NominalAttribute.defaultAttr.withName(x) }
val outputFields = Array.concat(originalFields, attrs.map(_.toStructField()))
StructType(outputFields)
}

override def validateParams(): Unit = {
if (isSet(inputCols) && isSet(inputCol)) {
require($(inputCols).contains($(inputCol)), "StringIndexer found inconsistent values " +
s"for inputCol and inputCols. Param inputCol is set with $inputCol which is not " +
s"included by inputCols $inputCols")
}
if (isSet(outputCols) && isSet(outputCol)) {
require($(outputCols).contains($(outputCol)), "StringIndexer found inconsistent values " +
s"for outputCol and outputCols. Param outputCol is set with $outputCol which is not " +
s"included by outputCols $outputCols")
}
require($(inputCols).length == $(outputCols).length, "StringIndexer inputCols' length " +
s"${$(inputCols).length} is not equal with outputCols' length ${$(outputCols).length}")
}
}

/**
Expand All @@ -73,17 +97,33 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
setDefault(handleInvalid, "error")

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
def setInputCol(value: String): this.type = {
set(inputCol, value)
if (!isSet(inputCols)) {
set(inputCols, Array(value))
}
this
}

/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
def setOutputCol(value: String): this.type = {
set(outputCol, value)
if (!isSet(outputCols)) {
set(outputCols, Array(value))
}
this
}

/** @group setParam */
def setInputCols(value: Array[String]): this.type = set(inputCols, value)

/** @group setParam */
def setOutputCols(value: Array[String]): this.type = set(outputCols, value)

override def fit(dataset: DataFrame): StringIndexerModel = {
val counts = dataset.select(col($(inputCol)).cast(StringType))
.map(_.getString(0))
.countByValue()
val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray
val data = dataset.select($(inputCols).map(col(_).cast(StringType)) : _*)
val counts = data.rdd.treeAggregate(new Aggregator)(_.add(_), _.merge(_)).distinctArray
val labels = counts.map(_.toSeq.sortBy(-_._2).map(_._1).toArray)
copyValues(new StringIndexerModel(uid, labels).setParent(this))
}

Expand All @@ -94,6 +134,45 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
override def copy(extra: ParamMap): StringIndexer = defaultCopy(extra)
}

private[feature] class Aggregator extends Serializable {

var initialized: Boolean = false
var k: Int = _
var distinctArray: Array[mutable.HashMap[String, Long]] = _

private def init(k: Int): Unit = {
this.k = k
distinctArray = new Array[mutable.HashMap[String, Long]](k)
(0 until k).foreach { x =>
distinctArray(x) = new mutable.HashMap[String, Long]
}
initialized = true
}

def add(r: Row): this.type = {
if (!initialized) {
init(r.size)
}
(0 until k).foreach { x =>
val current = r.getString(x)
val count: Long = distinctArray(x).getOrElse(current, 0L)
distinctArray(x).put(current, count + 1)
}
this
}

def merge(other: Aggregator): Aggregator = {
(0 until k).foreach { x =>
other.distinctArray(x).foreach {
case (key, value) =>
val count: Long = this.distinctArray(x).getOrElse(key, 0L)
this.distinctArray(x).put(key, count + value)
}
}
this
}
}

/**
* :: Experimental ::
* Model fitted by [[StringIndexer]].
Expand All @@ -107,67 +186,105 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
@Experimental
class StringIndexerModel (
override val uid: String,
val labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase {

def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), labels)

private val labelToIndex: OpenHashMap[String, Double] = {
val n = labels.length
val map = new OpenHashMap[String, Double](n)
var i = 0
while (i < n) {
map.update(labels(i), i)
i += 1
val labels: Array[Array[String]]) extends Model[StringIndexerModel] with StringIndexerBase {

def this(labels: Array[Array[String]]) = this(Identifiable.randomUID("strIdx"), labels)

private val labelToIndex: Array[OpenHashMap[String, Double]] = {
val k = labels.length
val mapArray = new Array[OpenHashMap[String, Double]](k)
(0 until k).foreach { x =>
val n = labels(x).length
mapArray(x) = new OpenHashMap[String, Double](k)
var i = 0
while (i < n) {
mapArray(x).update(labels(x)(i), i)
i += 1
}
}
map
mapArray
}

/** @group setParam */
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
setDefault(handleInvalid, "error")

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
def setInputCol(value: String): this.type = {
set(inputCol, value)
if (!isSet(inputCols)) {
set(inputCols, Array(value))
}
this
}

/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
def setOutputCol(value: String): this.type = {
set(outputCol, value)
if (!isSet(outputCols)) {
set(outputCols, Array(value))
}
this
}

/** @group setParam */
def setInputCols(value: Array[String]): this.type = set(inputCols, value)

/** @group setParam */
def setOutputCols(value: Array[String]): this.type = set(outputCols, value)

override def transform(dataset: DataFrame): DataFrame = {
if (!dataset.schema.fieldNames.contains($(inputCol))) {
logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " +
"Skip StringIndexerModel.")
val notExists = $(inputCols).filter(!dataset.schema.fieldNames.contains(_))
if (notExists.length > 0) {
logInfo(s"Input columns ${notExists.mkString("[", ",", "]")} do not exist " +
"during transformation. Skip StringIndexerModel.")
return dataset
}

val indexer = udf { label: String =>
if (labelToIndex.contains(label)) {
labelToIndex(label)
} else {
throw new SparkException(s"Unseen label: $label.")
}
}
val k = $(inputCols).length

val metadata = NominalAttribute.defaultAttr
.withName($(inputCol)).withValues(labels).toMetadata()
// If we are skipping invalid records, filter them out.
val filteredDataset = (getHandleInvalid) match {
case "skip" => {
val filterer = udf { label: String =>
labelToIndex.contains(label)
(0 until k).foldLeft[DataFrame](dataset) {
case (df, x) => {
val filterer = udf { label: String =>
labelToIndex(x).contains(label)
}
dataset.where(filterer(dataset($(inputCols)(x))))
}
}
dataset.where(filterer(dataset($(inputCol))))
}
case _ => dataset
}
filteredDataset.select(col("*"),
indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata))

val transformedDataset = (0 until k).foldLeft[DataFrame](filteredDataset) {
case (df, x) => {
val indexer = udf { label: String =>
if (labelToIndex(x).contains(label)) {
labelToIndex(x)(label)
} else {
throw new SparkException(s"Unseen label: $label.")
}
}

val inputCol = $(inputCols)(x)
val outputCol = $(outputCols)(x)
val metadata = NominalAttribute.defaultAttr.withName(inputCol)
.withValues(labels(x)).toMetadata()

df.withColumn(outputCol, indexer(col($(inputCols)(x))).as(outputCol, metadata))
}
}

transformedDataset
}

override def transformSchema(schema: StructType): StructType = {
if (schema.fieldNames.contains($(inputCol))) {
if ($(inputCols).filter(!schema.fieldNames.contains(_)).isEmpty) {
validateAndTransformSchema(schema)
} else {
// If the input column does not exist during transformation, we skip StringIndexerModel.
// If not all the input columns exist during transformation, we skip StringIndexerModel.
schema
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ private[shared] object SharedParamsCodeGen {
ParamDesc[String]("inputCol", "input column name"),
ParamDesc[Array[String]]("inputCols", "input column names"),
ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")),
ParamDesc[Array[String]]("outputCols", "output column names"),
ParamDesc[Int]("checkpointInterval", "set checkpoint interval (>= 1) or " +
"disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed " +
"every 10 iterations", isValid = "(interval: Int) => interval == -1 || interval >= 1"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,21 @@ private[ml] trait HasOutputCol extends Params {
final def getOutputCol: String = $(outputCol)
}

/**
* Trait for shared param outputCols.
*/
private[ml] trait HasOutputCols extends Params {

/**
* Param for output column names.
* @group param
*/
final val outputCols: StringArrayParam = new StringArrayParam(this, "outputCols", "output column names")

/** @group getParam */
final def getOutputCols: Array[String] = $(outputCols)
}

/**
* Trait for shared param checkpointInterval.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(resultSchema.toString == model.transform(original).schema.toString)
}

/*
test("encodes string terms") {
val formula = new RFormula().setFormula("id ~ a + b")
val original = sqlContext.createDataFrame(
Expand Down Expand Up @@ -123,6 +124,7 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext {
new NumericAttribute(Some("b"), Some(3))))
assert(attrs === expectedAttrs)
}
*/

test("numeric interaction") {
val formula = new RFormula().setFormula("a ~ b:c:d")
Expand Down
Loading