diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java index 24adeadf9567..bab118609c66 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java @@ -214,6 +214,6 @@ public static StructType createStructType(StructField[] fields) { throw new IllegalArgumentException("fields should have distinct names."); } - return StructType$.MODULE$.apply(fields); + return new StructType(fields); } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 67fca153b551..a1bbc7821dc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -119,7 +119,7 @@ object RowEncoder { "fromString", inputObject :: Nil) - case t @ ArrayType(et, _) => et match { + case t @ ArrayType(et, _, _) => et match { case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => // TODO: validate input type for primitive array. NewInstance( @@ -152,7 +152,7 @@ object RowEncoder { convertedKeys :: convertedValues :: Nil, dataType = t) - case StructType(fields) => + case StructType(fields, _) => val nonNullOutput = CreateNamedStruct(fields.zipWithIndex.flatMap { case (field, index) => val fieldValue = serializerFor( ValidateExternalType( @@ -259,7 +259,7 @@ object RowEncoder { case StringType => Invoke(input, "toString", ObjectType(classOf[String])) - case ArrayType(et, nullable) => + case ArrayType(et, nullable, _) => val arrayData = Invoke( MapObjects(deserializerFor(_), input, et), @@ -284,7 +284,7 @@ object RowEncoder { "toScalaMap", keyData :: valueData :: Nil) - case schema @ StructType(fields) => + case schema @ StructType(fields, _) => val convertedFields = fields.zipWithIndex.map { case (f, i) => If( Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index b1e89b5de833..7ed47792fd26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -62,7 +62,7 @@ object Cast { case (TimestampType, _: NumericType) => true case (_: NumericType, _: NumericType) => true - case (ArrayType(fromType, fn), ArrayType(toType, tn)) => + case (ArrayType(fromType, fn, _), ArrayType(toType, tn, _)) => canCast(fromType, toType) && resolvableNullability(fn || forceNullable(fromType, toType), tn) @@ -72,7 +72,7 @@ object Cast { canCast(fromValue, toValue) && resolvableNullability(fn || forceNullable(fromValue, toValue), tn) - case (StructType(fromFields), StructType(toFields)) => + case (StructType(fromFields, _), StructType(toFields, _)) => fromFields.length == toFields.length && fromFields.zip(toFields).forall { case (fromField, toField) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index b891f9467375..a39fb2cb9d71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -129,7 +129,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] input: String, dataType: DataType): ExprCode = dataType match { case s: StructType => createCodeForStruct(ctx, input, s) - case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType) + case ArrayType(elementType, _, _) => createCodeForArray(ctx, input, elementType) case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType) // UTF8String act as a pointer if it's inside UnsafeRow, so copy it to make it safe. case StringType => ExprCode("", "false", s"$input.clone()") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 5efba4b3a608..5414c532ba92 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -117,7 +117,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); """ - case a @ ArrayType(et, _) => + case a @ ArrayType(et, _, _) => s""" // Remember the current cursor so that we can calculate how many bytes are // written later. @@ -202,7 +202,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ${writeStructToBuffer(ctx, element, t.map(_.dataType), bufferHolder)} """ - case a @ ArrayType(et, _) => + case a @ ArrayType(et, _, _) => s""" $arrayWriter.setOffset($index); ${writeArrayToBuffer(ctx, element, et, bufferHolder)} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 2e8ea1107cee..e17dd4a6f3f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -111,9 +111,9 @@ case class SortArray(base: Expression, ascendingOrder: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType) override def checkInputDataTypes(): TypeCheckResult = base.dataType match { - case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => + case ArrayType(dt, _, _) if RowOrdering.isOrderable(dt) => TypeCheckResult.TypeCheckSuccess - case ArrayType(dt, _) => + case ArrayType(dt, _, _) => TypeCheckResult.TypeCheckFailure( s"$prettyName does not support sorting array of type ${dt.simpleString}") case _ => @@ -123,9 +123,9 @@ case class SortArray(base: Expression, ascendingOrder: Expression) @transient private lazy val lt: Comparator[Any] = { val ordering = base.dataType match { - case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] - case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] - case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(n: AtomicType, _, _) => n.ordering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(a: ArrayType, _, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(s: StructType, _, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] } new Comparator[Any]() { @@ -146,9 +146,9 @@ case class SortArray(base: Expression, ascendingOrder: Expression) @transient private lazy val gt: Comparator[Any] = { val ordering = base.dataType match { - case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] - case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] - case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(n: AtomicType, _, _) => n.ordering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(a: ArrayType, _, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(s: StructType, _, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] } new Comparator[Any]() { @@ -192,7 +192,7 @@ case class ArrayContains(left: Expression, right: Expression) override def inputTypes: Seq[AbstractDataType] = right.dataType match { case NullType => Seq() case _ => left.dataType match { - case n @ ArrayType(element, _) => Seq(n, element) + case n @ ArrayType(element, _, _) => Seq(n, element) case _ => Seq() } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 3b4468f55ca7..1a134906e5b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -48,12 +48,12 @@ object ExtractValue { resolver: Resolver): Expression = { (child.dataType, extraction) match { - case (StructType(fields), NonNullLiteral(v, StringType)) => + case (StructType(fields, _), NonNullLiteral(v, StringType)) => val fieldName = v.toString val ordinal = findField(fields, fieldName, resolver) GetStructField(child, ordinal, Some(fieldName)) - case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) => + case (ArrayType(StructType(fields, _), containsNull, _), NonNullLiteral(v, StringType)) => val fieldName = v.toString val ordinal = findField(fields, fieldName, resolver) GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), @@ -65,7 +65,7 @@ object ExtractValue { case (otherType, _) => val errorMsg = otherType match { - case StructType(_) => + case StructType(_, _) => s"Field name should be String Literal, but it's $extraction" case other => s"Can't extract value from $child" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 9d5c856a23e2..e26270f4d963 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -165,7 +165,7 @@ abstract class ExplodeBase(child: Expression, position: Boolean) // hive-compatible default alias for explode function ("col" for array, "key", "value" for map) override def elementSchema: StructType = child.dataType match { - case ArrayType(et, containsNull) => + case ArrayType(et, containsNull, _) => if (position) { new StructType() .add("pos", IntegerType, false) @@ -189,7 +189,7 @@ abstract class ExplodeBase(child: Expression, position: Boolean) override def eval(input: InternalRow): TraversableOnce[InternalRow] = { child.dataType match { - case ArrayType(et, _) => + case ArrayType(et, _, _) => val inputArray = child.eval(input).asInstanceOf[ArrayData] if (inputArray == null) { Nil @@ -260,7 +260,7 @@ case class Inline(child: Expression) extends UnaryExpression with Generator with override def children: Seq[Expression] = child :: Nil override def checkInputDataTypes(): TypeCheckResult = child.dataType match { - case ArrayType(et, _) if et.isInstanceOf[StructType] => + case ArrayType(et, _, _) if et.isInstanceOf[StructType] => TypeCheckResult.TypeCheckSuccess case _ => TypeCheckResult.TypeCheckFailure( @@ -268,7 +268,7 @@ case class Inline(child: Expression) extends UnaryExpression with Generator with } override def elementSchema: StructType = child.dataType match { - case ArrayType(et : StructType, _) => et + case ArrayType(et : StructType, _, _) => et } private lazy val numFields = elementSchema.fields.length diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index d2c94ec1df4d..4dc5803b157e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -316,7 +316,7 @@ abstract class HashExpression[E] extends Expression { val numBytes = s"$input.numBytes()" s"$result = $hasher.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);" - case ArrayType(et, containsNull) => + case ArrayType(et, containsNull, _) => val index = ctx.freshName("index") s""" for (int $index = 0; $index < $input.numElements(); $index++) { @@ -337,7 +337,7 @@ abstract class HashExpression[E] extends Expression { } """ - case StructType(fields) => + case StructType(fields, _) => fields.zipWithIndex.map { case (field, index) => nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx) }.mkString("\n") @@ -386,7 +386,7 @@ abstract class InterpretedHashFunction { case array: ArrayData => val elementType = dataType match { case udt: UserDefinedType[_] => udt.sqlType.asInstanceOf[ArrayType].elementType - case ArrayType(et, _) => et + case ArrayType(et, _, _) => et } var result = seed var i = 0 @@ -418,7 +418,7 @@ abstract class InterpretedHashFunction { val types: Array[DataType] = dataType match { case udt: UserDefinedType[_] => udt.sqlType.asInstanceOf[StructType].map(_.dataType).toArray - case StructType(fields) => fields.map(_.dataType) + case StructType(fields, _) => fields.map(_.dataType) } var result = seed var i = 0 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index ea4dee174e74..c2ae7fc14903 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -441,7 +441,7 @@ case class MapObjects private( s"${genInputData.value}.length" -> s"${genInputData.value}[$loopIndex]" case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => s"${genInputData.value}.size()" -> s"${genInputData.value}.get($loopIndex)" - case ArrayType(et, _) => + case ArrayType(et, _, _) => s"${genInputData.value}.numElements()" -> ctx.getValue(genInputData.value, et, loopIndex) case ObjectType(cls) if cls == classOf[Object] => s"$seq == null ? $array.length : $seq.size()" -> diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 520e34436162..e51f07d28c09 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -51,9 +51,16 @@ object ArrayType extends AbstractDataType { * * @param elementType The data type of values. * @param containsNull Indicates if values have `null` values + * @param metadata The metadata of this array type. */ @DeveloperApi -case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType { +case class ArrayType( + elementType: DataType, + containsNull: Boolean, + metadata: Metadata = Metadata.empty) extends DataType { + + protected def this(elementType: DataType, containsNull: Boolean) = + this(elementType, containsNull, Metadata.empty) /** No-arg constructor for kryo. */ protected def this() = this(null, false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 4fc65cbce15b..3e5ad36229e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -196,12 +196,12 @@ object DataType { */ private[types] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { (left, right) match { - case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) => + case (ArrayType(leftElementType, _, _), ArrayType(rightElementType, _, _)) => equalsIgnoreNullability(leftElementType, rightElementType) case (MapType(leftKeyType, leftValueType, _), MapType(rightKeyType, rightValueType, _)) => equalsIgnoreNullability(leftKeyType, rightKeyType) && equalsIgnoreNullability(leftValueType, rightValueType) - case (StructType(leftFields), StructType(rightFields)) => + case (StructType(leftFields, _), StructType(rightFields, _)) => leftFields.length == rightFields.length && leftFields.zip(rightFields).forall { case (l, r) => l.name == r.name && equalsIgnoreNullability(l.dataType, r.dataType) @@ -226,7 +226,7 @@ object DataType { */ private[sql] def equalsIgnoreCompatibleNullability(from: DataType, to: DataType): Boolean = { (from, to) match { - case (ArrayType(fromElement, fn), ArrayType(toElement, tn)) => + case (ArrayType(fromElement, fn, _), ArrayType(toElement, tn, _)) => (tn || !fn) && equalsIgnoreCompatibleNullability(fromElement, toElement) case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => @@ -234,7 +234,7 @@ object DataType { equalsIgnoreCompatibleNullability(fromKey, toKey) && equalsIgnoreCompatibleNullability(fromValue, toValue) - case (StructType(fromFields), StructType(toFields)) => + case (StructType(fromFields, _), StructType(toFields, _)) => fromFields.length == toFields.length && fromFields.zip(toFields).forall { case (fromField, toField) => fromField.name == toField.name && diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index dd4c88c4c43b..73c0ad738254 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -92,7 +92,11 @@ import org.apache.spark.util.Utils * }}} */ @DeveloperApi -case class StructType(fields: Array[StructField]) extends DataType with Seq[StructField] { +case class StructType( + fields: Array[StructField], + metadata: Metadata = Metadata.empty) extends DataType with Seq[StructField] { + + def this(fields: Array[StructField]) = this(fields, Metadata.empty) /** No-arg constructor for kryo. */ def this() = this(Array.empty[StructField]) @@ -106,7 +110,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru override def equals(that: Any): Boolean = { that match { - case StructType(otherFields) => + case StructType(otherFields, _) => java.util.Arrays.equals( fields.asInstanceOf[Array[AnyRef]], otherFields.asInstanceOf[Array[AnyRef]]) case _ => false @@ -417,11 +421,11 @@ object StructType extends AbstractDataType { } } - def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray) + def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray, Metadata.empty) def apply(fields: java.util.List[StructField]): StructType = { import scala.collection.JavaConverters._ - StructType(fields.asScala) + apply(fields.asScala) } private[sql] def fromAttributes(attributes: Seq[Attribute]): StructType = @@ -429,7 +433,7 @@ object StructType extends AbstractDataType { private[sql] def removeMetadata(key: String, dt: DataType): DataType = dt match { - case StructType(fields) => + case StructType(fields, _) => val newFields = fields.map { f => val mb = new MetadataBuilder() f.copy(dataType = removeMetadata(key, f.dataType), @@ -441,8 +445,8 @@ object StructType extends AbstractDataType { private[sql] def merge(left: DataType, right: DataType): DataType = (left, right) match { - case (ArrayType(leftElementType, leftContainsNull), - ArrayType(rightElementType, rightContainsNull)) => + case (ArrayType(leftElementType, leftContainsNull, _), + ArrayType(rightElementType, rightContainsNull, _)) => ArrayType( merge(leftElementType, rightElementType), leftContainsNull || rightContainsNull) @@ -454,7 +458,7 @@ object StructType extends AbstractDataType { merge(leftValueType, rightValueType), leftContainsNull || rightContainsNull) - case (StructType(leftFields), StructType(rightFields)) => + case (StructType(leftFields, _), StructType(rightFields, _)) => val newFields = ArrayBuffer.empty[StructField] // This metadata will record the fields that only exist in one of two StructTypes val optionalMeta = new MetadataBuilder() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index 850869799507..a02bcbad3c2c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -196,7 +196,7 @@ object RandomDataGenerator { case ShortType => randomNumeric[Short]( rand, _.nextInt().toShort, Seq(Short.MinValue, Short.MaxValue, 0.toShort)) case NullType => Some(() => null) - case ArrayType(elementType, containsNull) => + case ArrayType(elementType, containsNull, _) => forType(elementType, nullable = containsNull, rand).map { elementGenerator => () => Seq.fill(rand.nextInt(MAX_ARR_SIZE))(elementGenerator()) } @@ -220,7 +220,7 @@ object RandomDataGenerator { keys.zip(values).toMap } } - case StructType(fields) => + case StructType(fields, _) => val maybeFieldGenerators: Seq[Option[() => Any]] = fields.map { field => forType(field.dataType, nullable = field.nullable, rand) } @@ -269,7 +269,7 @@ object RandomDataGenerator { val fields = mutable.ArrayBuffer.empty[Any] schema.fields.foreach { f => f.dataType match { - case ArrayType(childType, nullable) => + case ArrayType(childType, nullable, _) => val data = if (f.nullable && rand.nextFloat() <= PROBABILITY_OF_NULL) { null } else { @@ -286,7 +286,7 @@ object RandomDataGenerator { arr } fields += data - case StructType(children) => + case StructType(children, _) => fields += randomRow(rand, StructType(children)) case _ => val generator = RandomDataGenerator.forType(f.dataType, f.nullable, rand) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index a1f9259f139e..8bcb41d7bf25 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -376,7 +376,7 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { val encodedData = try { row.toSeq(encoder.schema).zip(schema).map { - case (a: ArrayData, AttributeReference(_, ArrayType(et, _), _, _)) => + case (a: ArrayData, AttributeReference(_, ArrayType(et, _, _), _, _)) => a.toArray[Any](et).toSeq case (other, _) => other diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index ec7be4d4b849..c991185dff1b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -77,7 +77,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { def getStructField(expr: Expression, fieldName: String): GetStructField = { expr.dataType match { - case StructType(fields) => + case StructType(fields, _) => val index = fields.indexWhere(_.name == fieldName) GetStructField(expr, index) } @@ -107,7 +107,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { def getArrayStructFields(expr: Expression, fieldName: String): GetArrayStructFields = { expr.dataType match { - case ArrayType(StructType(fields), containsNull) => + case ArrayType(StructType(fields, _), containsNull, _) => val field = fields.find(_.name == fieldName).get GetArrayStructFields(expr, field, fields.indexOf(field), fields.length, containsNull) } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index a18b881c78a0..08091dd8ab39 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -18,6 +18,8 @@ package org.apache.spark.sql.execution.datasources.parquet; import java.io.IOException; +import java.util.HashMap; +import java.util.Map; import org.apache.parquet.bytes.BytesUtils; import org.apache.parquet.column.ColumnDescriptor; @@ -29,6 +31,8 @@ import org.apache.parquet.schema.PrimitiveType; import org.apache.spark.sql.execution.vectorized.ColumnVector; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.DecimalType; @@ -66,6 +70,11 @@ public class VectorizedColumnReader { */ private final int maxDefLevel; + /** + * Maximum repetition level for this column. + */ + private final int maxRepLevel; + /** * Repetition/Definition/Value readers. */ @@ -77,6 +86,9 @@ public class VectorizedColumnReader { // with `definitionLevelColumn`. private VectorizedRleValuesReader defColumn; + // Only set when reading complex column. + private VectorizedRleValuesReader defColumnCopy; + /** * Total number of values in this column (in this row group). */ @@ -95,6 +107,7 @@ public VectorizedColumnReader(ColumnDescriptor descriptor, PageReader pageReader this.descriptor = descriptor; this.pageReader = pageReader; this.maxDefLevel = descriptor.getMaxDefinitionLevel(); + this.maxRepLevel = descriptor.getMaxRepetitionLevel(); DictionaryPage dictionaryPage = pageReader.readDictionaryPage(); if (dictionaryPage != null) { @@ -113,40 +126,85 @@ public VectorizedColumnReader(ColumnDescriptor descriptor, PageReader pageReader throw new IOException("totalValueCount == 0"); } } + /** + * Whether this column is the element of a complex column. + */ + boolean asComplexColElement; /** - * Advances to the next value. Returns true if the value is non-null. + * The flag used in constructing nested records. When it is true, the previous status + * will be reset. */ - private boolean next() throws IOException { - if (valuesRead >= endOfPageValueCount) { - if (valuesRead >= totalValueCount) { - // How do we get here? Throw end of stream exception? - return false; - } - readPage(); - } - ++valuesRead; - // TODO: Don't read for flat schemas - //repetitionLevel = repetitionLevelColumn.nextInt(); - return definitionLevelColumn.nextInt() == maxDefLevel; - } + boolean resetNestedRecord = true; /** * Reads `total` values from this columnReader into column. */ - void readBatch(int total, ColumnVector column) throws IOException { + public void readBatch(int total, ColumnVector column) throws IOException { + asComplexColElement = column.getParentColumn() != null; + boolean isRepeatedColumn = maxRepLevel > 0; int rowId = 0; - while (total > 0) { + int repeatedRowId = 0; + int remaining = total; + + // The number of values to read. + int num = 0; + + // Stores row ids and offsets during constructing nested records. + int[] rowIds = new int[maxRepLevel + 2]; + int[] offsets = new int[maxRepLevel + 2]; + + // Keeps repetition levels and corresponding repetition counts. + int[] repetitions = new int[maxRepLevel + 2]; + + while (true) { // Compute the number of values we want to read in this page. int leftInPage = (int) (endOfPageValueCount - valuesRead); + + // Stop condition: + // If we are going to read data in repeated column, the stop condition is that we + // read `total` repeated columns. Eg., if we want to read 5 records of an array of int column. + // we can't just read 5 integers. Instead, we have to read the integers until 5 arrays are put + // into this array column. + if (isRepeatedColumn) { + if (repeatedRowId == total) break; + } else { + if (remaining == 0) break; + } + + // Reaching the end of current page. if (leftInPage == 0) { - readPage(); + boolean pageExists = readPage(); + if (!pageExists) { + if (!resetNestedRecord) { + insertRepeatedArray(column, rowIds, offsets, repetitions, total, 0); + resetNestedRecord = true; + repeatedRowId = rowIds[1]; + if (repeatedRowId == total) break; + } + // Should not reach here. + throw new IOException("Failed to read page. No page exists anymore!"); + } leftInPage = (int) (endOfPageValueCount - valuesRead); } - int num = Math.min(total, leftInPage); + + // Determine the number of values to read for this column in the current page. + if (asComplexColElement) { + // Using repetition and definition level encodings to construct nested/repeated records. + // When constructing nested/repeated records, we returns the number of values to read in + // this page for this column. + num = constructComplexRecords(column, repetitions, rowIds, offsets, leftInPage, total); + repeatedRowId = rowIds[1]; + } else { + // If this column is not a repeated/nested column, just read minimum of remaining values + // and all values left in the current page. + num = Math.min(remaining, leftInPage); + } + if (useDictionary) { // Read and decode dictionary ids. - ColumnVector dictionaryIds = column.reserveDictionaryIds(total); + int dictionaryCapacity = Math.max(remaining, rowId + num); + ColumnVector dictionaryIds = column.reserveDictionaryIds(dictionaryCapacity); defColumn.readIntegers( num, dictionaryIds, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); @@ -166,10 +224,13 @@ void readBatch(int total, ColumnVector column) throws IOException { } else { if (column.hasDictionary() && rowId != 0) { // This batch already has dictionary encoded values but this new page is not. The batch - // does not support a mix of dictionary and not so we will decode the dictionary. + // does not support a mix of dictionary and not, so we will decode the dictionary. decodeDictionaryIds(0, rowId, column, column.getDictionaryIds()); } column.setDictionary(null); + if (isRepeatedColumn) { + column.reserve(Math.max(remaining, rowId + num)); + } switch (descriptor.getType()) { case BOOLEAN: readBooleanBatch(rowId, num, column); @@ -199,10 +260,9 @@ void readBatch(int total, ColumnVector column) throws IOException { throw new IOException("Unsupported type: " + descriptor.getType()); } } - valuesRead += num; rowId += num; - total -= num; + remaining -= num; } } @@ -420,30 +480,35 @@ private void readFixedLenByteArrayBatch(int rowId, int num, } } - private void readPage() throws IOException { + private boolean readPage() throws IOException { DataPage page = pageReader.readPage(); - // TODO: Why is this a visitor? - page.accept(new DataPage.Visitor() { - @Override - public Void visit(DataPageV1 dataPageV1) { - try { - readPageV1(dataPageV1); - return null; - } catch (IOException e) { - throw new RuntimeException(e); + if (page == null) { + return false; + } else { + // TODO: Why is this a visitor? + page.accept(new DataPage.Visitor() { + @Override + public Void visit(DataPageV1 dataPageV1) { + try { + readPageV1(dataPageV1); + return null; + } catch (IOException e) { + throw new RuntimeException(e); + } } - } - @Override - public Void visit(DataPageV2 dataPageV2) { - try { - readPageV2(dataPageV2); - return null; - } catch (IOException e) { - throw new RuntimeException(e); + @Override + public Void visit(DataPageV2 dataPageV2) { + try { + readPageV2(dataPageV2); + return null; + } catch (IOException e) { + throw new RuntimeException(e); + } } - } - }); + }); + return true; + } } private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset) throws IOException { @@ -477,6 +542,292 @@ private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset) thr } } + /** + * Inserts records into parent columns of a column. These parent columns are repeated columns. As + * the real data are read into the column, we only need to insert array into its repeated columns. + * @param column The ColumnVector which the data in the page are read into. + * @param rowIds Mapping between repetition levels and their current row ids for constructing. + * @param offsets The beginning offsets in columns which we use to construct nested records. + * @param repetitions Mapping between repetition levels and their corresponding counts. + * @param total The total number of rows to construct. + * @param repLevel The current repetition level. + */ + private void insertRepeatedArray( + ColumnVector column, + int[] rowIds, + int[] offsets, + int[] repetitions, + int total, + int repLevel) throws IOException { + ColumnVector parentRepeatedColumn = column; + int curRepLevel = maxRepLevel; + while (true) { + parentRepeatedColumn = parentRepeatedColumn.getNearestParentArrayColumn(); + if (parentRepeatedColumn != null) { + int parentColRepLevel = parentRepeatedColumn.getRepLevel(); + // The current repetition level means the beginning level of the current value. Thus, + // we only need to insert array into the parent columns whose repetition levels are + // equal to or more than the given repetition level. + if (parentColRepLevel >= repLevel) { + parentRepeatedColumn.reserve(rowIds[curRepLevel] + 1); + parentRepeatedColumn.putArray(rowIds[curRepLevel], + offsets[curRepLevel], repetitions[curRepLevel]); + + offsets[curRepLevel] += repetitions[curRepLevel]; + repetitions[curRepLevel] = 0; + rowIds[curRepLevel]++; + + // Increase the repetition count for parent repetition level as we add a new record. + if (curRepLevel > 1) { + repetitions[curRepLevel - 1]++; + } + + // In vectorization, the most outside repeated element is at the repetition 1. + if (curRepLevel == 1 && rowIds[curRepLevel] == total) { + return; + } + curRepLevel--; + } else { + break; + } + } else { + break; + } + } + } + + /** + * Finds the outside element of an inner element which is defined as Catalyst DataType, + * with the specified definition level. + * @param column The column as the beginning level for looking up the inner element. + * @param defLevel The specified definition level. + * @return the column which is the outside group element of the inner element. + */ + private ColumnVector findInnerElementWithDefLevel(ColumnVector column, int defLevel) { + while (true) { + if (column == null) { + return null; + } + ColumnVector parent = column.getParentColumn(); + if (parent != null && parent.getDefLevel() == defLevel) { + ColumnVector outside = parent.getParentColumn(); + if (outside == null || outside.getDefLevel() < defLevel) { + return column; + } + } + column = parent; + } + } + + /** + * Finds the outside element of the inner element which is not defined as Catalyst DataType, + * with the specified definition level. + * @param column The column as the beginning level for looking up the inner element. + * @param defLevel The specified definition level. + * @return the column which is the outside group element of the inner element. + */ + private ColumnVector findHiddenInnerElementWithDefLevel(ColumnVector column, int defLevel) { + while (true) { + if (column == null) { + return null; + } + ColumnVector parent = column.getParentColumn(); + if (parent != null && parent.getDefLevel() <= defLevel) { + ColumnVector outside = parent.getParentColumn(); + if (outside == null || outside.getDefLevel() < defLevel) { + return column; + } + } + column = parent; + } + } + + /** + * Checks if the given column is a legacy array in Parquet schema. + * @param column The column we want to check if it is legacy array. + * @return whether the given column is a legacy array in Parquet schema. + */ + private boolean isLegacyArray(ColumnVector column) { + ColumnVector parent = column.getNearestParentArrayColumn(); + if (parent == null) { + return false; + } else if (parent.getRepLevel() <= maxRepLevel && parent.getDefLevel() < maxDefLevel) { + return true; + } + return false; + } + + /** + * Inserts a null record at specified column. + * @param column The ColumnVector which the data in the page are read into. + * @param rowIds Mapping between repetition levels and their current row ids for constructing. + * @param repetitions Mapping between repetition levels and their corresponding counts. + */ + private void insertNullRecord( + ColumnVector column, + int[] rowIds, + int[] repetitions) { + int repLevel = column.getRepLevel(); + + if (repLevel == 0) { + repLevel = 1; + } + + rowIds[repLevel] += repetitions[repLevel]; + repetitions[repLevel] = 0; + + column.reserve(rowIds[repLevel] + 1); + column.putNull(rowIds[repLevel]); + rowIds[repLevel]++; + } + + /** + * Returns the array of repetition level values. + */ + private int[] getRepetitionLevels() throws IOException { + int[] repetitions = new int[this.pageValueCount]; + for (int i = 0; i < this.pageValueCount; i++) { + repetitions[i] = this.repetitionLevelColumn.nextInt(); + } + return repetitions; + } + + /** + * Iterates the values of definition and repetition levels for the values read in the page, + * and constructs complex records accordingly. + * @param column The ColumnVector which the data in the page are read into. + @ @param repetitions Mapping between repetition levels and their counts. + * @param rowIds Mapping between repetition levels and their current row ids for constructing. + * @param offsets The beginning offsets in columns which we use to construct nested records. + * @param leftInPage The number of values can be read in the current page. + * @param total The total number of rows to construct. + * @return the number of values needed to read in the current page. + */ + private int constructComplexRecords( + ColumnVector column, + int[] repetitions, + int[] rowIds, + int[] offsets, + int leftInPage, + int total) throws IOException { + for (int i = 0; i < leftInPage; i++) { + int repLevel = repetitionLevelColumn.nextInt(); + int defLevel = definitionLevelColumn.nextInt(); + + // If there are previous values and counts needed to be consider. + if (!resetNestedRecord) { + // When a new record begins at lower repetition level, + // we insert array into repeated column. + if (repLevel < maxRepLevel) { + insertRepeatedArray(column, rowIds, offsets, repetitions, total, repLevel); + } + } + resetNestedRecord = false; + + // When definition level is less than max definition level, + // there is a null value. + if (defLevel < maxDefLevel) { + int offset = offsets[maxRepLevel]; + + // The null value is defined at the root level. + // Insert a null record. + if (repLevel == 0 && defLevel == 0) { + ColumnVector parent = column.getParentColumn(); + if (parent != null && parent.getDefLevel() == maxDefLevel + && parent.getRepLevel() == maxRepLevel) { + // A repeated element at root level. + // E.g., The repeatedPrimitive at the following schema. + // Going to insert an empty record. + // messageType: message spark_schema { + // optional int32 optionalPrimitive; + // required int32 requiredPrimitive; + // + // repeated int32 repeatedPrimitive; + // + // optional group optionalMessage { + // optional int32 someId; + // } + // required group requiredMessage { + // optional int32 someId; + // } + // repeated group repeatedMessage { + // optional int32 someId; + // } + // } + insertRepeatedArray(column, rowIds, offsets, repetitions, total, repLevel); + } else { + // Obtain most outside column. + ColumnVector topColumn = column.getParentColumn(); + while (topColumn.getParentColumn() != null) { + topColumn = topColumn.getParentColumn(); + } + + insertNullRecord(topColumn, rowIds, repetitions); + } + // Move to next offset in max repetition level as we processed the current value. + offsets[maxRepLevel]++; + resetNestedRecord = true; + } else if (isLegacyArray(column) && + column.getNearestParentArrayColumn().getDefLevel() == defLevel) { + // For a legacy array, if a null is defined at the repeated group column, it actually + // means an element with null value. + + repetitions[maxRepLevel]++; + } else if (!column.getParentColumn().isArray() && + column.getParentColumn().getDefLevel() == defLevel) { + // A null element defined in the wrapping non-repeated group. + rowIds[1]++; + } else { + // An empty element defined in outside group. + // E.g., the element in the following schema. + // messageType: message spark_schema { + // required int32 index; + // optional group col { + // optional float f1; + // optional group f2 (LIST) { + // repeated group list { + // optional boolean element; + // } + // } + // } + // } + ColumnVector parent = findInnerElementWithDefLevel(column, defLevel); + if (parent != null) { + // Found the group with the same definition level. + // Insert a null record at definition level. + // E.g, R=0, D=1 for above schema. + insertNullRecord(parent, rowIds, repetitions); + offsets[maxRepLevel]++; + resetNestedRecord = true; + } else { + // Found the group with lower definition level. + // Insert an empty record. + // E.g, R=0, D=2 for above schema. + parent = findHiddenInnerElementWithDefLevel(column, defLevel); + insertRepeatedArray(column, rowIds, offsets, repetitions, total, repLevel); + offsets[maxRepLevel]++; + resetNestedRecord = true; + } + } + } else { + // Determine the repetition level of non-null values. + // A new record begins with non-null value. + if (maxRepLevel == 0) { + // A required record at root level. + repetitions[1]++; + insertRepeatedArray(column, rowIds, offsets, repetitions, total, maxRepLevel - 1); + } else { + // Repeated values. We increase repetition count. + repetitions[maxRepLevel]++; + } + } + // If we have constructed `total` records, return the number of values to read. + if (rowIds[1] == total) return i + 1; + } + // All `leftInPage` values in the current page are needed to read. + return leftInPage; + } + private void readPageV1(DataPageV1 page) throws IOException { this.pageValueCount = page.getValueCount(); ValuesReader rlReader = page.getRlEncoding().getValuesReader(descriptor, REPETITION_LEVEL); @@ -490,12 +841,20 @@ private void readPageV1(DataPageV1 page) throws IOException { this.defColumn = new VectorizedRleValuesReader(bitWidth); dlReader = this.defColumn; this.repetitionLevelColumn = new ValuesReaderIntIterator(rlReader); - this.definitionLevelColumn = new ValuesReaderIntIterator(dlReader); try { byte[] bytes = page.getBytes().toByteArray(); rlReader.initFromPage(pageValueCount, bytes, 0); int next = rlReader.getNextOffset(); dlReader.initFromPage(pageValueCount, bytes, next); + + if (asComplexColElement) { + ValuesReader dlReaderCopy; + this.defColumnCopy = new VectorizedRleValuesReader(bitWidth); + dlReaderCopy = this.defColumnCopy; + this.definitionLevelColumn = new ValuesReaderIntIterator(dlReaderCopy); + dlReaderCopy.initFromPage(pageValueCount, bytes, next); + } + next = dlReader.getNextOffset(); initDataReader(page.getValueEncoding(), bytes, next); } catch (IOException e) { @@ -510,9 +869,16 @@ private void readPageV2(DataPageV2 page) throws IOException { int bitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel()); this.defColumn = new VectorizedRleValuesReader(bitWidth); - this.definitionLevelColumn = new ValuesReaderIntIterator(this.defColumn); this.defColumn.initFromBuffer( this.pageValueCount, page.getDefinitionLevels().toByteArray()); + + if (asComplexColElement) { + this.defColumnCopy = new VectorizedRleValuesReader(bitWidth); + this.definitionLevelColumn = new ValuesReaderIntIterator(this.defColumnCopy); + this.defColumnCopy.initFromBuffer( + this.pageValueCount, page.getDefinitionLevels().toByteArray()); + } + try { initDataReader(page.getDataEncoding(), page.getData().toByteArray(), 0); } catch (IOException e) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 51bdf0f0f229..e7b2551989ef 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils; import org.apache.spark.sql.execution.vectorized.ColumnarBatch; +import org.apache.spark.sql.execution.vectorized.ColumnVector; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; @@ -178,6 +179,7 @@ public void initBatch(MemoryMode memMode, StructType partitionColumns, } } + // Allocate ColumnVectors in ColumnarBatch columnarBatch = ColumnarBatch.allocate(batchSchema, memMode); if (partitionColumns != null) { int partitionIdx = sparkSchema.fields().length; @@ -188,12 +190,30 @@ public void initBatch(MemoryMode memMode, StructType partitionColumns, } // Initialize missing columns with nulls. - for (int i = 0; i < missingColumns.length; i++) { - if (missingColumns[i]) { - columnarBatch.column(i).putNulls(0, columnarBatch.capacity()); - columnarBatch.column(i).setIsConstant(); + int missingColumnIdx = 0; + int partitionIdxBase = missingColumns.length; + if (partitionColumns != null) { + partitionIdxBase = sparkSchema.fields().length; + } + for (int i = 0; i < columnarBatch.numFields(); i++) { + if (i < partitionIdxBase) { + missingColumnIdx = initColumnWithNulls(columnarBatch.column(i), missingColumnIdx); + } + } + } + + private int initColumnWithNulls(ColumnVector column, int missingColumnIdx) { + if (column.isComplex()) { + for (int j = 0; j < column.getChildColumnNums(); j++) { + missingColumnIdx = initColumnWithNulls(column.getChildColumn(j), missingColumnIdx); + } + } else { + if (missingColumns[missingColumnIdx++]) { + column.putNulls(0, columnarBatch.capacity()); + column.setIsConstant(); } } + return missingColumnIdx; } public void initBatch() { @@ -225,10 +245,10 @@ public boolean nextBatch() throws IOException { checkEndOfRowGroup(); int num = (int) Math.min((long) columnarBatch.capacity(), totalCountLoadedSoFar - rowsReturned); - for (int i = 0; i < columnReaders.length; ++i) { - if (columnReaders[i] == null) continue; - columnReaders[i].readBatch(num, columnarBatch.column(i)); + for (int i = 0; i < columnarBatch.numFields(); i++) { + readBatchOnColumnVector(columnarBatch.column(i), num); } + rowsReturned += num; columnarBatch.setNumRows(num); numBatched = num; @@ -236,17 +256,23 @@ public boolean nextBatch() throws IOException { return true; } + private void readBatchOnColumnVector(ColumnVector column, int num) throws IOException { + if (column.hasColumnReader()) { + column.readBatch(num); + } else { + for (int j = 0; j < column.getChildColumnNums(); j++) { + readBatchOnColumnVector(column.getChildColumn(j), num); + } + } + } + private void initializeInternal() throws IOException, UnsupportedOperationException { /** * Check that the requested schema is supported. */ - missingColumns = new boolean[requestedSchema.getFieldCount()]; - for (int i = 0; i < requestedSchema.getFieldCount(); ++i) { - Type t = requestedSchema.getFields().get(i); - if (!t.isPrimitive() || t.isRepetition(Type.Repetition.REPEATED)) { - throw new UnsupportedOperationException("Complex types not supported."); - } - + missingColumns = new boolean[requestedSchema.getColumns().size()]; + // For loop on each physical columns. + for (int i = 0; i < requestedSchema.getColumns().size(); ++i) { String[] colPath = requestedSchema.getPaths().get(i); if (fileSchema.containsPath(colPath)) { ColumnDescriptor fd = fileSchema.getColumnDescription(colPath); @@ -265,6 +291,25 @@ private void initializeInternal() throws IOException, UnsupportedOperationExcept } } + private int setupColumnReader( + ColumnVector column, + VectorizedColumnReader[] columnReaders, + int readerIdx) { + if (column.isComplex()) { + column.setColumnReader(null); + for (int j = 0; j < column.getChildColumnNums(); j++) { + readerIdx = setupColumnReader(column.getChildColumn(j), columnReaders, readerIdx); + } + } else { + if (!missingColumns[readerIdx]) { + column.setColumnReader(columnReaders[readerIdx++]); + } else { + readerIdx++; + } + } + return readerIdx; + } + private void checkEndOfRowGroup() throws IOException { if (rowsReturned != totalCountLoadedSoFar) return; PageReadStore pages = reader.readNextRowGroup(); @@ -272,13 +317,25 @@ private void checkEndOfRowGroup() throws IOException { throw new IOException("expecting more rows but reached last block. Read " + rowsReturned + " out of " + totalRowCount); } + // Return physical columns stored in Parquet file. Not logical fields. + // For example, a nested StructType field in requestedSchema might have many columns. + // A column is always a primitive type column. List columns = requestedSchema.getColumns(); columnReaders = new VectorizedColumnReader[columns.size()]; + for (int i = 0; i < columns.size(); ++i) { if (missingColumns[i]) continue; columnReaders[i] = new VectorizedColumnReader(columns.get(i), pages.getPageReader(columns.get(i))); } + + // Associate ColumnReaders to ColumnVectors in ColumnarBatch. + int readerIdx = 0; + int partitionIdx = sparkSchema.fields().length; + for (int i = 0; i < columnarBatch.numFields(); i++) { + if (i >= partitionIdx) break; + readerIdx = setupColumnReader(columnarBatch.column(i), columnReaders, readerIdx); + } totalCountLoadedSoFar += pages.getRowCount(); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index bbbb796aca0d..ac5bc30786b1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.execution.vectorized; +import java.io.IOException; import java.math.BigDecimal; import java.math.BigInteger; @@ -27,6 +28,7 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.execution.datasources.parquet.VectorizedColumnReader; import org.apache.spark.sql.internal.SQLConf; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; @@ -179,9 +181,7 @@ public Object[] array() { public boolean isNullAt(int ordinal) { return data.isNullAt(offset + ordinal); } @Override - public boolean getBoolean(int ordinal) { - throw new UnsupportedOperationException(); - } + public boolean getBoolean(int ordinal) { return data.getBoolean(offset + ordinal); } @Override public byte getByte(int ordinal) { return data.getByte(offset + ordinal); } @@ -198,9 +198,7 @@ public short getShort(int ordinal) { public long getLong(int ordinal) { return data.getLong(offset + ordinal); } @Override - public float getFloat(int ordinal) { - throw new UnsupportedOperationException(); - } + public float getFloat(int ordinal) { return data.getFloat(offset + ordinal); } @Override public double getDouble(int ordinal) { return data.getDouble(offset + ordinal); } @@ -298,6 +296,11 @@ public void reserve(int requiredCapacity) { */ protected abstract void reserveInternal(int capacity); + /** + * Ensures that there is enough storage to store null information. + */ + protected abstract void reserveNulls(int capacity); + /** * Returns the number of nulls in this column. */ @@ -495,7 +498,8 @@ public void reserve(int requiredCapacity) { public abstract double getDouble(int rowId); /** - * Puts a byte array that already exists in this column. + * Puts an array that already exists in this column. + * This method only updates array length and offset data in this column. */ public abstract void putArray(int rowId, int offset, int length); @@ -841,6 +845,28 @@ public final int appendStruct(boolean isNull) { */ public final ColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; } + /** + * Returns the number of childColumns. + */ + public final int getChildColumnNums() { + if (childColumns == null) { + return 0; + } else { + return childColumns.length; + } + } + + /** + * Returns whether this ColumnVector represents complex types such as Array, Map, Struct. + */ + public final boolean isComplex() { + if (type instanceof ArrayType || type instanceof StructType || type instanceof MapType) { + return true; + } else { + return false; + } + } + /** * Returns the elements appended. */ @@ -856,6 +882,16 @@ public final int appendStruct(boolean isNull) { */ public final void setIsConstant() { isConstant = true; } + /** + * Returns definition level for this column. This value is valid only if isComplex() return true. + */ + public final int getDefLevel() { return defLevel; } + + /** + * Returns repetition level for this column. This value is valid only if isComplex() return true. + */ + public final int getRepLevel() { return repLevel; } + /** * Maximum number of rows that can be stored in this column. */ @@ -889,6 +925,16 @@ public final int appendStruct(boolean isNull) { */ protected boolean isConstant; + /** + * Max definition level of this column. This value is valid only if this is a nested column. + */ + protected int defLevel; + + /** + * Max repetition level of this column. This value is valid only if this is a nested column. + */ + protected int repLevel; + /** * Default size of each array length value. This grows as necessary. */ @@ -926,6 +972,83 @@ public final int appendStruct(boolean isNull) { */ protected ColumnVector dictionaryIds; + /** + * Associated VectorizedColumnReader which is used to load data into this ColumnVector. + * If this is a complex type such as array or struct, the VectorizedColumnReader will be + * null. + */ + protected VectorizedColumnReader columnReader; + + /** + * The parent ColumnVector of this column. If this column is not an element of nested column, + * then this is null. + */ + protected ColumnVector parentColumn; + + /** + * Sets the columnReader for this column. + */ + public void setColumnReader(VectorizedColumnReader columnReader) { + this.columnReader = columnReader; + } + + /** + * Sets the parent column for this column. + */ + public void setParentColumn(ColumnVector column) { + this.parentColumn = column; + } + + /** + * Returns the parent column for this column. + */ + public ColumnVector getParentColumn() { + return this.parentColumn; + } + + /** + * The flag shows if the nearest parent column is initialized. + */ + private boolean isNearestParentArrayColumnInited = false; + + /** + * The nearest parent column which is an Array column. + */ + private ColumnVector nearestParentArrayColumn; + + /** + * Returns the nearest parent column which is an Array column. + */ + public ColumnVector getNearestParentArrayColumn() { + if (!isNearestParentArrayColumnInited) { + nearestParentArrayColumn = this.parentColumn; + while (nearestParentArrayColumn != null && !nearestParentArrayColumn.isArray()) { + nearestParentArrayColumn = nearestParentArrayColumn.parentColumn; + } + isNearestParentArrayColumnInited = true; + } + return nearestParentArrayColumn; + } + + /** + * Returns if this ColumnVector has initialized VectorizedColumnReader. + */ + public boolean hasColumnReader() { + return this.columnReader != null; + } + + /** + * Reads `total` values from associated columnReader into this column. + */ + public void readBatch(int total) throws IOException { + if (this.columnReader != null) { + this.columnReader.readBatch(total, this); + } else { + throw new RuntimeException("The reader of this ColumnVector is not initialized yet. " + + "Failed to call readBatch()."); + } + } + /** * Update the dictionary. */ @@ -949,6 +1072,7 @@ public ColumnVector reserveDictionaryIds(int capacity) { dictionaryIds.reset(); dictionaryIds.reserve(capacity); } + reserveNulls(capacity); return dictionaryIds; } @@ -972,27 +1096,47 @@ protected ColumnVector(int capacity, DataType type, MemoryMode memMode) { DataType childType; int childCapacity = capacity; if (type instanceof ArrayType) { + ArrayType arrayType = (ArrayType)type; + if (arrayType.metadata().contains("defLevel")) { + this.defLevel = (int)arrayType.metadata().getLong("defLevel"); + this.repLevel = (int)arrayType.metadata().getLong("repLevel"); + } childType = ((ArrayType)type).elementType(); + this.defLevel = defLevel; } else { childType = DataTypes.ByteType; childCapacity *= DEFAULT_ARRAY_LENGTH; } this.childColumns = new ColumnVector[1]; this.childColumns[0] = ColumnVector.allocate(childCapacity, childType, memMode); + this.childColumns[0].setParentColumn(this); this.resultArray = new Array(this.childColumns[0]); this.resultStruct = null; } else if (type instanceof StructType) { + StructType structType = (StructType)type; + if (structType.metadata().contains("defLevel")) { + this.defLevel = (int)structType.metadata().getLong("defLevel"); + this.repLevel = (int)structType.metadata().getLong("repLevel"); + } StructType st = (StructType)type; this.childColumns = new ColumnVector[st.fields().length]; for (int i = 0; i < childColumns.length; ++i) { - this.childColumns[i] = ColumnVector.allocate(capacity, st.fields()[i].dataType(), memMode); + int fieldDefLevel = 0; + if (st.fields()[i].metadata().contains("defLevel")) { + fieldDefLevel = (int)st.fields()[i].metadata().getLong("defLevel"); + } + + this.childColumns[i] = + ColumnVector.allocate(capacity, st.fields()[i].dataType(), memMode); + this.childColumns[i].setParentColumn(this); } this.resultArray = null; this.resultStruct = new ColumnarBatch.Row(this.childColumns); } else if (type instanceof CalendarIntervalType) { // Two columns. Months as int. Microseconds as Long. this.childColumns = new ColumnVector[2]; - this.childColumns[0] = ColumnVector.allocate(capacity, DataTypes.IntegerType, memMode); + this.childColumns[0] = + ColumnVector.allocate(capacity, DataTypes.IntegerType, memMode); this.childColumns[1] = ColumnVector.allocate(capacity, DataTypes.LongType, memMode); this.resultArray = null; this.resultStruct = new ColumnarBatch.Row(this.childColumns); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index f3afa8f938f8..13d618419445 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -77,6 +77,8 @@ public static ColumnarBatch allocate(StructType schema, MemoryMode memMode, int return new ColumnarBatch(schema, maxRows, memMode); } + public int numFields() { return columns.length; } + /** * Called to close all the columns in this batch. It is not valid to access the data after * calling this. This must be called at the end to clean up memory allocations. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 913a05a0aa0e..f708f8c9e17d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -445,8 +445,13 @@ protected void reserveInternal(int newCapacity) { } else { throw new RuntimeException("Unhandled " + type); } - this.nulls = Platform.reallocateMemory(nulls, elementsAppended, newCapacity); - Platform.setMemory(nulls + elementsAppended, (byte)0, newCapacity - elementsAppended); + reserveNulls(newCapacity); capacity = newCapacity; } + + @Override + protected void reserveNulls(int capacity) { + this.nulls = Platform.reallocateMemory(nulls, elementsAppended, capacity); + Platform.setMemory(nulls + elementsAppended, (byte)0, capacity - elementsAppended); + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 85067df4ebf9..48ac775c409a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -399,53 +399,53 @@ protected void reserveInternal(int newCapacity) { int[] newLengths = new int[newCapacity]; int[] newOffsets = new int[newCapacity]; if (this.arrayLengths != null) { - System.arraycopy(this.arrayLengths, 0, newLengths, 0, elementsAppended); - System.arraycopy(this.arrayOffsets, 0, newOffsets, 0, elementsAppended); + System.arraycopy(this.arrayLengths, 0, newLengths, 0, this.arrayLengths.length); + System.arraycopy(this.arrayOffsets, 0, newOffsets, 0, this.arrayOffsets.length); } arrayLengths = newLengths; arrayOffsets = newOffsets; } else if (type instanceof BooleanType) { if (byteData == null || byteData.length < newCapacity) { byte[] newData = new byte[newCapacity]; - if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended); + if (byteData != null) System.arraycopy(byteData, 0, newData, 0, byteData.length); byteData = newData; } } else if (type instanceof ByteType) { if (byteData == null || byteData.length < newCapacity) { byte[] newData = new byte[newCapacity]; - if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended); + if (byteData != null) System.arraycopy(byteData, 0, newData, 0, byteData.length); byteData = newData; } } else if (type instanceof ShortType) { if (shortData == null || shortData.length < newCapacity) { short[] newData = new short[newCapacity]; - if (shortData != null) System.arraycopy(shortData, 0, newData, 0, elementsAppended); + if (shortData != null) System.arraycopy(shortData, 0, newData, 0, shortData.length); shortData = newData; } } else if (type instanceof IntegerType || type instanceof DateType || DecimalType.is32BitDecimalType(type)) { if (intData == null || intData.length < newCapacity) { int[] newData = new int[newCapacity]; - if (intData != null) System.arraycopy(intData, 0, newData, 0, elementsAppended); + if (intData != null) System.arraycopy(intData, 0, newData, 0, intData.length); intData = newData; } } else if (type instanceof LongType || type instanceof TimestampType || DecimalType.is64BitDecimalType(type)) { if (longData == null || longData.length < newCapacity) { long[] newData = new long[newCapacity]; - if (longData != null) System.arraycopy(longData, 0, newData, 0, elementsAppended); + if (longData != null) System.arraycopy(longData, 0, newData, 0, longData.length); longData = newData; } } else if (type instanceof FloatType) { if (floatData == null || floatData.length < newCapacity) { float[] newData = new float[newCapacity]; - if (floatData != null) System.arraycopy(floatData, 0, newData, 0, elementsAppended); + if (floatData != null) System.arraycopy(floatData, 0, newData, 0, floatData.length); floatData = newData; } } else if (type instanceof DoubleType) { if (doubleData == null || doubleData.length < newCapacity) { double[] newData = new double[newCapacity]; - if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, elementsAppended); + if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, doubleData.length); doubleData = newData; } } else if (resultStruct != null) { @@ -454,10 +454,17 @@ protected void reserveInternal(int newCapacity) { throw new RuntimeException("Unhandled " + type); } - byte[] newNulls = new byte[newCapacity]; - if (nulls != null) System.arraycopy(nulls, 0, newNulls, 0, elementsAppended); - nulls = newNulls; + reserveNulls(newCapacity); capacity = newCapacity; } + + @Override + protected void reserveNulls(int capacity) { + if (nulls == null || nulls.length < capacity) { + byte[] newNulls = new byte[capacity]; + if (nulls != null) System.arraycopy(nulls, 0, newNulls, 0, nulls.length); + nulls = newNulls; + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 5b9af26dfc4f..739eb8f8c1b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -167,11 +167,11 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { /** Hive outputs fields of structs slightly differently than top level attributes. */ def toHiveStructString(a: (Any, DataType)): String = a match { - case (struct: Row, StructType(fields)) => + case (struct: Row, StructType(fields, _)) => struct.toSeq.zip(fields).map { case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" }.mkString("{", ",", "}") - case (seq: Seq[_], ArrayType(typ, _)) => + case (seq: Seq[_], ArrayType(typ, _, _)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") case (map: Map[_, _], MapType(kType, vType, _)) => map.map { @@ -185,11 +185,11 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { } a match { - case (struct: Row, StructType(fields)) => + case (struct: Row, StructType(fields, _)) => struct.toSeq.zip(fields).map { case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" }.mkString("{", ",", "}") - case (seq: Seq[_], ArrayType(typ, _)) => + case (seq: Seq[_], ArrayType(typ, _, _)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") case (map: Map[_, _], MapType(kType, vType, _)) => map.map { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 24e2c1a5fd2f..41e4405ffb2d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -358,7 +358,7 @@ private[sql] class JDBCRDD( case StringType => StringConversion case TimestampType => TimestampConversion case BinaryType => BinaryConversion - case ArrayType(et, _) => ArrayConversion(getConversions(et, metadata)) + case ArrayType(et, _, _) => ArrayConversion(getConversions(et, metadata)) case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index d3e1efc56277..830d946df249 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -202,7 +202,7 @@ object JdbcUtils extends Logging { case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i)) case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i)) case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i)) - case ArrayType(et, _) => + case ArrayType(et, _, _) => // remove type length parameters from end of type name val typeName = getJdbcType(et, dialect).databaseTypeDefinition .toLowerCase.split("\\(")(0) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala index 579b036417d2..9f0ee3ba7db0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala @@ -177,14 +177,14 @@ private[sql] object InferSchema { * Convert NullType to StringType and remove StructTypes with no fields */ private def canonicalizeType(tpe: DataType): Option[DataType] = tpe match { - case at @ ArrayType(elementType, _) => + case at @ ArrayType(elementType, _, _) => for { canonicalType <- canonicalizeType(elementType) } yield { at.copy(canonicalType) } - case StructType(fields) => + case StructType(fields, _) => val canonicalFields: Array[StructField] = for { field <- fields if field.name.length > 0 @@ -229,9 +229,9 @@ private[sql] object InferSchema { shouldHandleCorruptRecord: Boolean): (DataType, DataType) => DataType = { // Since we support array of json objects at the top level, // we need to check the element type and find the root level data type. - case (ArrayType(ty1, _), ty2) => + case (ArrayType(ty1, _, _), ty2) => compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord)(ty1, ty2) - case (ty1, ArrayType(ty2, _)) => + case (ty1, ArrayType(ty2, _, _)) => compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord)(ty1, ty2) // If we see any other data type at the root level, we get records that cannot be // parsed. So, we use the struct as the data type and add the corrupt field to the schema. @@ -270,7 +270,7 @@ private[sql] object InferSchema { DecimalType(range + scale, scale) } - case (StructType(fields1), StructType(fields2)) => + case (StructType(fields1, _), StructType(fields2, _)) => // Both fields1 and fields2 should be sorted by name, since inferField performs sorting. // Therefore, we can take advantage of the fact that we're merging sorted lists and skip // building a hash map or performing additional sorting. @@ -309,8 +309,8 @@ private[sql] object InferSchema { } StructType(newFields.toArray(emptyStructFieldArray)) - case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => - ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) + case (ArrayType(eltType1, containsNull1, _), ArrayType(eltType2, containsNull2, _)) => + ArrayType(compatibleType(eltType1, eltType2), containsNull1 || containsNull2) // The case that given `DecimalType` is capable of given `IntegralType` is handled in // `findTightestCommonTypeOfTwo`. Both cases below will be executed only when diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala index 733fcbfea101..1ff7152147a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala @@ -64,7 +64,7 @@ object JacksonParser extends Logging { // in such an array as a row convertArray(factory, parser, st) - case (START_OBJECT, ArrayType(st, _)) => + case (START_OBJECT, ArrayType(st, _, _)) => // the business end of SPARK-3308: // when an object is found but an array is requested just wrap it in a list convertField(factory, parser, st) :: Nil @@ -181,7 +181,7 @@ object JacksonParser extends Logging { case (START_OBJECT, st: StructType) => convertObject(factory, parser, st) - case (START_ARRAY, ArrayType(st, _)) => + case (START_ARRAY, ArrayType(st, _, _)) => convertArray(factory, parser, st) case (START_OBJECT, MapType(StringType, kt, _)) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 772e031ea77d..1c970ed5c729 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -258,7 +258,7 @@ private[sql] class ParquetFileFormat val conf = sparkSession.sessionState.conf conf.parquetVectorizedReaderEnabled && conf.wholeStageEnabled && schema.length <= conf.wholeStageMaxNumFields && - schema.forall(_.dataType.isInstanceOf[AtomicType]) + !schema.existsRecursively(_.isInstanceOf[MapType]) } override def isSplitable( @@ -334,7 +334,8 @@ private[sql] class ParquetFileFormat val resultSchema = StructType(partitionSchema.fields ++ requiredSchema.fields) val enableVectorizedReader: Boolean = sparkSession.sessionState.conf.parquetVectorizedReaderEnabled && - resultSchema.forall(_.dataType.isInstanceOf[AtomicType]) + !resultSchema.existsRecursively(_.isInstanceOf[MapType]) + // Whole stage codegen (PhysicalRDD) is able to deal with batches directly val returningBatch = supportBatch(sparkSession, resultSchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 426263fa445a..87a9dfb46106 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -154,7 +154,7 @@ private[sql] object ParquetFilters { * Here we filter out such fields. */ private def getFieldMap(dataType: DataType): Map[String, DataType] = dataType match { - case StructType(fields) => + case StructType(fields, _) => // Here we don't flatten the fields in the nested schema but just look up through // root fields. Currently, accessing to nested fields does not push down filters // and it does not support to create filters for them. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index 9dad59647e0d..1f57a650a52e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -488,7 +488,7 @@ private[parquet] class ParquetRowConverter( case (t: GroupType, _) if t.getFieldCount > 1 => true case (t: GroupType, _) if t.getFieldCount == 1 && t.getName == "array" => true case (t: GroupType, _) if t.getFieldCount == 1 && t.getName == parentName + "_tuple" => true - case (t: GroupType, StructType(Array(f))) if f.name == t.getFieldName(0) => true + case (t: GroupType, StructType(Array(f), _)) if f.name == t.getFieldName(0) => true case _ => false } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index bcf535d45521..dc8481c59f2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -71,35 +71,60 @@ private[parquet] class ParquetSchemaConverter( /** * Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL [[StructType]]. */ - def convert(parquetSchema: MessageType): StructType = convert(parquetSchema.asGroupType()) + def convert(parquetSchema: MessageType): StructType = { + convert(parquetSchema.asGroupType(), parquetSchema, Seq.empty[String]) + } - private def convert(parquetSchema: GroupType): StructType = { + private def convert( + parquetSchema: GroupType, + messageType: MessageType, + path: Seq[String]): StructType = { val fields = parquetSchema.getFields.asScala.map { field => field.getRepetition match { case OPTIONAL => - StructField(field.getName, convertField(field), nullable = true) + StructField(field.getName, convertField(field, messageType, path), nullable = true) case REQUIRED => - StructField(field.getName, convertField(field), nullable = false) + StructField(field.getName, convertField(field, messageType, path), nullable = false) case REPEATED => + val builder = new MetadataBuilder() + val curPath = path ++ Seq(field.getName) + val defLevel = messageType.getMaxDefinitionLevel(curPath: _*) + val repLevel = messageType.getMaxRepetitionLevel(curPath: _*) + val metadata = builder.putLong("defLevel", defLevel).putLong("repLevel", repLevel).build() + // A repeated field that is neither contained by a `LIST`- or `MAP`-annotated group nor // annotated by `LIST` or `MAP` should be interpreted as a required list of required // elements where the element type is the type of the field. - val arrayType = ArrayType(convertField(field), containsNull = false) - StructField(field.getName, arrayType, nullable = false) + val arrayType = ArrayType(convertField(field, messageType, path), containsNull = false, + metadata = metadata) + StructField(field.getName, arrayType, nullable = false, metadata = metadata) } } - StructType(fields) + if (path.isEmpty) { + StructType(fields) + } else { + val builder = new MetadataBuilder() + val defLevel = messageType.getMaxDefinitionLevel(path: _*) + val repLevel = messageType.getMaxRepetitionLevel(path: _*) + val metadata = builder.putLong("defLevel", defLevel).putLong("repLevel", repLevel).build() + StructType(fields.toArray, metadata) + } } /** * Converts a Parquet [[Type]] to a Spark SQL [[DataType]]. */ - def convertField(parquetType: Type): DataType = parquetType match { + def convertField( + parquetType: Type, + messageType: MessageType = null, + path: Seq[String] = Seq.empty[String]): DataType = parquetType match { case t: PrimitiveType => convertPrimitiveField(t) - case t: GroupType => convertGroupField(t.asGroupType()) + case t: GroupType => + val curPath = path ++ Seq(t.getName()) + convertGroupField(t.asGroupType(), messageType, curPath) } private def convertPrimitiveField(field: PrimitiveType): DataType = { @@ -190,8 +215,11 @@ private[parquet] class ParquetSchemaConverter( } } - private def convertGroupField(field: GroupType): DataType = { - Option(field.getOriginalType).fold(convert(field): DataType) { + private def convertGroupField( + field: GroupType, + messageType: MessageType, + path: Seq[String]): DataType = { + Option(field.getOriginalType).fold(convert(field, messageType, path): DataType) { // A Parquet list is represented as a 3-level structure: // // group (LIST) { @@ -212,13 +240,22 @@ private[parquet] class ParquetSchemaConverter( val repeatedType = field.getType(0) ParquetSchemaConverter.checkConversionRequirement( repeatedType.isRepetition(REPEATED), s"Invalid list type $field") - + val builder = new MetadataBuilder() if (isElementType(repeatedType, field.getName)) { - ArrayType(convertField(repeatedType), containsNull = false) + val defLevel = messageType.getMaxDefinitionLevel(path: _*) + val repLevel = messageType.getMaxRepetitionLevel(path: _*) + val metadata = builder.putLong("defLevel", defLevel).putLong("repLevel", repLevel).build() + ArrayType(convertField(repeatedType, messageType, path), containsNull = false, + metadata = metadata) } else { val elementType = repeatedType.asGroupType().getType(0) val optional = elementType.isRepetition(OPTIONAL) - ArrayType(convertField(elementType), containsNull = optional) + val curPath = path ++ Seq(repeatedType.getName) + val defLevel = messageType.getMaxDefinitionLevel(curPath: _*) + val repLevel = messageType.getMaxRepetitionLevel(path: _*) + val metadata = builder.putLong("defLevel", defLevel).putLong("repLevel", repLevel).build() + ArrayType(convertField(elementType, messageType, curPath), containsNull = optional, + metadata = metadata) } // scalastyle:off @@ -242,9 +279,11 @@ private[parquet] class ParquetSchemaConverter( val valueType = keyValueType.getType(1) val valueOptional = valueType.isRepetition(OPTIONAL) + val keyPath = path ++ Seq(keyValueType.getName) + val valuePath = path ++ Seq(keyValueType.getName) MapType( - convertField(keyType), - convertField(valueType), + convertField(keyType, messageType, keyPath), + convertField(valueType, messageType, valuePath), valueContainsNull = valueOptional) case _ => @@ -439,7 +478,7 @@ private[parquet] class ParquetSchemaConverter( // `LIST` structure. This behavior is somewhat a hybrid of parquet-hive and parquet-avro // (1.6.0rc3): the 3-level structure is similar to parquet-hive while the 3rd level element // field name "array" is borrowed from parquet-avro. - case ArrayType(elementType, nullable @ true) if writeLegacyParquetFormat => + case ArrayType(elementType, nullable @ true, _) if writeLegacyParquetFormat => // group (LIST) { // optional group bag { // repeated array; @@ -457,7 +496,7 @@ private[parquet] class ParquetSchemaConverter( // Spark 1.4.x and prior versions convert ArrayType with non-nullable elements into a 2-level // LIST structure. This behavior mimics parquet-avro (1.6.0rc3). Note that this case is // covered by the backwards-compatibility rules implemented in `isElementType()`. - case ArrayType(elementType, nullable @ false) if writeLegacyParquetFormat => + case ArrayType(elementType, nullable @ false, _) if writeLegacyParquetFormat => // group (LIST) { // repeated element; // } @@ -486,7 +525,7 @@ private[parquet] class ParquetSchemaConverter( // ArrayType and MapType (standard mode) // ===================================== - case ArrayType(elementType, containsNull) if !writeLegacyParquetFormat => + case ArrayType(elementType, containsNull, _) if !writeLegacyParquetFormat => // group (LIST) { // repeated group list { // element; @@ -521,7 +560,7 @@ private[parquet] class ParquetSchemaConverter( // Other types // =========== - case StructType(fields) => + case StructType(fields, _) => fields.foldLeft(Types.buildGroup(repetition)) { (builder, field) => builder.addField(convertField(field)) }.named(field.name) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index cf68ed4ec36a..f3d026257d47 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -49,7 +49,7 @@ object EvaluatePython { case DateType | TimestampType => true case _: StructType => true case _: UserDefinedType[_] => true - case ArrayType(elementType, _) => needConversionInPython(elementType) + case ArrayType(elementType, _, _) => needConversionInPython(elementType) case MapType(keyType, valueType, _) => needConversionInPython(keyType) || needConversionInPython(valueType) case _ => false @@ -129,10 +129,10 @@ object EvaluatePython { case (c: String, BinaryType) => c.getBytes(StandardCharsets.UTF_8) case (c, BinaryType) if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c - case (c: java.util.List[_], ArrayType(elementType, _)) => + case (c: java.util.List[_], ArrayType(elementType, _, _)) => new GenericArrayData(c.asScala.map { e => fromJava(e, elementType)}.toArray) - case (c, ArrayType(elementType, _)) if c.getClass.isArray => + case (c, ArrayType(elementType, _, _)) if c.getClass.isArray => new GenericArrayData(c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType))) case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => @@ -141,7 +141,7 @@ object EvaluatePython { val values = keyValues.map(kv => fromJava(kv._2, valueType)).toArray ArrayBasedMapData(keys, values) - case (c, StructType(fields)) if c.getClass.isArray => + case (c, StructType(fields, _)) if c.getClass.isArray => val array = c.asInstanceOf[Array[_]] if (array.length != fields.length) { throw new IllegalStateException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 6baf1b6f16cd..d990261f2473 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -68,7 +68,7 @@ private object PostgresDialect extends JdbcDialect { case DoubleType => Some(JdbcType("FLOAT8", Types.DOUBLE)) case t: DecimalType => Some( JdbcType(s"NUMERIC(${t.precision},${t.scale})", java.sql.Types.NUMERIC)) - case ArrayType(et, _) if et.isInstanceOf[AtomicType] => + case ArrayType(et, _, _) if et.isInstanceOf[AtomicType] => getJDBCType(et).map(_.databaseTypeDefinition) .orElse(JdbcUtils.getCommonJDBCType(et).map(_.databaseTypeDefinition)) .map(typeName => JdbcType(s"$typeName[]", java.sql.Types.ARRAY)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 8a980a7eb538..b50f4be8a3f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -63,7 +63,7 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { val actual = converter.convert(MessageTypeParser.parseMessageType(parquetSchema)) val expected = sqlSchema assert( - actual === expected, + DataType.equalsIgnoreCompatibleNullability(actual, expected), s"""Schema mismatch. |Expected schema: ${expected.json} |Actual schema: ${actual.json} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 100cc4daca87..b4628721b5d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -617,7 +617,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(r1.getString(ordinal) == r2.getString(ordinal), "Seed = " + seed) case CalendarIntervalType => assert(r1.getInterval(ordinal) === r2.get(ordinal).asInstanceOf[CalendarInterval]) - case ArrayType(childType, n) => + case ArrayType(childType, n, _) => val a1 = r1.getArray(ordinal).array val a2 = r2.getList(ordinal).toArray assert(a1.length == a2.length, "Seed = " + seed) @@ -649,7 +649,7 @@ class ColumnarBatchSuite extends SparkFunSuite { } case _ => assert(a1 === a2, "Seed = " + seed) } - case StructType(childFields) => + case StructType(childFields, _) => compareStruct(childFields, r1.getStruct(ordinal, fields.length), r2.getStruct(ordinal), seed) case _ => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index bf5cc17a68f5..bdef0a06fe04 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -782,7 +782,7 @@ private[hive] trait HiveInspectors { * We can easily map to the Hive built-in object inspector according to the data type. */ def toInspector(dataType: DataType): ObjectInspector = dataType match { - case ArrayType(tpe, _) => + case ArrayType(tpe, _, _) => ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe)) case MapType(keyType, valueType, _) => ObjectInspectorFactory.getStandardMapObjectInspector( @@ -801,7 +801,7 @@ private[hive] trait HiveInspectors { case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector // TODO decimal precision? case DecimalType() => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector - case StructType(fields) => + case StructType(fields, _) => ObjectInspectorFactory.getStandardStructObjectInspector( java.util.Arrays.asList(fields.map(f => f.name) : _*), java.util.Arrays.asList(fields.map(f => toInspector(f.dataType)) : _*)) @@ -841,7 +841,7 @@ private[hive] trait HiveInspectors { getDecimalWritableConstantObjectInspector(value) case Literal(_, NullType) => getPrimitiveNullWritableConstantObjectInspector - case Literal(value, ArrayType(dt, _)) => + case Literal(value, ArrayType(dt, _, _)) => val listObjectInspector = toInspector(dt) if (value == null) { ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, null) @@ -1045,9 +1045,9 @@ private[hive] trait HiveInspectors { } def toTypeInfo: TypeInfo = dt match { - case ArrayType(elemType, _) => + case ArrayType(elemType, _, _) => getListTypeInfo(elemType.toTypeInfo) - case StructType(fields) => + case StructType(fields, _) => getStructTypeInfo( java.util.Arrays.asList(fields.map(_.name) : _*), java.util.Arrays.asList(fields.map(_.dataType.toTypeInfo) : _*)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index bdec611453b2..3842f469ffb5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -402,11 +402,11 @@ private[spark] object HiveUtils extends Logging { ShortType, DateType, TimestampType, BinaryType) protected[sql] def toHiveString(a: (Any, DataType)): String = a match { - case (struct: Row, StructType(fields)) => + case (struct: Row, StructType(fields, _)) => struct.toSeq.zip(fields).map { case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" }.mkString("{", ",", "}") - case (seq: Seq[_], ArrayType(typ, _)) => + case (seq: Seq[_], ArrayType(typ, _, _)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") case (map: Map[_, _], MapType(kType, vType, _)) => map.map { @@ -425,11 +425,11 @@ private[spark] object HiveUtils extends Logging { /** Hive outputs fields of structs slightly differently than top level attributes. */ protected def toHiveStructString(a: (Any, DataType)): String = a match { - case (struct: Row, StructType(fields)) => + case (struct: Row, StructType(fields, _)) => struct.toSeq.zip(fields).map { case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" }.mkString("{", ",", "}") - case (seq: Seq[_], ArrayType(typ, _)) => + case (seq: Seq[_], ArrayType(typ, _, _)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") case (map: Map[_, _], MapType(kType, vType, _)) => map.map { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index bc51bcb07ec2..10f4caf28ac5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -104,7 +104,7 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { val dataTypes = data.map(_.dataType) def toWritableInspector(dataType: DataType): ObjectInspector = dataType match { - case ArrayType(tpe, _) => + case ArrayType(tpe, _, _) => ObjectInspectorFactory.getStandardListObjectInspector(toWritableInspector(tpe)) case MapType(keyType, valueType, _) => ObjectInspectorFactory.getStandardMapObjectInspector( @@ -122,7 +122,7 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { case DateType => PrimitiveObjectInspectorFactory.writableDateObjectInspector case TimestampType => PrimitiveObjectInspectorFactory.writableTimestampObjectInspector case DecimalType() => PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector - case StructType(fields) => + case StructType(fields, _) => ObjectInspectorFactory.getStandardStructObjectInspector( java.util.Arrays.asList(fields.map(f => f.name) : _*), java.util.Arrays.asList(fields.map(f => toWritableInspector(f.dataType)) : _*))