Skip to content

Commit 685ddcf

Browse files
committed
[SPARK-5886][ML] Add StringIndexer as a feature transformer
This PR adds string indexer, which takes a column of string labels and outputs a double column with labels indexed by their frequency. TODOs: - [x] store feature to index map in output metadata Author: Xiangrui Meng <[email protected]> Closes #4735 from mengxr/SPARK-5886 and squashes the following commits: d82575f [Xiangrui Meng] fix test 700e70f [Xiangrui Meng] rename LabelIndexer to StringIndexer 16a6f8c [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5886 457166e [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5886 f8b30f4 [Xiangrui Meng] update label indexer to output metadata e81ec28 [Xiangrui Meng] Merge branch 'openhashmap-contains' into SPARK-5886-2 d6e6f1f [Xiangrui Meng] add contains to primitivekeyopenhashmap 748a69b [Xiangrui Meng] add contains to OpenHashMap def3c5c [Xiangrui Meng] add LabelIndexer
1 parent d3792f5 commit 685ddcf

File tree

2 files changed

+178
-0
lines changed

2 files changed

+178
-0
lines changed
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.feature
19+
20+
import org.apache.spark.SparkException
21+
import org.apache.spark.annotation.AlphaComponent
22+
import org.apache.spark.ml.{Estimator, Model}
23+
import org.apache.spark.ml.attribute.NominalAttribute
24+
import org.apache.spark.ml.param._
25+
import org.apache.spark.sql.DataFrame
26+
import org.apache.spark.sql.functions._
27+
import org.apache.spark.sql.types.{StringType, StructType}
28+
import org.apache.spark.util.collection.OpenHashMap
29+
30+
/**
31+
* Base trait for [[StringIndexer]] and [[StringIndexerModel]].
32+
*/
33+
private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol {
34+
35+
/** Validates and transforms the input schema. */
36+
protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = {
37+
val map = this.paramMap ++ paramMap
38+
checkInputColumn(schema, map(inputCol), StringType)
39+
val inputFields = schema.fields
40+
val outputColName = map(outputCol)
41+
require(inputFields.forall(_.name != outputColName),
42+
s"Output column $outputColName already exists.")
43+
val attr = NominalAttribute.defaultAttr.withName(map(outputCol))
44+
val outputFields = inputFields :+ attr.toStructField()
45+
StructType(outputFields)
46+
}
47+
}
48+
49+
/**
50+
* :: AlphaComponent ::
51+
* A label indexer that maps a string column of labels to an ML column of label indices.
52+
* The indices are in [0, numLabels), ordered by label frequencies.
53+
* So the most frequent label gets index 0.
54+
*/
55+
@AlphaComponent
56+
class StringIndexer extends Estimator[StringIndexerModel] with StringIndexerBase {
57+
58+
/** @group setParam */
59+
def setInputCol(value: String): this.type = set(inputCol, value)
60+
61+
/** @group setParam */
62+
def setOutputCol(value: String): this.type = set(outputCol, value)
63+
64+
// TODO: handle unseen labels
65+
66+
override def fit(dataset: DataFrame, paramMap: ParamMap): StringIndexerModel = {
67+
val map = this.paramMap ++ paramMap
68+
val counts = dataset.select(map(inputCol)).map(_.getString(0)).countByValue()
69+
val labels = counts.toSeq.sortBy(-_._2).map(_._1).toArray
70+
val model = new StringIndexerModel(this, map, labels)
71+
Params.inheritValues(map, this, model)
72+
model
73+
}
74+
75+
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
76+
validateAndTransformSchema(schema, paramMap)
77+
}
78+
}
79+
80+
/**
81+
* :: AlphaComponent ::
82+
* Model fitted by [[StringIndexer]].
83+
*/
84+
@AlphaComponent
85+
class StringIndexerModel private[ml] (
86+
override val parent: StringIndexer,
87+
override val fittingParamMap: ParamMap,
88+
labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase {
89+
90+
private val labelToIndex: OpenHashMap[String, Double] = {
91+
val n = labels.length
92+
val map = new OpenHashMap[String, Double](n)
93+
var i = 0
94+
while (i < n) {
95+
map.update(labels(i), i)
96+
i += 1
97+
}
98+
map
99+
}
100+
101+
/** @group setParam */
102+
def setInputCol(value: String): this.type = set(inputCol, value)
103+
104+
/** @group setParam */
105+
def setOutputCol(value: String): this.type = set(outputCol, value)
106+
107+
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
108+
val map = this.paramMap ++ paramMap
109+
val indexer = udf { label: String =>
110+
if (labelToIndex.contains(label)) {
111+
labelToIndex(label)
112+
} else {
113+
// TODO: handle unseen labels
114+
throw new SparkException(s"Unseen label: $label.")
115+
}
116+
}
117+
val outputColName = map(outputCol)
118+
val metadata = NominalAttribute.defaultAttr
119+
.withName(outputColName).withValues(labels).toStructField().metadata
120+
dataset.select(col("*"), indexer(dataset(map(inputCol))).as(outputColName, metadata))
121+
}
122+
123+
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
124+
validateAndTransformSchema(schema, paramMap)
125+
}
126+
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.feature
19+
20+
import org.scalatest.FunSuite
21+
22+
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
23+
import org.apache.spark.mllib.util.MLlibTestSparkContext
24+
import org.apache.spark.sql.SQLContext
25+
26+
class StringIndexerSuite extends FunSuite with MLlibTestSparkContext {
27+
private var sqlContext: SQLContext = _
28+
29+
override def beforeAll(): Unit = {
30+
super.beforeAll()
31+
sqlContext = new SQLContext(sc)
32+
}
33+
34+
test("StringIndexer") {
35+
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
36+
val df = sqlContext.createDataFrame(data).toDF("id", "label")
37+
val indexer = new StringIndexer()
38+
.setInputCol("label")
39+
.setOutputCol("labelIndex")
40+
.fit(df)
41+
val transformed = indexer.transform(df)
42+
val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
43+
.asInstanceOf[NominalAttribute]
44+
assert(attr.values.get === Array("a", "c", "b"))
45+
val output = transformed.select("id", "labelIndex").map { r =>
46+
(r.getInt(0), r.getDouble(1))
47+
}.collect().toSet
48+
// a -> 0, b -> 2, c -> 1
49+
val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0))
50+
assert(output === expected)
51+
}
52+
}

0 commit comments

Comments
 (0)