diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/Attribute.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/Attribute.scala new file mode 100644 index 0000000000000..afb03b488ac56 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/Attribute.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.attribute + +import org.apache.spark.sql.types.{MetadataBuilder, Metadata} + +abstract class Attribute(val index: Int, + val name: Option[String], + val dimension: Int) { + + require(index >= 0) + require(dimension >= 1) + + def featureType: FeatureType + + def toMetadata(): Metadata + + private[attribute] def toBaseMetadata(): MetadataBuilder = { + val builder = new MetadataBuilder() + builder.putLong("index", index) + if (name.isDefined) { + builder.putString("name", name.get) + } + if (dimension > 1) { + builder.putLong("dimension", dimension) + } + builder + } + +} + +object Attribute { + + def fromMetadata(metadata: Metadata): Attribute = { + FeatureTypes.withName(metadata.getString("type")) match { + case Categorical => CategoricalAttribute.fromMetadata(metadata) + case Continuous => ContinuousAttribute.fromMetadata(metadata) + } + } + + private[attribute] def parseBaseMetadata(metadata: Metadata): (Int, Option[String], Int) = { + val index = metadata.getLong("index").toInt + val name = if (metadata.contains("name")) Some(metadata.getString("name")) else None + val dimension = if (metadata.contains("dimension")) metadata.getLong("dimension").toInt else 1 + (index, name, dimension) + } + +} \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/CategoricalAttribute.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/CategoricalAttribute.scala new file mode 100644 index 0000000000000..4c54839863211 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/CategoricalAttribute.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.attribute + +import org.apache.spark.sql.types.Metadata + +class CategoricalAttribute private ( + override val index: Int, + override val name: Option[String], + override val dimension: Int, + val categories: Option[Array[String]], + val cardinality: Option[Int]) extends Attribute(index, name, dimension) { + + require(!categories.isDefined || categories.get.nonEmpty) + require(!cardinality.isDefined || cardinality.get > 0) + + override def featureType: FeatureType = Categorical + + override def toMetadata(): Metadata = { + val builder = toBaseMetadata() + if (categories.isDefined) { + builder.putStringArray("categories", categories.get) + } + if (cardinality.isDefined) { + builder.putLong("cardinality", cardinality.get) + } + builder.build() + } + +} + +private[attribute] object CategoricalAttribute { + + def fromMetadata(metadata: Metadata): CategoricalAttribute = { + val (index, name, dimension) = Attribute.parseBaseMetadata(metadata) + + var cardinality: Option[Int] = + if (metadata.contains("cardinality")) { + Some(metadata.getLong("cardinality").toInt) + } else { + None + } + + val categories: Option[Array[String]] = + if (metadata.contains("categories")) { + val theCategories = Some(metadata.getStringArray("categories")) + if (cardinality.isDefined) { + require(theCategories.get.size <= cardinality.get) + } else { + cardinality = Some(theCategories.get.size) + } + theCategories + } else { + None + } + + new CategoricalAttribute(index, name, dimension, categories, cardinality) + } + +} \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/ContinuousAttribute.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/ContinuousAttribute.scala new file mode 100644 index 0000000000000..137cc62921863 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/ContinuousAttribute.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.attribute + +import org.apache.spark.sql.types.Metadata + +class ContinuousAttribute private ( + override val index: Int, + override val name: Option[String], + override val dimension: Int, + val min: Option[Double], + val max: Option[Double]) extends Attribute(index, name, dimension) { + + if (min.isDefined && max.isDefined) { + require(min.get <= max.get) + } + + override def featureType(): FeatureType = Continuous + + override def toMetadata(): Metadata = { + val builder = toBaseMetadata() + if (min.isDefined) { + builder.putDouble("min", min.get) + } + if (max.isDefined) { + builder.putDouble("max", max.get) + } + builder.build() + } + +} + +private[attribute] object ContinuousAttribute { + + def fromMetadata(metadata: Metadata): ContinuousAttribute = { + val (index, name, dimension) = Attribute.parseBaseMetadata(metadata) + val min = if (metadata.contains("min")) Some(metadata.getDouble("min")) else None + val max = if (metadata.contains("max")) Some(metadata.getDouble("max")) else None + new ContinuousAttribute(index, name, dimension, min, max) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/FeatureAttributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/FeatureAttributes.scala new file mode 100644 index 0000000000000..fe69a6dc27203 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/FeatureAttributes.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.attribute + +import org.apache.spark.sql.types.{MetadataBuilder, Metadata} + +/** + * Representation of specialized information in a [[Metadata]] concerning + * data as machine learning features, with methods to access their associated attributes, like: + * + * - type (continuous, categorical, etc.) as [[FeatureType]] + * - optional feature name + * - for categorical features, the category values + * - for continuous values, maximum and minimum value + * - dimension for vector-valued features + * + * This information is stored as a [[Metadata]] under key "features", and contains an array of + * [[Metadata]] inside that for each feature for which metadata is defined. Example: + * + * {{{ + * { + * ... + * "features" : [ + * { + * "index": 0, + * "name": "age", + * "type": "CONTINUOUS", + * "min": 0 + * }, + * { + * "index": 5, + * "name": "gender", + * "type": "CATEGORICAL", + * "categories" : [ "male", "female" ] + * }, + * { + * "index": 6, + * "name": "customerType", + * "type": "CATEGORICAL", + * "cardinality": 10 + * }, + * { + * "index": 7, + * "name": "percentAllocations", + * "type": "CONTINUOUS", + * "dimension": 10, + * "min": 0, + * "max": 1 + * ] + * "producer": "..." + * ... + * } + * }}} + */ +class FeatureAttributes private (val attributes: Array[Attribute], + val producer: Option[String]) { + + private val nameToIndex: Map[String,Int] = + attributes.filter(_.name.isDefined).map(att => (att.name.get, att.index)).toMap + private val indexToAttribute: Map[Int,Attribute] = + attributes.map(att => (att.index, att)).toMap + private val categoricalIndices: Array[Int] = + attributes.filter(_.featureType match { + case c: CategoricalFeatureType => true + case _ => false + }).map(_.index) + + def getFeatureAttribute(index: Int): Option[Attribute] = indexToAttribute.get(index) + + def getFeatureIndex(featureName: String): Option[Int] = nameToIndex.get(featureName) + + def categoricalFeatureIndices(): Array[Int] = categoricalIndices + + def toMetadata(): Metadata = { + val builder = new MetadataBuilder() + builder.putMetadataArray("features", attributes.map(_.toMetadata())) + if (producer.isDefined) { + builder.putString("producer", producer.get) + } + builder.build() + } + +} + +object FeatureAttributes { + + def fromMetadata(metadata: Metadata): FeatureAttributes = { + val attributes = metadata.getMetadataArray("features").map(Attribute.fromMetadata(_)) + val producer = + if (metadata.contains("producer")) Some(metadata.getString("producer")) else None + new FeatureAttributes(attributes, producer) + } + +} \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/FeatureType.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/FeatureType.scala new file mode 100644 index 0000000000000..1ec9599be4696 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/FeatureType.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.attribute + +sealed trait FeatureType + +sealed trait ContinuousFeatureType extends FeatureType +sealed trait CategoricalFeatureType extends FeatureType +sealed trait DiscreteFeatureType extends ContinuousFeatureType + +case object Continuous extends ContinuousFeatureType +case object Categorical extends CategoricalFeatureType +case object Discrete extends DiscreteFeatureType +case object Binary extends DiscreteFeatureType with CategoricalFeatureType + +object FeatureTypes { + def withName(name: String): FeatureType = name match { + case "CONTINUOUS" => Continuous + case "CATEGORICAL" => Categorical + case "DISCRETE" => Discrete + case "BINARY" => Binary + } +}