diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LabelBinarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LabelBinarizer.scala new file mode 100644 index 000000000000..2c97397335bd --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LabelBinarizer.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} +import org.apache.spark.mllib.linalg.{Matrices, MatrixUDT} +import org.apache.spark.sql._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{StructField, StructType} + +/** + * :: Experimental :: + * Binarize a column of continuous features given a set of labels. + */ + +@Experimental +final class LabelBinarizer @Since("2.0.0")(override val uid: String) + extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { + + def this() = this(Identifiable.randomUID("labelBinarizer")) + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) + + val iter = udf { (s: String) => + val arr = s.split(",") + val len = arr.length + val clsLen = arr.distinct.length + val vec: Array[Double] = new Array(len * clsLen) + var i: Int = 0 + var j: Int = 0 + arr.distinct.sortWith(_ < _).foreach { (v: String) => + while (i < arr.length) { + val idx: Int = arr.indexOf(v, i) + if (idx != -1) { + vec.update(idx + j, 1) + } + i += 1 + } + i = 0 + j += len + } + Matrices.dense(len, clsLen, vec) + } + dataset.withColumn($(outputCol), iter(col($(inputCol)))) + } + + override def transformSchema(schema: StructType): StructType = { + val outputColName = $(outputCol) + val inputType = schema($(inputCol)).dataType + + if (!inputType.typeName.equals("string")) { + throw new IllegalArgumentException(s"Data type $inputType is not supported.") + } + + if (schema.fieldNames.contains(outputColName)) { + throw new IllegalArgumentException(s"Output column $outputColName already exists.") + } + val outputFields = schema.fields :+ StructField($(outputCol), new MatrixUDT, false) + StructType(outputFields) + } + + override def copy(extra: ParamMap): LabelBinarizer = defaultCopy(extra) +} + +@Since("1.6.0") +object LabelBinarizer extends DefaultParamsReadable[LabelBinarizer] { + + @Since("1.6.0") + override def load(path: String): LabelBinarizer = super.load(path) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/LabelBinarizerSuiter.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/LabelBinarizerSuiter.scala new file mode 100644 index 000000000000..32b7faa1199a --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/LabelBinarizerSuiter.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.{Matrices, Matrix} +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql._ + +class LabelBinarizerSuiter extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + test("Label binarize one-class case") { + val data = Array("pos,pos,pos,pos") + val expected: Array[Matrix] = Array( + Matrices.dense(4, 1, Array(1.0, 1.0, 1.0, 1.0))) + val dataFrame: DataFrame = sqlContext.createDataFrame( + data.zip(expected)).toDF("feature", "expected") + + val lBinarizer: LabelBinarizer = new LabelBinarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + + lBinarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach { + case Row(x: Matrix, y: Matrix) => + assert(x === y, "The feature value is not correct after binarization.") + } + } + + test("Label binarize two-class case") { + val data = Array("pos,neg,neg,pos") + val expected: Array[Matrix] = Array( + Matrices.dense(4, 2, Array(0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0))) + val dataFrame: DataFrame = sqlContext.createDataFrame( + data.zip(expected)).toDF("feature", "expected") + + val lBinarizer: LabelBinarizer = new LabelBinarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + + lBinarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach { + case Row(x: Matrix, y: Matrix) => + assert(x === y, "The feature value is not correct after binarization.") + } + } + + test("Label binarize multi-class case") { + val data = Array("yellow,green,red,green,0") + val expected: Array[Matrix] = Array( + Matrices.dense(5, 4, + Array(0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0))) + val dataFrame: DataFrame = sqlContext.createDataFrame( + data.zip(expected)).toDF("feature", "expected") + + val lBinarizer: LabelBinarizer = new LabelBinarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + + lBinarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach { + case Row(x: Matrix, y: Matrix) => + assert(x === y, "The feature value is not correct after binarization.") + } + } + + test("Label binarize combination case") { + val data = Array("yellow,green,red,green,0", + "pos,neg,neg,pos") + val expected: Array[Matrix] = Array( + Matrices.dense(5, 4, + Array(0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0)), + Matrices.dense(4, 2, Array(0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0)) + ) + val dataFrame: DataFrame = sqlContext.createDataFrame( + data.zip(expected)).toDF("feature", "expected") + + val lBinarizer: LabelBinarizer = new LabelBinarizer() + .setInputCol("feature") + .setOutputCol("binarized_feature") + + lBinarizer.transform(dataFrame).select("binarized_feature", "expected").collect().foreach { + case Row(x: Matrix, y: Matrix) => + assert(x === y, "The feature value is not correct after binarization.") + } + } + + test("read/write") { + val t = new Binarizer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setThreshold(0.1) + testDefaultReadWrite(t) + } +}