Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,6 @@ public static StructType createStructType(StructField[] fields) {
throw new IllegalArgumentException("fields should have distinct names.");
}

return StructType$.MODULE$.apply(fields);
return new StructType(fields);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ object RowEncoder {
"fromString",
inputObject :: Nil)

case t @ ArrayType(et, _) => et match {
case t @ ArrayType(et, _, _) => et match {
case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType =>
// TODO: validate input type for primitive array.
NewInstance(
Expand Down Expand Up @@ -152,7 +152,7 @@ object RowEncoder {
convertedKeys :: convertedValues :: Nil,
dataType = t)

case StructType(fields) =>
case StructType(fields, _) =>
val nonNullOutput = CreateNamedStruct(fields.zipWithIndex.flatMap { case (field, index) =>
val fieldValue = serializerFor(
ValidateExternalType(
Expand Down Expand Up @@ -259,7 +259,7 @@ object RowEncoder {
case StringType =>
Invoke(input, "toString", ObjectType(classOf[String]))

case ArrayType(et, nullable) =>
case ArrayType(et, nullable, _) =>
val arrayData =
Invoke(
MapObjects(deserializerFor(_), input, et),
Expand All @@ -284,7 +284,7 @@ object RowEncoder {
"toScalaMap",
keyData :: valueData :: Nil)

case schema @ StructType(fields) =>
case schema @ StructType(fields, _) =>
val convertedFields = fields.zipWithIndex.map { case (f, i) =>
If(
Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ object Cast {
case (TimestampType, _: NumericType) => true
case (_: NumericType, _: NumericType) => true

case (ArrayType(fromType, fn), ArrayType(toType, tn)) =>
case (ArrayType(fromType, fn, _), ArrayType(toType, tn, _)) =>
canCast(fromType, toType) &&
resolvableNullability(fn || forceNullable(fromType, toType), tn)

Expand All @@ -72,7 +72,7 @@ object Cast {
canCast(fromValue, toValue) &&
resolvableNullability(fn || forceNullable(fromValue, toValue), tn)

case (StructType(fromFields), StructType(toFields)) =>
case (StructType(fromFields, _), StructType(toFields, _)) =>
fromFields.length == toFields.length &&
fromFields.zip(toFields).forall {
case (fromField, toField) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
input: String,
dataType: DataType): ExprCode = dataType match {
case s: StructType => createCodeForStruct(ctx, input, s)
case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType)
case ArrayType(elementType, _, _) => createCodeForArray(ctx, input, elementType)
case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType)
// UTF8String act as a pointer if it's inside UnsafeRow, so copy it to make it safe.
case StringType => ExprCode("", "false", s"$input.clone()")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
$rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
"""

case a @ ArrayType(et, _) =>
case a @ ArrayType(et, _, _) =>
s"""
// Remember the current cursor so that we can calculate how many bytes are
// written later.
Expand Down Expand Up @@ -202,7 +202,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
${writeStructToBuffer(ctx, element, t.map(_.dataType), bufferHolder)}
"""

case a @ ArrayType(et, _) =>
case a @ ArrayType(et, _, _) =>
s"""
$arrayWriter.setOffset($index);
${writeArrayToBuffer(ctx, element, et, bufferHolder)}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType)

override def checkInputDataTypes(): TypeCheckResult = base.dataType match {
case ArrayType(dt, _) if RowOrdering.isOrderable(dt) =>
case ArrayType(dt, _, _) if RowOrdering.isOrderable(dt) =>
TypeCheckResult.TypeCheckSuccess
case ArrayType(dt, _) =>
case ArrayType(dt, _, _) =>
TypeCheckResult.TypeCheckFailure(
s"$prettyName does not support sorting array of type ${dt.simpleString}")
case _ =>
Expand All @@ -123,9 +123,9 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
@transient
private lazy val lt: Comparator[Any] = {
val ordering = base.dataType match {
case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(n: AtomicType, _, _) => n.ordering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(a: ArrayType, _, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(s: StructType, _, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
}

new Comparator[Any]() {
Expand All @@ -146,9 +146,9 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
@transient
private lazy val gt: Comparator[Any] = {
val ordering = base.dataType match {
case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(n: AtomicType, _, _) => n.ordering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(a: ArrayType, _, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(s: StructType, _, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
}

new Comparator[Any]() {
Expand Down Expand Up @@ -192,7 +192,7 @@ case class ArrayContains(left: Expression, right: Expression)
override def inputTypes: Seq[AbstractDataType] = right.dataType match {
case NullType => Seq()
case _ => left.dataType match {
case n @ ArrayType(element, _) => Seq(n, element)
case n @ ArrayType(element, _, _) => Seq(n, element)
case _ => Seq()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@ object ExtractValue {
resolver: Resolver): Expression = {

(child.dataType, extraction) match {
case (StructType(fields), NonNullLiteral(v, StringType)) =>
case (StructType(fields, _), NonNullLiteral(v, StringType)) =>
val fieldName = v.toString
val ordinal = findField(fields, fieldName, resolver)
GetStructField(child, ordinal, Some(fieldName))

case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) =>
case (ArrayType(StructType(fields, _), containsNull, _), NonNullLiteral(v, StringType)) =>
val fieldName = v.toString
val ordinal = findField(fields, fieldName, resolver)
GetArrayStructFields(child, fields(ordinal).copy(name = fieldName),
Expand All @@ -65,7 +65,7 @@ object ExtractValue {

case (otherType, _) =>
val errorMsg = otherType match {
case StructType(_) =>
case StructType(_, _) =>
s"Field name should be String Literal, but it's $extraction"
case other =>
s"Can't extract value from $child"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ abstract class ExplodeBase(child: Expression, position: Boolean)

// hive-compatible default alias for explode function ("col" for array, "key", "value" for map)
override def elementSchema: StructType = child.dataType match {
case ArrayType(et, containsNull) =>
case ArrayType(et, containsNull, _) =>
if (position) {
new StructType()
.add("pos", IntegerType, false)
Expand All @@ -189,7 +189,7 @@ abstract class ExplodeBase(child: Expression, position: Boolean)

override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
child.dataType match {
case ArrayType(et, _) =>
case ArrayType(et, _, _) =>
val inputArray = child.eval(input).asInstanceOf[ArrayData]
if (inputArray == null) {
Nil
Expand Down Expand Up @@ -260,15 +260,15 @@ case class Inline(child: Expression) extends UnaryExpression with Generator with
override def children: Seq[Expression] = child :: Nil

override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
case ArrayType(et, _) if et.isInstanceOf[StructType] =>
case ArrayType(et, _, _) if et.isInstanceOf[StructType] =>
TypeCheckResult.TypeCheckSuccess
case _ =>
TypeCheckResult.TypeCheckFailure(
s"input to function $prettyName should be array of struct type, not ${child.dataType}")
}

override def elementSchema: StructType = child.dataType match {
case ArrayType(et : StructType, _) => et
case ArrayType(et : StructType, _, _) => et
}

private lazy val numFields = elementSchema.fields.length
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ abstract class HashExpression[E] extends Expression {
val numBytes = s"$input.numBytes()"
s"$result = $hasher.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);"

case ArrayType(et, containsNull) =>
case ArrayType(et, containsNull, _) =>
val index = ctx.freshName("index")
s"""
for (int $index = 0; $index < $input.numElements(); $index++) {
Expand All @@ -337,7 +337,7 @@ abstract class HashExpression[E] extends Expression {
}
"""

case StructType(fields) =>
case StructType(fields, _) =>
fields.zipWithIndex.map { case (field, index) =>
nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx)
}.mkString("\n")
Expand Down Expand Up @@ -386,7 +386,7 @@ abstract class InterpretedHashFunction {
case array: ArrayData =>
val elementType = dataType match {
case udt: UserDefinedType[_] => udt.sqlType.asInstanceOf[ArrayType].elementType
case ArrayType(et, _) => et
case ArrayType(et, _, _) => et
}
var result = seed
var i = 0
Expand Down Expand Up @@ -418,7 +418,7 @@ abstract class InterpretedHashFunction {
val types: Array[DataType] = dataType match {
case udt: UserDefinedType[_] =>
udt.sqlType.asInstanceOf[StructType].map(_.dataType).toArray
case StructType(fields) => fields.map(_.dataType)
case StructType(fields, _) => fields.map(_.dataType)
}
var result = seed
var i = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ case class MapObjects private(
s"${genInputData.value}.length" -> s"${genInputData.value}[$loopIndex]"
case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
s"${genInputData.value}.size()" -> s"${genInputData.value}.get($loopIndex)"
case ArrayType(et, _) =>
case ArrayType(et, _, _) =>
s"${genInputData.value}.numElements()" -> ctx.getValue(genInputData.value, et, loopIndex)
case ObjectType(cls) if cls == classOf[Object] =>
s"$seq == null ? $array.length : $seq.size()" ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,16 @@ object ArrayType extends AbstractDataType {
*
* @param elementType The data type of values.
* @param containsNull Indicates if values have `null` values
* @param metadata The metadata of this array type.
*/
@DeveloperApi
case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType {
case class ArrayType(
elementType: DataType,
containsNull: Boolean,
metadata: Metadata = Metadata.empty) extends DataType {

protected def this(elementType: DataType, containsNull: Boolean) =
this(elementType, containsNull, Metadata.empty)

/** No-arg constructor for kryo. */
protected def this() = this(null, false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,12 +196,12 @@ object DataType {
*/
private[types] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = {
(left, right) match {
case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) =>
case (ArrayType(leftElementType, _, _), ArrayType(rightElementType, _, _)) =>
equalsIgnoreNullability(leftElementType, rightElementType)
case (MapType(leftKeyType, leftValueType, _), MapType(rightKeyType, rightValueType, _)) =>
equalsIgnoreNullability(leftKeyType, rightKeyType) &&
equalsIgnoreNullability(leftValueType, rightValueType)
case (StructType(leftFields), StructType(rightFields)) =>
case (StructType(leftFields, _), StructType(rightFields, _)) =>
leftFields.length == rightFields.length &&
leftFields.zip(rightFields).forall { case (l, r) =>
l.name == r.name && equalsIgnoreNullability(l.dataType, r.dataType)
Expand All @@ -226,15 +226,15 @@ object DataType {
*/
private[sql] def equalsIgnoreCompatibleNullability(from: DataType, to: DataType): Boolean = {
(from, to) match {
case (ArrayType(fromElement, fn), ArrayType(toElement, tn)) =>
case (ArrayType(fromElement, fn, _), ArrayType(toElement, tn, _)) =>
(tn || !fn) && equalsIgnoreCompatibleNullability(fromElement, toElement)

case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
(tn || !fn) &&
equalsIgnoreCompatibleNullability(fromKey, toKey) &&
equalsIgnoreCompatibleNullability(fromValue, toValue)

case (StructType(fromFields), StructType(toFields)) =>
case (StructType(fromFields, _), StructType(toFields, _)) =>
fromFields.length == toFields.length &&
fromFields.zip(toFields).forall { case (fromField, toField) =>
fromField.name == toField.name &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,11 @@ import org.apache.spark.util.Utils
* }}}
*/
@DeveloperApi
case class StructType(fields: Array[StructField]) extends DataType with Seq[StructField] {
case class StructType(
fields: Array[StructField],
metadata: Metadata = Metadata.empty) extends DataType with Seq[StructField] {

def this(fields: Array[StructField]) = this(fields, Metadata.empty)

/** No-arg constructor for kryo. */
def this() = this(Array.empty[StructField])
Expand All @@ -106,7 +110,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru

override def equals(that: Any): Boolean = {
that match {
case StructType(otherFields) =>
case StructType(otherFields, _) =>
java.util.Arrays.equals(
fields.asInstanceOf[Array[AnyRef]], otherFields.asInstanceOf[Array[AnyRef]])
case _ => false
Expand Down Expand Up @@ -417,19 +421,19 @@ object StructType extends AbstractDataType {
}
}

def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray)
def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray, Metadata.empty)

def apply(fields: java.util.List[StructField]): StructType = {
import scala.collection.JavaConverters._
StructType(fields.asScala)
apply(fields.asScala)
}

private[sql] def fromAttributes(attributes: Seq[Attribute]): StructType =
StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)))

private[sql] def removeMetadata(key: String, dt: DataType): DataType =
dt match {
case StructType(fields) =>
case StructType(fields, _) =>
val newFields = fields.map { f =>
val mb = new MetadataBuilder()
f.copy(dataType = removeMetadata(key, f.dataType),
Expand All @@ -441,8 +445,8 @@ object StructType extends AbstractDataType {

private[sql] def merge(left: DataType, right: DataType): DataType =
(left, right) match {
case (ArrayType(leftElementType, leftContainsNull),
ArrayType(rightElementType, rightContainsNull)) =>
case (ArrayType(leftElementType, leftContainsNull, _),
ArrayType(rightElementType, rightContainsNull, _)) =>
ArrayType(
merge(leftElementType, rightElementType),
leftContainsNull || rightContainsNull)
Expand All @@ -454,7 +458,7 @@ object StructType extends AbstractDataType {
merge(leftValueType, rightValueType),
leftContainsNull || rightContainsNull)

case (StructType(leftFields), StructType(rightFields)) =>
case (StructType(leftFields, _), StructType(rightFields, _)) =>
val newFields = ArrayBuffer.empty[StructField]
// This metadata will record the fields that only exist in one of two StructTypes
val optionalMeta = new MetadataBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ object RandomDataGenerator {
case ShortType => randomNumeric[Short](
rand, _.nextInt().toShort, Seq(Short.MinValue, Short.MaxValue, 0.toShort))
case NullType => Some(() => null)
case ArrayType(elementType, containsNull) =>
case ArrayType(elementType, containsNull, _) =>
forType(elementType, nullable = containsNull, rand).map {
elementGenerator => () => Seq.fill(rand.nextInt(MAX_ARR_SIZE))(elementGenerator())
}
Expand All @@ -220,7 +220,7 @@ object RandomDataGenerator {
keys.zip(values).toMap
}
}
case StructType(fields) =>
case StructType(fields, _) =>
val maybeFieldGenerators: Seq[Option[() => Any]] = fields.map { field =>
forType(field.dataType, nullable = field.nullable, rand)
}
Expand Down Expand Up @@ -269,7 +269,7 @@ object RandomDataGenerator {
val fields = mutable.ArrayBuffer.empty[Any]
schema.fields.foreach { f =>
f.dataType match {
case ArrayType(childType, nullable) =>
case ArrayType(childType, nullable, _) =>
val data = if (f.nullable && rand.nextFloat() <= PROBABILITY_OF_NULL) {
null
} else {
Expand All @@ -286,7 +286,7 @@ object RandomDataGenerator {
arr
}
fields += data
case StructType(children) =>
case StructType(children, _) =>
fields += randomRow(rand, StructType(children))
case _ =>
val generator = RandomDataGenerator.forType(f.dataType, f.nullable, rand)
Expand Down
Loading