Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.ml.attribute

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.mllib.linalg.VectorUDT
import org.apache.spark.sql.types.{Metadata, MetadataBuilder, StructField}

Expand All @@ -34,9 +34,10 @@ import org.apache.spark.sql.types.{Metadata, MetadataBuilder, StructField}
* indices in the array.
*/
@DeveloperApi
@Since("1.4.0")
class AttributeGroup private (
val name: String,
val numAttributes: Option[Int],
@Since("1.4.0") val name: String,
@Since("1.4.0") val numAttributes: Option[Int],
attrs: Option[Array[Attribute]]) extends Serializable {

require(name.nonEmpty, "Cannot have an empty string for name.")
Expand All @@ -47,13 +48,15 @@ class AttributeGroup private (
* Creates an attribute group without attribute info.
* @param name name of the attribute group
*/
@Since("1.4.0")
def this(name: String) = this(name, None, None)

/**
* Creates an attribute group knowing only the number of attributes.
* @param name name of the attribute group
* @param numAttributes number of attributes
*/
@Since("1.4.0")
def this(name: String, numAttributes: Int) = this(name, Some(numAttributes), None)

/**
Expand All @@ -62,11 +65,13 @@ class AttributeGroup private (
* @param attrs array of attributes. Attributes will be copied with their corresponding indices in
* the array.
*/
@Since("1.4.0")
def this(name: String, attrs: Array[Attribute]) = this(name, None, Some(attrs))

/**
* Optional array of attributes. At most one of `numAttributes` and `attributes` can be defined.
*/
@Since("1.4.0")
val attributes: Option[Array[Attribute]] = attrs.map(_.view.zipWithIndex.map { case (attr, i) =>
attr.withIndex(i)
}.toArray)
Expand All @@ -78,6 +83,7 @@ class AttributeGroup private (
}

/** Size of the attribute group. Returns -1 if the size is unknown. */
@Since("1.4.0")
def size: Int = {
if (numAttributes.isDefined) {
numAttributes.get
Expand All @@ -89,23 +95,29 @@ class AttributeGroup private (
}

/** Test whether this attribute group contains a specific attribute. */
@Since("1.4.0")
def hasAttr(attrName: String): Boolean = nameToIndex.contains(attrName)

/** Index of an attribute specified by name. */
@Since("1.4.0")
def indexOf(attrName: String): Int = nameToIndex(attrName)

/** Gets an attribute by its name. */
@Since("1.4.0")
def apply(attrName: String): Attribute = {
attributes.get(indexOf(attrName))
}

/** Gets an attribute by its name. */
@Since("1.4.0")
def getAttr(attrName: String): Attribute = this(attrName)

/** Gets an attribute by its index. */
@Since("1.4.0")
def apply(attrIndex: Int): Attribute = attributes.get(attrIndex)

/** Gets an attribute by its index. */
@Since("1.4.0")
def getAttr(attrIndex: Int): Attribute = this(attrIndex)

/** Converts to metadata without name. */
Expand Down Expand Up @@ -147,6 +159,7 @@ class AttributeGroup private (
}

/** Converts to ML metadata with some existing metadata. */
@Since("1.4.1")
def toMetadata(existingMetadata: Metadata): Metadata = {
new MetadataBuilder()
.withMetadata(existingMetadata)
Expand All @@ -158,13 +171,16 @@ class AttributeGroup private (
def toMetadata(): Metadata = toMetadata(Metadata.empty)

/** Converts to a StructField with some existing metadata. */
@Since("1.4.0")
def toStructField(existingMetadata: Metadata): StructField = {
StructField(name, new VectorUDT, nullable = false, toMetadata(existingMetadata))
}

/** Converts to a StructField. */
@Since("1.4.0")
def toStructField(): StructField = toStructField(Metadata.empty)

@Since("1.4.0")
override def equals(other: Any): Boolean = {
other match {
case o: AttributeGroup =>
Expand All @@ -176,6 +192,7 @@ class AttributeGroup private (
}
}

@Since("1.4.0")
override def hashCode: Int = {
var sum = 17
sum = 37 * sum + name.hashCode
Expand All @@ -184,6 +201,8 @@ class AttributeGroup private (
sum
}


@Since("1.6.0")
override def toString: String = toMetadata.toString
}

Expand All @@ -192,6 +211,7 @@ class AttributeGroup private (
* Factory methods to create attribute groups.
*/
@DeveloperApi
@Since("1.4.0")
object AttributeGroup {

import AttributeKeys._
Expand Down Expand Up @@ -240,6 +260,7 @@ object AttributeGroup {
}

/** Creates an attribute group from a [[StructField]] instance. */
@Since("1.4.0")
def fromStructField(field: StructField): AttributeGroup = {
require(field.dataType == new VectorUDT)
if (field.metadata.contains(ML_ATTR)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,38 +17,44 @@

package org.apache.spark.ml.attribute

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.annotation.{DeveloperApi, Since}

/**
* :: DeveloperApi ::
* An enum-like type for attribute types: [[AttributeType$#Numeric]], [[AttributeType$#Nominal]],
* and [[AttributeType$#Binary]].
*/
@DeveloperApi
sealed abstract class AttributeType(val name: String)
@Since("1.4.0")
sealed abstract class AttributeType(@Since("1.4.0") val name: String)

@DeveloperApi
@Since("1.4.0")
object AttributeType {

/** Numeric type. */
@Since("1.4.0")
val Numeric: AttributeType = {
case object Numeric extends AttributeType("numeric")
Numeric
}

/** Nominal type. */
@Since("1.4.0")
val Nominal: AttributeType = {
case object Nominal extends AttributeType("nominal")
Nominal
}

/** Binary type. */
@Since("1.4.0")
val Binary: AttributeType = {
case object Binary extends AttributeType("binary")
Binary
}

/** Unresolved type. */
@Since("1.5.0")
val Unresolved: AttributeType = {
case object Unresolved extends AttributeType("unresolved")
Unresolved
Expand All @@ -58,6 +64,7 @@ object AttributeType {
* Gets the [[AttributeType]] object from its name.
* @param name attribute type name: "numeric", "nominal", or "binary"
*/
@Since("1.4.0")
def fromName(name: String): AttributeType = {
if (name == Numeric.name) {
Numeric
Expand Down
Loading