From 38be47eed7987ac46fea9a17eda2400c704a1df3 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 22 Jun 2016 22:05:11 +0800 Subject: [PATCH 01/17] Support ArrayType and StructType in vectorization parquet reader. --- .../org/apache/spark/sql/types/DataTypes.java | 2 +- .../sql/catalyst/encoders/RowEncoder.scala | 8 +- .../spark/sql/catalyst/expressions/Cast.scala | 4 +- .../codegen/GenerateSafeProjection.scala | 2 +- .../codegen/GenerateUnsafeProjection.scala | 4 +- .../expressions/collectionOperations.scala | 18 +- .../expressions/complexTypeExtractors.scala | 6 +- .../sql/catalyst/expressions/generators.scala | 4 +- .../spark/sql/catalyst/expressions/misc.scala | 8 +- .../expressions/objects/objects.scala | 2 +- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../apache/spark/sql/types/ArrayType.scala | 9 +- .../org/apache/spark/sql/types/DataType.scala | 8 +- .../apache/spark/sql/types/StructType.scala | 20 +- .../spark/sql/RandomDataGenerator.scala | 8 +- .../encoders/ExpressionEncoderSuite.scala | 2 +- .../expressions/ComplexTypeSuite.scala | 4 +- .../parquet/VectorizedColumnReader.java | 271 ++++++++++++++++-- .../VectorizedParquetRecordReader.java | 85 +++++- .../parquet/VectorizedRleValuesReader.java | 2 +- .../execution/vectorized/ColumnVector.java | 135 ++++++++- .../execution/vectorized/ColumnarBatch.java | 2 + .../vectorized/OnHeapColumnVector.java | 20 +- .../spark/sql/execution/QueryExecution.scala | 8 +- .../execution/datasources/jdbc/JDBCRDD.scala | 2 +- .../datasources/jdbc/JdbcUtils.scala | 2 +- .../datasources/json/InferSchema.scala | 14 +- .../datasources/json/JacksonGenerator.scala | 4 +- .../datasources/json/JacksonParser.scala | 4 +- .../parquet/ParquetFileFormat.scala | 6 +- .../datasources/parquet/ParquetFilters.scala | 2 +- .../parquet/ParquetRowConverter.scala | 2 +- .../parquet/ParquetSchemaConverter.scala | 79 +++-- .../sql/execution/python/EvaluatePython.scala | 8 +- .../spark/sql/jdbc/PostgresDialect.scala | 2 +- .../parquet/ParquetSchemaSuite.scala | 2 +- .../vectorized/ColumnarBatchSuite.scala | 4 +- .../spark/sql/hive/HiveInspectors.scala | 10 +- .../org/apache/spark/sql/hive/HiveUtils.scala | 8 +- .../spark/sql/hive/HiveInspectorSuite.scala | 4 +- 40 files changed, 622 insertions(+), 165 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java index 24adeadf9567..bab118609c66 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/types/DataTypes.java @@ -214,6 +214,6 @@ public static StructType createStructType(StructField[] fields) { throw new IllegalArgumentException("fields should have distinct names."); } - return StructType$.MODULE$.apply(fields); + return new StructType(fields); } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 67fca153b551..a1bbc7821dc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -119,7 +119,7 @@ object RowEncoder { "fromString", inputObject :: Nil) - case t @ ArrayType(et, _) => et match { + case t @ ArrayType(et, _, _) => et match { case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => // TODO: validate input type for primitive array. NewInstance( @@ -152,7 +152,7 @@ object RowEncoder { convertedKeys :: convertedValues :: Nil, dataType = t) - case StructType(fields) => + case StructType(fields, _) => val nonNullOutput = CreateNamedStruct(fields.zipWithIndex.flatMap { case (field, index) => val fieldValue = serializerFor( ValidateExternalType( @@ -259,7 +259,7 @@ object RowEncoder { case StringType => Invoke(input, "toString", ObjectType(classOf[String])) - case ArrayType(et, nullable) => + case ArrayType(et, nullable, _) => val arrayData = Invoke( MapObjects(deserializerFor(_), input, et), @@ -284,7 +284,7 @@ object RowEncoder { "toScalaMap", keyData :: valueData :: Nil) - case schema @ StructType(fields) => + case schema @ StructType(fields, _) => val convertedFields = fields.zipWithIndex.map { case (f, i) => If( Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index b1e89b5de833..7ed47792fd26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -62,7 +62,7 @@ object Cast { case (TimestampType, _: NumericType) => true case (_: NumericType, _: NumericType) => true - case (ArrayType(fromType, fn), ArrayType(toType, tn)) => + case (ArrayType(fromType, fn, _), ArrayType(toType, tn, _)) => canCast(fromType, toType) && resolvableNullability(fn || forceNullable(fromType, toType), tn) @@ -72,7 +72,7 @@ object Cast { canCast(fromValue, toValue) && resolvableNullability(fn || forceNullable(fromValue, toValue), tn) - case (StructType(fromFields), StructType(toFields)) => + case (StructType(fromFields, _), StructType(toFields, _)) => fromFields.length == toFields.length && fromFields.zip(toFields).forall { case (fromField, toField) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index b891f9467375..a39fb2cb9d71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -129,7 +129,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] input: String, dataType: DataType): ExprCode = dataType match { case s: StructType => createCodeForStruct(ctx, input, s) - case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType) + case ArrayType(elementType, _, _) => createCodeForArray(ctx, input, elementType) case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType) // UTF8String act as a pointer if it's inside UnsafeRow, so copy it to make it safe. case StringType => ExprCode("", "false", s"$input.clone()") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 5efba4b3a608..5414c532ba92 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -117,7 +117,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor); """ - case a @ ArrayType(et, _) => + case a @ ArrayType(et, _, _) => s""" // Remember the current cursor so that we can calculate how many bytes are // written later. @@ -202,7 +202,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ${writeStructToBuffer(ctx, element, t.map(_.dataType), bufferHolder)} """ - case a @ ArrayType(et, _) => + case a @ ArrayType(et, _, _) => s""" $arrayWriter.setOffset($index); ${writeArrayToBuffer(ctx, element, et, bufferHolder)} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c71cb73d65bf..a585dc3cef1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -63,9 +63,9 @@ case class SortArray(base: Expression, ascendingOrder: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType) override def checkInputDataTypes(): TypeCheckResult = base.dataType match { - case ArrayType(dt, _) if RowOrdering.isOrderable(dt) => + case ArrayType(dt, _, _) if RowOrdering.isOrderable(dt) => TypeCheckResult.TypeCheckSuccess - case ArrayType(dt, _) => + case ArrayType(dt, _, _) => TypeCheckResult.TypeCheckFailure( s"$prettyName does not support sorting array of type ${dt.simpleString}") case _ => @@ -75,9 +75,9 @@ case class SortArray(base: Expression, ascendingOrder: Expression) @transient private lazy val lt: Comparator[Any] = { val ordering = base.dataType match { - case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] - case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] - case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(n: AtomicType, _, _) => n.ordering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(a: ArrayType, _, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(s: StructType, _, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] } new Comparator[Any]() { @@ -98,9 +98,9 @@ case class SortArray(base: Expression, ascendingOrder: Expression) @transient private lazy val gt: Comparator[Any] = { val ordering = base.dataType match { - case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]] - case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] - case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(n: AtomicType, _, _) => n.ordering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(a: ArrayType, _, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]] + case _ @ ArrayType(s: StructType, _, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]] } new Comparator[Any]() { @@ -144,7 +144,7 @@ case class ArrayContains(left: Expression, right: Expression) override def inputTypes: Seq[AbstractDataType] = right.dataType match { case NullType => Seq() case _ => left.dataType match { - case n @ ArrayType(element, _) => Seq(n, element) + case n @ ArrayType(element, _, _) => Seq(n, element) case _ => Seq() } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 3b4468f55ca7..1a134906e5b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -48,12 +48,12 @@ object ExtractValue { resolver: Resolver): Expression = { (child.dataType, extraction) match { - case (StructType(fields), NonNullLiteral(v, StringType)) => + case (StructType(fields, _), NonNullLiteral(v, StringType)) => val fieldName = v.toString val ordinal = findField(fields, fieldName, resolver) GetStructField(child, ordinal, Some(fieldName)) - case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) => + case (ArrayType(StructType(fields, _), containsNull, _), NonNullLiteral(v, StringType)) => val fieldName = v.toString val ordinal = findField(fields, fieldName, resolver) GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), @@ -65,7 +65,7 @@ object ExtractValue { case (otherType, _) => val errorMsg = otherType match { - case StructType(_) => + case StructType(_, _) => s"Field name should be String Literal, but it's $extraction" case other => s"Can't extract value from $child" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 12c35644e564..0950935a1418 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -115,14 +115,14 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit // hive-compatible default alias for explode function ("col" for array, "key", "value" for map) override def elementSchema: StructType = child.dataType match { - case ArrayType(et, containsNull) => new StructType().add("col", et, containsNull) + case ArrayType(et, containsNull, _) => new StructType().add("col", et, containsNull) case MapType(kt, vt, valueContainsNull) => new StructType().add("key", kt, false).add("value", vt, valueContainsNull) } override def eval(input: InternalRow): TraversableOnce[InternalRow] = { child.dataType match { - case ArrayType(et, _) => + case ArrayType(et, _, _) => val inputArray = child.eval(input).asInstanceOf[ArrayData] if (inputArray == null) { Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 1c0787bf9227..fa9aa3468847 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -315,7 +315,7 @@ abstract class HashExpression[E] extends Expression { val numBytes = s"$input.numBytes()" s"$result = $hasher.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);" - case ArrayType(et, containsNull) => + case ArrayType(et, containsNull, _) => val index = ctx.freshName("index") s""" for (int $index = 0; $index < $input.numElements(); $index++) { @@ -336,7 +336,7 @@ abstract class HashExpression[E] extends Expression { } """ - case StructType(fields) => + case StructType(fields, _) => fields.zipWithIndex.map { case (field, index) => nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx) }.mkString("\n") @@ -385,7 +385,7 @@ abstract class InterpretedHashFunction { case array: ArrayData => val elementType = dataType match { case udt: UserDefinedType[_] => udt.sqlType.asInstanceOf[ArrayType].elementType - case ArrayType(et, _) => et + case ArrayType(et, _, _) => et } var result = seed var i = 0 @@ -417,7 +417,7 @@ abstract class InterpretedHashFunction { val types: Array[DataType] = dataType match { case udt: UserDefinedType[_] => udt.sqlType.asInstanceOf[StructType].map(_.dataType).toArray - case StructType(fields) => fields.map(_.dataType) + case StructType(fields, _) => fields.map(_.dataType) } var result = seed var i = 0 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index c597a2a70944..e24a58a73e8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -435,7 +435,7 @@ case class MapObjects private( s"${genInputData.value}.length" -> s"${genInputData.value}[$loopIndex]" case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => s"${genInputData.value}.size()" -> s"${genInputData.value}.get($loopIndex)" - case ArrayType(et, _) => + case ArrayType(et, _, _) => s"${genInputData.value}.numElements()" -> ctx.getValue(genInputData.value, et, loopIndex) case ObjectType(cls) if cls == classOf[Object] => s"$seq == null ? $array.length : $seq.size()" -> diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 6e78ad0e7765..a713565e9164 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1683,7 +1683,7 @@ object EmbedSerializerInFilter extends Rule[LogicalPlan] { def samePrimitiveType(lhs: StructType, rhs: StructType): Boolean = { (lhs, rhs) match { - case (StructType(Array(f1)), StructType(Array(f2))) => f1.dataType == f2.dataType + case (StructType(Array(f1), _), StructType(Array(f2), _)) => f1.dataType == f2.dataType case _ => false } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 520e34436162..e51f07d28c09 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -51,9 +51,16 @@ object ArrayType extends AbstractDataType { * * @param elementType The data type of values. * @param containsNull Indicates if values have `null` values + * @param metadata The metadata of this array type. */ @DeveloperApi -case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType { +case class ArrayType( + elementType: DataType, + containsNull: Boolean, + metadata: Metadata = Metadata.empty) extends DataType { + + protected def this(elementType: DataType, containsNull: Boolean) = + this(elementType, containsNull, Metadata.empty) /** No-arg constructor for kryo. */ protected def this() = this(null, false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 4fc65cbce15b..3e5ad36229e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -196,12 +196,12 @@ object DataType { */ private[types] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { (left, right) match { - case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) => + case (ArrayType(leftElementType, _, _), ArrayType(rightElementType, _, _)) => equalsIgnoreNullability(leftElementType, rightElementType) case (MapType(leftKeyType, leftValueType, _), MapType(rightKeyType, rightValueType, _)) => equalsIgnoreNullability(leftKeyType, rightKeyType) && equalsIgnoreNullability(leftValueType, rightValueType) - case (StructType(leftFields), StructType(rightFields)) => + case (StructType(leftFields, _), StructType(rightFields, _)) => leftFields.length == rightFields.length && leftFields.zip(rightFields).forall { case (l, r) => l.name == r.name && equalsIgnoreNullability(l.dataType, r.dataType) @@ -226,7 +226,7 @@ object DataType { */ private[sql] def equalsIgnoreCompatibleNullability(from: DataType, to: DataType): Boolean = { (from, to) match { - case (ArrayType(fromElement, fn), ArrayType(toElement, tn)) => + case (ArrayType(fromElement, fn, _), ArrayType(toElement, tn, _)) => (tn || !fn) && equalsIgnoreCompatibleNullability(fromElement, toElement) case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => @@ -234,7 +234,7 @@ object DataType { equalsIgnoreCompatibleNullability(fromKey, toKey) && equalsIgnoreCompatibleNullability(fromValue, toValue) - case (StructType(fromFields), StructType(toFields)) => + case (StructType(fromFields, _), StructType(toFields, _)) => fromFields.length == toFields.length && fromFields.zip(toFields).forall { case (fromField, toField) => fromField.name == toField.name && diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 436512ff6933..606f1b984c75 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -92,7 +92,11 @@ import org.apache.spark.util.Utils * }}} */ @DeveloperApi -case class StructType(fields: Array[StructField]) extends DataType with Seq[StructField] { +case class StructType( + fields: Array[StructField], + metadata: Metadata = Metadata.empty) extends DataType with Seq[StructField] { + + def this(fields: Array[StructField]) = this(fields, Metadata.empty) /** No-arg constructor for kryo. */ def this() = this(Array.empty[StructField]) @@ -106,7 +110,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru override def equals(that: Any): Boolean = { that match { - case StructType(otherFields) => + case StructType(otherFields, _) => java.util.Arrays.equals( fields.asInstanceOf[Array[AnyRef]], otherFields.asInstanceOf[Array[AnyRef]]) case _ => false @@ -371,11 +375,11 @@ object StructType extends AbstractDataType { } } - def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray) + def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray, Metadata.empty) def apply(fields: java.util.List[StructField]): StructType = { import scala.collection.JavaConverters._ - StructType(fields.asScala) + apply(fields.asScala) } protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType = @@ -383,7 +387,7 @@ object StructType extends AbstractDataType { def removeMetadata(key: String, dt: DataType): DataType = dt match { - case StructType(fields) => + case StructType(fields, _) => val newFields = fields.map { f => val mb = new MetadataBuilder() f.copy(dataType = removeMetadata(key, f.dataType), @@ -395,8 +399,8 @@ object StructType extends AbstractDataType { private[sql] def merge(left: DataType, right: DataType): DataType = (left, right) match { - case (ArrayType(leftElementType, leftContainsNull), - ArrayType(rightElementType, rightContainsNull)) => + case (ArrayType(leftElementType, leftContainsNull, _), + ArrayType(rightElementType, rightContainsNull, _)) => ArrayType( merge(leftElementType, rightElementType), leftContainsNull || rightContainsNull) @@ -408,7 +412,7 @@ object StructType extends AbstractDataType { merge(leftValueType, rightValueType), leftContainsNull || rightContainsNull) - case (StructType(leftFields), StructType(rightFields)) => + case (StructType(leftFields, _), StructType(rightFields, _)) => val newFields = ArrayBuffer.empty[StructField] // This metadata will record the fields that only exist in one of two StructTypes val optionalMeta = new MetadataBuilder() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index 850869799507..a02bcbad3c2c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -196,7 +196,7 @@ object RandomDataGenerator { case ShortType => randomNumeric[Short]( rand, _.nextInt().toShort, Seq(Short.MinValue, Short.MaxValue, 0.toShort)) case NullType => Some(() => null) - case ArrayType(elementType, containsNull) => + case ArrayType(elementType, containsNull, _) => forType(elementType, nullable = containsNull, rand).map { elementGenerator => () => Seq.fill(rand.nextInt(MAX_ARR_SIZE))(elementGenerator()) } @@ -220,7 +220,7 @@ object RandomDataGenerator { keys.zip(values).toMap } } - case StructType(fields) => + case StructType(fields, _) => val maybeFieldGenerators: Seq[Option[() => Any]] = fields.map { field => forType(field.dataType, nullable = field.nullable, rand) } @@ -269,7 +269,7 @@ object RandomDataGenerator { val fields = mutable.ArrayBuffer.empty[Any] schema.fields.foreach { f => f.dataType match { - case ArrayType(childType, nullable) => + case ArrayType(childType, nullable, _) => val data = if (f.nullable && rand.nextFloat() <= PROBABILITY_OF_NULL) { null } else { @@ -286,7 +286,7 @@ object RandomDataGenerator { arr } fields += data - case StructType(children) => + case StructType(children, _) => fields += randomRow(rand, StructType(children)) case _ => val generator = RandomDataGenerator.forType(f.dataType, f.nullable, rand) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index a1f9259f139e..8bcb41d7bf25 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -376,7 +376,7 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { val encodedData = try { row.toSeq(encoder.schema).zip(schema).map { - case (a: ArrayData, AttributeReference(_, ArrayType(et, _), _, _)) => + case (a: ArrayData, AttributeReference(_, ArrayType(et, _, _), _, _)) => a.toArray[Any](et).toSeq case (other, _) => other diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index ec7be4d4b849..c991185dff1b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -77,7 +77,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { def getStructField(expr: Expression, fieldName: String): GetStructField = { expr.dataType match { - case StructType(fields) => + case StructType(fields, _) => val index = fields.indexWhere(_.name == fieldName) GetStructField(expr, index) } @@ -107,7 +107,7 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { def getArrayStructFields(expr: Expression, fieldName: String): GetArrayStructFields = { expr.dataType match { - case ArrayType(StructType(fields), containsNull) => + case ArrayType(StructType(fields, _), containsNull, _) => val field = fields.find(_.name == fieldName).get GetArrayStructFields(expr, field, fields.indexOf(field), fields.length, containsNull) } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 662a03d3b56a..2625e3c99ca7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -18,6 +18,8 @@ package org.apache.spark.sql.execution.datasources.parquet; import java.io.IOException; +import java.util.HashMap; +import java.util.Map; import org.apache.commons.lang.NotImplementedException; import org.apache.parquet.bytes.BytesUtils; @@ -30,6 +32,8 @@ import org.apache.parquet.schema.PrimitiveType; import org.apache.spark.sql.execution.vectorized.ColumnVector; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.DecimalType; @@ -67,6 +71,11 @@ public class VectorizedColumnReader { */ private final int maxDefLevel; + /** + * Maximum repetition level for this column. + */ + private final int maxRepLevel; + /** * Repetition/Definition/Value readers. */ @@ -77,7 +86,8 @@ public class VectorizedColumnReader { // Only set if vectorized decoding is true. This is used instead of the row by row decoding // with `definitionLevelColumn`. private VectorizedRleValuesReader defColumn; - + private VectorizedRleValuesReader defColumnCopy; + /** * Total number of values in this column (in this row group). */ @@ -96,6 +106,7 @@ public VectorizedColumnReader(ColumnDescriptor descriptor, PageReader pageReader this.descriptor = descriptor; this.pageReader = pageReader; this.maxDefLevel = descriptor.getMaxDefinitionLevel(); + this.maxRepLevel = descriptor.getMaxRepetitionLevel(); DictionaryPage dictionaryPage = pageReader.readDictionaryPage(); if (dictionaryPage != null) { @@ -114,32 +125,45 @@ public VectorizedColumnReader(ColumnDescriptor descriptor, PageReader pageReader throw new IOException("totalValueCount == 0"); } } - - /** - * Advances to the next value. Returns true if the value is non-null. - */ - private boolean next() throws IOException { - if (valuesRead >= endOfPageValueCount) { - if (valuesRead >= totalValueCount) { - // How do we get here? Throw end of stream exception? - return false; - } - readPage(); - } - ++valuesRead; - // TODO: Don't read for flat schemas - //repetitionLevel = repetitionLevelColumn.nextInt(); - return definitionLevelColumn.nextInt() == maxDefLevel; - } - + /** * Reads `total` values from this columnReader into column. */ - void readBatch(int total, ColumnVector column) throws IOException { + public void readBatch(int total, ColumnVector column) throws IOException { + boolean isNestedColumn = column.getParentColumn() != null; + boolean isRepeatedColumn = maxRepLevel > 0; int rowId = 0; - while (total > 0) { + int valuesReadInPage = 0; + int repeatedRowId = 0; + + Map rowIds = new HashMap(); + Map offsets = new HashMap(); + + while (true) { // Compute the number of values we want to read in this page. int leftInPage = (int) (endOfPageValueCount - valuesRead); + // When we reach the end of this page, we update repetition info of this column + // and then read next page. + if (leftInPage == 0) { + // Update repetition info for this column. + if (valuesReadInPage > 0 && isNestedColumn) { + updateReptitionInfo(column, rowIds, offsets, valuesReadInPage, total); + if (rowIds.containsKey(1)) { + repeatedRowId = rowIds.get(1); + } + valuesReadInPage = 0; + } + } + // Stop condition: + // If we are going to read data in repeated column, the stop condition is that we + // read `total` repeated columns. Eg., if we want to read 5 records of an array of int column. + // we can't just read 5 integers. Instead, we have to read the integers until 5 arrays are put + // into this array column. + if (isRepeatedColumn) { + if (repeatedRowId >= total) break; + } else { + if (total <= 0) break; + } if (leftInPage == 0) { readPage(); leftInPage = (int) (endOfPageValueCount - valuesRead); @@ -147,7 +171,8 @@ void readBatch(int total, ColumnVector column) throws IOException { int num = Math.min(total, leftInPage); if (useDictionary) { // Read and decode dictionary ids. - ColumnVector dictionaryIds = column.reserveDictionaryIds(total); + int dictionaryCapacity = Math.max(total, rowId + num); + ColumnVector dictionaryIds = column.reserveDictionaryIds(dictionaryCapacity); defColumn.readIntegers( num, dictionaryIds, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); @@ -167,7 +192,7 @@ void readBatch(int total, ColumnVector column) throws IOException { } else { if (column.hasDictionary() && rowId != 0) { // This batch already has dictionary encoded values but this new page is not. The batch - // does not support a mix of dictionary and not so we will decode the dictionary. + // does not support a mix of dictionary and not, so we will decode the dictionary. decodeDictionaryIds(0, rowId, column, column.getDictionaryIds()); } column.setDictionary(null); @@ -201,9 +226,12 @@ void readBatch(int total, ColumnVector column) throws IOException { } } + valuesReadInPage += num; valuesRead += num; rowId += num; - total -= num; + if (!isRepeatedColumn) { + total -= num; + } } } @@ -478,10 +506,195 @@ private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset) thr } } + /** + * Inserts arrays into parent repeated columns. + */ + private void insertRepeatedArray( + ColumnVector column, + Map rowIds, + Map offsets, + Map reptitionMap, + int total, + int repLevel) throws IOException { + ColumnVector parentRepeatedColumn = column; + int curRepLevel = maxRepLevel; + while (true) { + parentRepeatedColumn = parentRepeatedColumn.getNearestParentArrayColumn(); + if (parentRepeatedColumn != null) { + int parentColRepLevel = parentRepeatedColumn.getRepLevel(); + // Only process the parent columns whose repetition levels are equal to or more than + // the given repetition level (less than or equal to max repetition level). + // E.g., when the current repetition level is 1 and max repetition level us 2, + // we only add arrays into the column whose repetition level is 1. + if (parentColRepLevel >= repLevel) { + // Current row id at this column. + int rowId = 0; + if (rowIds.containsKey(curRepLevel)) { + rowId = rowIds.get(curRepLevel); + } + // Repetition count. + int repCount = 0; + if (reptitionMap.containsKey(curRepLevel)) { + repCount = reptitionMap.get(curRepLevel); + } + // Offset of values. + int offset = 0; + if (offsets.containsKey(curRepLevel)) { + offset = offsets.get(curRepLevel); + } + + parentRepeatedColumn.putArray(rowId, offset, repCount); + + offset += repCount; + repCount = 0; + rowId++; + + offsets.put(curRepLevel, offset); + reptitionMap.put(curRepLevel, repCount); + rowIds.put(curRepLevel, rowId); + + // Increase the repetition count for parent repetition level as we add a new record. + if (curRepLevel > 1) { + int nextRepCount = 0; + if (reptitionMap.containsKey(curRepLevel - 1)) { + nextRepCount = reptitionMap.get(curRepLevel - 1); + } + reptitionMap.put(curRepLevel - 1, nextRepCount + 1); + } + + if (curRepLevel == 1 && rowId == total) { + return; + } + curRepLevel--; + } else { + break; + } + } else { + break; + } + } + } + + /** + * Reads repetition level for each value and updates length and offset info for above columns, + * recursively. + */ + private void updateReptitionInfo( + ColumnVector column, + Map rowIds, + Map offsets, + int valuesReadInPage, + int total) throws IOException { + // Keeps repetition levels and corresponding repetition counts. + Map reptitionMap = new HashMap(); + + if (column.getParentColumn() != null) { + int prevRepLevel = -1; + + for (int i = 0; i < valuesReadInPage; i++) { + int repLevel = repetitionLevelColumn.nextInt(); + int defLevel = definitionLevelColumn.nextInt(); + + if (prevRepLevel >= 0) { + // When a new record begins at lower repetition level, + // we insert array into repeated column. + if (repLevel < maxRepLevel) { + insertRepeatedArray(column, rowIds, offsets, reptitionMap, total, repLevel); + } + } + prevRepLevel = repLevel; + + // When definition level is less than max definition level, + // there is a null value. + if (defLevel < maxDefLevel) { + int offset = 0; + if (offsets.containsKey(maxRepLevel)) { + offset = offsets.get(maxRepLevel); + } + + if (column.getParentColumn().getDefLevel() == maxDefLevel) { + insertRepeatedArray(column, rowIds, offsets, reptitionMap, total, repLevel); + offsets.put(maxRepLevel, offset + 1); + } else if (defLevel == 0) { + // A null record at root level. + // Obtain most-top column (repetition level 1). + ColumnVector topColumn = column.getParentColumn(); + while (topColumn.getParentColumn() != null) { + topColumn = topColumn.getParentColumn(); + } + // Get its current row id. + int rowId = 0; + if (rowIds.containsKey(1)) { + rowId = rowIds.get(1); + } + // Insert null record and increase row id. + topColumn.putNull(rowId); + rowIds.put(1, rowId + 1); + + // Increse row id at later repetition levels. + for (int j = 2; j <= maxRepLevel; j++) { + rowId = 0; + if (rowIds.containsKey(j)) { + rowId = rowIds.get(j); + } + rowIds.put(j, rowId + 1); + } + + // Move to next offset in max repetition level as we processed the current value. + offsets.put(maxRepLevel, offset + 1); + + prevRepLevel = -1; + } else if (repLevel == maxRepLevel) { + // A null value at max repetition level. + // This null value is repeated in a wrapping group. Simply increase repetition count. + int repCount = 0; + if (reptitionMap.containsKey(repLevel)) { + repCount = reptitionMap.get(repLevel); + } + reptitionMap.put(repLevel, repCount + 1); + } else { + // Null value at definition level > 0. + int repCount = 0; + if (reptitionMap.containsKey(maxRepLevel)) { + repCount = reptitionMap.get(maxRepLevel); + } + reptitionMap.put(maxRepLevel, repCount + 1); + } + } else { + // Determine the repetition level of non-null values. + + // A new record begins with non-null value. + if (maxRepLevel == 0) { + // A required record at root level. + int repCount = 0; + if (reptitionMap.containsKey(1)) { + repCount = reptitionMap.get(1); + } + reptitionMap.put(1, repCount + 1); + // insertArrayForRepetition(column, rowIds, offsets, reptitionMap, total, + // 0, 1); + insertRepeatedArray(column, rowIds, offsets, reptitionMap, total, maxRepLevel - 1); + } else { + // Repeated values. We increase repetition count. + if (reptitionMap.containsKey(maxRepLevel)) { + reptitionMap.put(maxRepLevel, reptitionMap.get(maxRepLevel) + 1); + } else { + reptitionMap.put(maxRepLevel, 1); + } + } + } + } + if (prevRepLevel >= 0) { + insertRepeatedArray(column, rowIds, offsets, reptitionMap, total, 0); + } + } + } + private void readPageV1(DataPageV1 page) throws IOException { this.pageValueCount = page.getValueCount(); ValuesReader rlReader = page.getRlEncoding().getValuesReader(descriptor, REPETITION_LEVEL); ValuesReader dlReader; + ValuesReader dlReaderCopy; // Initialize the decoders. if (page.getDlEncoding() != Encoding.RLE && descriptor.getMaxDefinitionLevel() != 0) { @@ -489,14 +702,17 @@ private void readPageV1(DataPageV1 page) throws IOException { } int bitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel()); this.defColumn = new VectorizedRleValuesReader(bitWidth); + this.defColumnCopy = new VectorizedRleValuesReader(bitWidth); dlReader = this.defColumn; + dlReaderCopy = this.defColumnCopy; this.repetitionLevelColumn = new ValuesReaderIntIterator(rlReader); - this.definitionLevelColumn = new ValuesReaderIntIterator(dlReader); + this.definitionLevelColumn = new ValuesReaderIntIterator(dlReaderCopy); try { byte[] bytes = page.getBytes().toByteArray(); rlReader.initFromPage(pageValueCount, bytes, 0); int next = rlReader.getNextOffset(); dlReader.initFromPage(pageValueCount, bytes, next); + dlReaderCopy.initFromPage(pageValueCount, bytes, next); next = dlReader.getNextOffset(); initDataReader(page.getValueEncoding(), bytes, next); } catch (IOException e) { @@ -511,9 +727,12 @@ private void readPageV2(DataPageV2 page) throws IOException { int bitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel()); this.defColumn = new VectorizedRleValuesReader(bitWidth); - this.definitionLevelColumn = new ValuesReaderIntIterator(this.defColumn); + this.defColumnCopy = new VectorizedRleValuesReader(bitWidth); + this.definitionLevelColumn = new ValuesReaderIntIterator(this.defColumnCopy); this.defColumn.initFromBuffer( this.pageValueCount, page.getDefinitionLevels().toByteArray()); + this.defColumnCopy.initFromBuffer( + this.pageValueCount, page.getDefinitionLevels().toByteArray()); try { initDataReader(page.getDataEncoding(), page.getData().toByteArray(), 0); } catch (IOException e) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 51bdf0f0f229..e7b2551989ef 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils; import org.apache.spark.sql.execution.vectorized.ColumnarBatch; +import org.apache.spark.sql.execution.vectorized.ColumnVector; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; @@ -178,6 +179,7 @@ public void initBatch(MemoryMode memMode, StructType partitionColumns, } } + // Allocate ColumnVectors in ColumnarBatch columnarBatch = ColumnarBatch.allocate(batchSchema, memMode); if (partitionColumns != null) { int partitionIdx = sparkSchema.fields().length; @@ -188,12 +190,30 @@ public void initBatch(MemoryMode memMode, StructType partitionColumns, } // Initialize missing columns with nulls. - for (int i = 0; i < missingColumns.length; i++) { - if (missingColumns[i]) { - columnarBatch.column(i).putNulls(0, columnarBatch.capacity()); - columnarBatch.column(i).setIsConstant(); + int missingColumnIdx = 0; + int partitionIdxBase = missingColumns.length; + if (partitionColumns != null) { + partitionIdxBase = sparkSchema.fields().length; + } + for (int i = 0; i < columnarBatch.numFields(); i++) { + if (i < partitionIdxBase) { + missingColumnIdx = initColumnWithNulls(columnarBatch.column(i), missingColumnIdx); + } + } + } + + private int initColumnWithNulls(ColumnVector column, int missingColumnIdx) { + if (column.isComplex()) { + for (int j = 0; j < column.getChildColumnNums(); j++) { + missingColumnIdx = initColumnWithNulls(column.getChildColumn(j), missingColumnIdx); + } + } else { + if (missingColumns[missingColumnIdx++]) { + column.putNulls(0, columnarBatch.capacity()); + column.setIsConstant(); } } + return missingColumnIdx; } public void initBatch() { @@ -225,10 +245,10 @@ public boolean nextBatch() throws IOException { checkEndOfRowGroup(); int num = (int) Math.min((long) columnarBatch.capacity(), totalCountLoadedSoFar - rowsReturned); - for (int i = 0; i < columnReaders.length; ++i) { - if (columnReaders[i] == null) continue; - columnReaders[i].readBatch(num, columnarBatch.column(i)); + for (int i = 0; i < columnarBatch.numFields(); i++) { + readBatchOnColumnVector(columnarBatch.column(i), num); } + rowsReturned += num; columnarBatch.setNumRows(num); numBatched = num; @@ -236,17 +256,23 @@ public boolean nextBatch() throws IOException { return true; } + private void readBatchOnColumnVector(ColumnVector column, int num) throws IOException { + if (column.hasColumnReader()) { + column.readBatch(num); + } else { + for (int j = 0; j < column.getChildColumnNums(); j++) { + readBatchOnColumnVector(column.getChildColumn(j), num); + } + } + } + private void initializeInternal() throws IOException, UnsupportedOperationException { /** * Check that the requested schema is supported. */ - missingColumns = new boolean[requestedSchema.getFieldCount()]; - for (int i = 0; i < requestedSchema.getFieldCount(); ++i) { - Type t = requestedSchema.getFields().get(i); - if (!t.isPrimitive() || t.isRepetition(Type.Repetition.REPEATED)) { - throw new UnsupportedOperationException("Complex types not supported."); - } - + missingColumns = new boolean[requestedSchema.getColumns().size()]; + // For loop on each physical columns. + for (int i = 0; i < requestedSchema.getColumns().size(); ++i) { String[] colPath = requestedSchema.getPaths().get(i); if (fileSchema.containsPath(colPath)) { ColumnDescriptor fd = fileSchema.getColumnDescription(colPath); @@ -265,6 +291,25 @@ private void initializeInternal() throws IOException, UnsupportedOperationExcept } } + private int setupColumnReader( + ColumnVector column, + VectorizedColumnReader[] columnReaders, + int readerIdx) { + if (column.isComplex()) { + column.setColumnReader(null); + for (int j = 0; j < column.getChildColumnNums(); j++) { + readerIdx = setupColumnReader(column.getChildColumn(j), columnReaders, readerIdx); + } + } else { + if (!missingColumns[readerIdx]) { + column.setColumnReader(columnReaders[readerIdx++]); + } else { + readerIdx++; + } + } + return readerIdx; + } + private void checkEndOfRowGroup() throws IOException { if (rowsReturned != totalCountLoadedSoFar) return; PageReadStore pages = reader.readNextRowGroup(); @@ -272,13 +317,25 @@ private void checkEndOfRowGroup() throws IOException { throw new IOException("expecting more rows but reached last block. Read " + rowsReturned + " out of " + totalRowCount); } + // Return physical columns stored in Parquet file. Not logical fields. + // For example, a nested StructType field in requestedSchema might have many columns. + // A column is always a primitive type column. List columns = requestedSchema.getColumns(); columnReaders = new VectorizedColumnReader[columns.size()]; + for (int i = 0; i < columns.size(); ++i) { if (missingColumns[i]) continue; columnReaders[i] = new VectorizedColumnReader(columns.get(i), pages.getPageReader(columns.get(i))); } + + // Associate ColumnReaders to ColumnVectors in ColumnarBatch. + int readerIdx = 0; + int partitionIdx = sparkSchema.fields().length; + for (int i = 0; i < columnarBatch.numFields(); i++) { + if (i >= partitionIdx) break; + readerIdx = setupColumnReader(columnarBatch.column(i), columnReaders, readerIdx); + } totalCountLoadedSoFar += pages.getRowCount(); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index 62157389013b..6413e4e595c8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -498,7 +498,7 @@ public void readLongs(int total, ColumnVector c, int rowId) { public void readBinary(int total, ColumnVector c, int rowId) { throw new UnsupportedOperationException("only readInts is valid."); } - + @Override public void readBooleans(int total, ColumnVector c, int rowId) { throw new UnsupportedOperationException("only readInts is valid."); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index 3f9425525669..c6daf3bcb5a5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.execution.vectorized; +import java.io.IOException; import java.math.BigDecimal; import java.math.BigInteger; @@ -27,6 +28,7 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.execution.datasources.parquet.VectorizedColumnReader; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -480,7 +482,8 @@ public void reset() { public abstract double getDouble(int rowId); /** - * Puts a byte array that already exists in this column. + * Puts an array that already exists in this column. + * This method only updates array length and offset data in this column. */ public abstract void putArray(int rowId, int offset, int length); @@ -826,6 +829,28 @@ public final int appendStruct(boolean isNull) { */ public final ColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; } + /** + * Returns the number of childColumns. + */ + public final int getChildColumnNums() { + if (childColumns == null) { + return 0; + } else { + return childColumns.length; + } + } + + /** + * Returns whether this ColumnVector represents complex types such as Array, Map, Struct. + */ + public final boolean isComplex() { + if (type instanceof ArrayType || type instanceof StructType || type instanceof MapType) { + return true; + } else { + return false; + } + } + /** * Returns the elements appended. */ @@ -841,6 +866,16 @@ public final int appendStruct(boolean isNull) { */ public final void setIsConstant() { isConstant = true; } + /** + * Returns definition level for this column. This value is valid only if isComplex() return true. + */ + public final int getDefLevel() { return defLevel; } + + /** + * Returns repetition level for this column. This value is valid only if isComplex() return true. + */ + public final int getRepLevel() { return repLevel; } + /** * Maximum number of rows that can be stored in this column. */ @@ -868,6 +903,16 @@ public final int appendStruct(boolean isNull) { */ protected boolean isConstant; + /** + * Max definition level of this column. This value is valid only if this is a nested column. + */ + protected int defLevel; + + /** + * Max repetition level of this column. This value is valid only if this is a nested column. + */ + protected int repLevel; + /** * Default size of each array length value. This grows as necessary. */ @@ -905,6 +950,70 @@ public final int appendStruct(boolean isNull) { */ protected ColumnVector dictionaryIds; + /** + * Associated VectorizedColumnReader which is used to load data into this ColumnVector. + * If this is a complex type such as array or struct, the VectorizedColumnReader will be + * null. + */ + protected VectorizedColumnReader columnReader; + + /** + * The parent ColumnVector of this column. If this column is not an element of nested column, + * then this is null. + */ + protected ColumnVector parentColumn; + + /** + * Sets the columnReader for this column. + */ + public void setColumnReader(VectorizedColumnReader columnReader) { + this.columnReader = columnReader; + } + + /** + * Sets the parent column for this column. + */ + public void setParentColumn(ColumnVector column) { + this.parentColumn = column; + } + + /** + * Returns the parent column for this column. + */ + public ColumnVector getParentColumn() { + return this.parentColumn; + } + + /** + * Returns the nearest parent column which is an Array column. + */ + public ColumnVector getNearestParentArrayColumn() { + ColumnVector parentCol = this.parentColumn; + while (parentCol != null && !parentCol.isArray()) { + parentCol = parentCol.parentColumn; + } + return parentCol; + } + + /** + * Returns if this ColumnVector has initialized VectorizedColumnReader. + */ + public boolean hasColumnReader() { + return this.columnReader != null; + } + + /** + * Reads `total` values from associated columnReader into this column. + */ + public void readBatch(int total) throws IOException { + if (this.columnReader != null) { + this.columnReader.readBatch(total, this); + } else { + throw new RuntimeException("The reader of this ColumnVector is not initialized yet. " + + "Failed to call readBatch()."); + } + } + /** * Update the dictionary. */ @@ -951,27 +1060,47 @@ protected ColumnVector(int capacity, DataType type, MemoryMode memMode) { DataType childType; int childCapacity = capacity; if (type instanceof ArrayType) { + ArrayType arrayType = (ArrayType)type; + if (arrayType.metadata().contains("defLevel")) { + this.defLevel = (int)arrayType.metadata().getLong("defLevel"); + this.repLevel = (int)arrayType.metadata().getLong("repLevel"); + } childType = ((ArrayType)type).elementType(); + this.defLevel = defLevel; } else { childType = DataTypes.ByteType; childCapacity *= DEFAULT_ARRAY_LENGTH; } this.childColumns = new ColumnVector[1]; this.childColumns[0] = ColumnVector.allocate(childCapacity, childType, memMode); + this.childColumns[0].setParentColumn(this); this.resultArray = new Array(this.childColumns[0]); this.resultStruct = null; } else if (type instanceof StructType) { + StructType structType = (StructType)type; + if (structType.metadata().contains("defLevel")) { + this.defLevel = (int)structType.metadata().getLong("defLevel"); + this.repLevel = (int)structType.metadata().getLong("repLevel"); + } StructType st = (StructType)type; this.childColumns = new ColumnVector[st.fields().length]; for (int i = 0; i < childColumns.length; ++i) { - this.childColumns[i] = ColumnVector.allocate(capacity, st.fields()[i].dataType(), memMode); + int fieldDefLevel = 0; + if (st.fields()[i].metadata().contains("defLevel")) { + fieldDefLevel = (int)st.fields()[i].metadata().getLong("defLevel"); + } + + this.childColumns[i] = + ColumnVector.allocate(capacity, st.fields()[i].dataType(), memMode); + this.childColumns[i].setParentColumn(this); } this.resultArray = null; this.resultStruct = new ColumnarBatch.Row(this.childColumns); } else if (type instanceof CalendarIntervalType) { // Two columns. Months as int. Microseconds as Long. this.childColumns = new ColumnVector[2]; - this.childColumns[0] = ColumnVector.allocate(capacity, DataTypes.IntegerType, memMode); + this.childColumns[0] = + ColumnVector.allocate(capacity, DataTypes.IntegerType, memMode); this.childColumns[1] = ColumnVector.allocate(capacity, DataTypes.LongType, memMode); this.resultArray = null; this.resultStruct = new ColumnarBatch.Row(this.childColumns); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index 8cece73faa4b..360c7fb05fe7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -79,6 +79,8 @@ public static ColumnarBatch allocate(StructType schema, MemoryMode memMode, int return new ColumnarBatch(schema, maxRows, memMode); } + public int numFields() { return columns.length; } + /** * Called to close all the columns in this batch. It is not valid to access the data after * calling this. This must be called at the end to clean up memory allocations. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 7fb7617050f2..89998b0b4cd2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -403,53 +403,53 @@ private void reserveInternal(int newCapacity) { int[] newLengths = new int[newCapacity]; int[] newOffsets = new int[newCapacity]; if (this.arrayLengths != null) { - System.arraycopy(this.arrayLengths, 0, newLengths, 0, elementsAppended); - System.arraycopy(this.arrayOffsets, 0, newOffsets, 0, elementsAppended); + System.arraycopy(this.arrayLengths, 0, newLengths, 0, this.arrayLengths.length); + System.arraycopy(this.arrayOffsets, 0, newOffsets, 0, this.arrayOffsets.length); } arrayLengths = newLengths; arrayOffsets = newOffsets; } else if (type instanceof BooleanType) { if (byteData == null || byteData.length < newCapacity) { byte[] newData = new byte[newCapacity]; - if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended); + if (byteData != null) System.arraycopy(byteData, 0, newData, 0, byteData.length); byteData = newData; } } else if (type instanceof ByteType) { if (byteData == null || byteData.length < newCapacity) { byte[] newData = new byte[newCapacity]; - if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended); + if (byteData != null) System.arraycopy(byteData, 0, newData, 0, byteData.length); byteData = newData; } } else if (type instanceof ShortType) { if (shortData == null || shortData.length < newCapacity) { short[] newData = new short[newCapacity]; - if (shortData != null) System.arraycopy(shortData, 0, newData, 0, elementsAppended); + if (shortData != null) System.arraycopy(shortData, 0, newData, 0, shortData.length); shortData = newData; } } else if (type instanceof IntegerType || type instanceof DateType || DecimalType.is32BitDecimalType(type)) { if (intData == null || intData.length < newCapacity) { int[] newData = new int[newCapacity]; - if (intData != null) System.arraycopy(intData, 0, newData, 0, elementsAppended); + if (intData != null) System.arraycopy(intData, 0, newData, 0, intData.length); intData = newData; } } else if (type instanceof LongType || type instanceof TimestampType || DecimalType.is64BitDecimalType(type)) { if (longData == null || longData.length < newCapacity) { long[] newData = new long[newCapacity]; - if (longData != null) System.arraycopy(longData, 0, newData, 0, elementsAppended); + if (longData != null) System.arraycopy(longData, 0, newData, 0, longData.length); longData = newData; } } else if (type instanceof FloatType) { if (floatData == null || floatData.length < newCapacity) { float[] newData = new float[newCapacity]; - if (floatData != null) System.arraycopy(floatData, 0, newData, 0, elementsAppended); + if (floatData != null) System.arraycopy(floatData, 0, newData, 0, floatData.length); floatData = newData; } } else if (type instanceof DoubleType) { if (doubleData == null || doubleData.length < newCapacity) { double[] newData = new double[newCapacity]; - if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, elementsAppended); + if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, doubleData.length); doubleData = newData; } } else if (resultStruct != null) { @@ -459,7 +459,7 @@ private void reserveInternal(int newCapacity) { } byte[] newNulls = new byte[newCapacity]; - if (nulls != null) System.arraycopy(nulls, 0, newNulls, 0, elementsAppended); + if (nulls != null) System.arraycopy(nulls, 0, newNulls, 0, nulls.length); nulls = newNulls; capacity = newCapacity; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 5b9af26dfc4f..739eb8f8c1b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -167,11 +167,11 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { /** Hive outputs fields of structs slightly differently than top level attributes. */ def toHiveStructString(a: (Any, DataType)): String = a match { - case (struct: Row, StructType(fields)) => + case (struct: Row, StructType(fields, _)) => struct.toSeq.zip(fields).map { case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" }.mkString("{", ",", "}") - case (seq: Seq[_], ArrayType(typ, _)) => + case (seq: Seq[_], ArrayType(typ, _, _)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") case (map: Map[_, _], MapType(kType, vType, _)) => map.map { @@ -185,11 +185,11 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { } a match { - case (struct: Row, StructType(fields)) => + case (struct: Row, StructType(fields, _)) => struct.toSeq.zip(fields).map { case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" }.mkString("{", ",", "}") - case (seq: Seq[_], ArrayType(typ, _)) => + case (seq: Seq[_], ArrayType(typ, _, _)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") case (map: Map[_, _], MapType(kType, vType, _)) => map.map { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 44cfbb9fbd81..2636afd31353 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -358,7 +358,7 @@ private[sql] class JDBCRDD( case StringType => StringConversion case TimestampType => TimestampConversion case BinaryType => BinaryConversion - case ArrayType(et, _) => ArrayConversion(getConversions(et, metadata)) + case ArrayType(et, _, _) => ArrayConversion(getConversions(et, metadata)) case _ => throw new IllegalArgumentException(s"Unsupported type ${dt.simpleString}") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 065c8572b06a..6160ece02385 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -193,7 +193,7 @@ object JdbcUtils extends Logging { case TimestampType => stmt.setTimestamp(i + 1, row.getAs[java.sql.Timestamp](i)) case DateType => stmt.setDate(i + 1, row.getAs[java.sql.Date](i)) case t: DecimalType => stmt.setBigDecimal(i + 1, row.getDecimal(i)) - case ArrayType(et, _) => + case ArrayType(et, _, _) => // remove type length parameters from end of type name val typeName = getJdbcType(et, dialect).databaseTypeDefinition .toLowerCase.split("\\(")(0) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala index 579b036417d2..9f0ee3ba7db0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala @@ -177,14 +177,14 @@ private[sql] object InferSchema { * Convert NullType to StringType and remove StructTypes with no fields */ private def canonicalizeType(tpe: DataType): Option[DataType] = tpe match { - case at @ ArrayType(elementType, _) => + case at @ ArrayType(elementType, _, _) => for { canonicalType <- canonicalizeType(elementType) } yield { at.copy(canonicalType) } - case StructType(fields) => + case StructType(fields, _) => val canonicalFields: Array[StructField] = for { field <- fields if field.name.length > 0 @@ -229,9 +229,9 @@ private[sql] object InferSchema { shouldHandleCorruptRecord: Boolean): (DataType, DataType) => DataType = { // Since we support array of json objects at the top level, // we need to check the element type and find the root level data type. - case (ArrayType(ty1, _), ty2) => + case (ArrayType(ty1, _, _), ty2) => compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord)(ty1, ty2) - case (ty1, ArrayType(ty2, _)) => + case (ty1, ArrayType(ty2, _, _)) => compatibleRootType(columnNameOfCorruptRecords, shouldHandleCorruptRecord)(ty1, ty2) // If we see any other data type at the root level, we get records that cannot be // parsed. So, we use the struct as the data type and add the corrupt field to the schema. @@ -270,7 +270,7 @@ private[sql] object InferSchema { DecimalType(range + scale, scale) } - case (StructType(fields1), StructType(fields2)) => + case (StructType(fields1, _), StructType(fields2, _)) => // Both fields1 and fields2 should be sorted by name, since inferField performs sorting. // Therefore, we can take advantage of the fact that we're merging sorted lists and skip // building a hash map or performing additional sorting. @@ -309,8 +309,8 @@ private[sql] object InferSchema { } StructType(newFields.toArray(emptyStructFieldArray)) - case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => - ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) + case (ArrayType(eltType1, containsNull1, _), ArrayType(eltType2, containsNull2, _)) => + ArrayType(compatibleType(eltType1, eltType2), containsNull1 || containsNull2) // The case that given `DecimalType` is capable of given `IntegralType` is handled in // `findTightestCommonTypeOfTwo`. Both cases below will be executed only when diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala index 8b920ecafaee..fadc3fc0c339 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala @@ -53,7 +53,7 @@ private[sql] object JacksonGenerator { // an ArrayData at here, instead of a Vector. case (udt: UserDefinedType[_], v) => valWriter(udt.sqlType, v) - case (ArrayType(ty, _), v: ArrayData) => + case (ArrayType(ty, _, _), v: ArrayData) => gen.writeStartArray() v.foreach(ty, (_, value) => valWriter(ty, value)) gen.writeEndArray() @@ -66,7 +66,7 @@ private[sql] object JacksonGenerator { }) gen.writeEndObject() - case (StructType(ty), v: InternalRow) => + case (StructType(ty, _), v: InternalRow) => gen.writeStartObject() var i = 0 while (i < ty.length) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala index 733fcbfea101..1ff7152147a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala @@ -64,7 +64,7 @@ object JacksonParser extends Logging { // in such an array as a row convertArray(factory, parser, st) - case (START_OBJECT, ArrayType(st, _)) => + case (START_OBJECT, ArrayType(st, _, _)) => // the business end of SPARK-3308: // when an object is found but an array is requested just wrap it in a list convertField(factory, parser, st) :: Nil @@ -181,7 +181,7 @@ object JacksonParser extends Logging { case (START_OBJECT, st: StructType) => convertObject(factory, parser, st) - case (START_ARRAY, ArrayType(st, _)) => + case (START_ARRAY, ArrayType(st, _, _)) => convertArray(factory, parser, st) case (START_OBJECT, MapType(StringType, kt, _)) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 2cce3db9a692..bed8a35bd5f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -255,8 +255,7 @@ private[sql] class ParquetFileFormat override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = { val conf = sparkSession.sessionState.conf conf.parquetVectorizedReaderEnabled && conf.wholeStageEnabled && - schema.length <= conf.wholeStageMaxNumFields && - schema.forall(_.dataType.isInstanceOf[AtomicType]) + schema.length <= conf.wholeStageMaxNumFields } override def isSplitable( @@ -332,7 +331,8 @@ private[sql] class ParquetFileFormat val resultSchema = StructType(partitionSchema.fields ++ requiredSchema.fields) val enableVectorizedReader: Boolean = sparkSession.sessionState.conf.parquetVectorizedReaderEnabled && - resultSchema.forall(_.dataType.isInstanceOf[AtomicType]) + !resultSchema.existsRecursively(_.isInstanceOf[MapType]) + // Whole stage codegen (PhysicalRDD) is able to deal with batches directly val returningBatch = supportBatch(sparkSession, resultSchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 624081250113..3c18ed4d56f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -192,7 +192,7 @@ private[sql] object ParquetFilters { * fields. */ private def getFieldMap(dataType: DataType): Array[(String, DataType)] = dataType match { - case StructType(fields) => + case StructType(fields, _) => fields.filter { f => !f.metadata.contains(StructType.metadataKeyForOptionalField) || !f.metadata.getBoolean(StructType.metadataKeyForOptionalField) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index 9dad59647e0d..1f57a650a52e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -488,7 +488,7 @@ private[parquet] class ParquetRowConverter( case (t: GroupType, _) if t.getFieldCount > 1 => true case (t: GroupType, _) if t.getFieldCount == 1 && t.getName == "array" => true case (t: GroupType, _) if t.getFieldCount == 1 && t.getName == parentName + "_tuple" => true - case (t: GroupType, StructType(Array(f))) if f.name == t.getFieldName(0) => true + case (t: GroupType, StructType(Array(f), _)) if f.name == t.getFieldName(0) => true case _ => false } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index bcf535d45521..78f885c6d434 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -71,35 +71,60 @@ private[parquet] class ParquetSchemaConverter( /** * Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL [[StructType]]. */ - def convert(parquetSchema: MessageType): StructType = convert(parquetSchema.asGroupType()) + def convert(parquetSchema: MessageType): StructType = { + convert(parquetSchema.asGroupType(), parquetSchema, Seq.empty[String]) + } - private def convert(parquetSchema: GroupType): StructType = { + private def convert( + parquetSchema: GroupType, + messageType: MessageType, + path: Seq[String]): StructType = { val fields = parquetSchema.getFields.asScala.map { field => field.getRepetition match { case OPTIONAL => - StructField(field.getName, convertField(field), nullable = true) + StructField(field.getName, convertField(field, messageType, path), nullable = true) case REQUIRED => - StructField(field.getName, convertField(field), nullable = false) + StructField(field.getName, convertField(field, messageType, path), nullable = false) case REPEATED => + val builder = new MetadataBuilder() + val curPath = path ++ Seq(field.getName) + val defLevel = messageType.getMaxDefinitionLevel(curPath: _*) + val repLevel = messageType.getMaxRepetitionLevel(curPath: _*) + val metadata = builder.putLong("defLevel", defLevel).putLong("repLevel", repLevel).build() + // A repeated field that is neither contained by a `LIST`- or `MAP`-annotated group nor // annotated by `LIST` or `MAP` should be interpreted as a required list of required // elements where the element type is the type of the field. - val arrayType = ArrayType(convertField(field), containsNull = false) - StructField(field.getName, arrayType, nullable = false) + val arrayType = ArrayType(convertField(field, messageType, path), containsNull = false, + metadata = metadata) + StructField(field.getName, arrayType, nullable = false, metadata = metadata) } } - StructType(fields) + if (path.isEmpty) { + StructType(fields) + } else { + val builder = new MetadataBuilder() + val defLevel = messageType.getMaxDefinitionLevel(path: _*) + val repLevel = messageType.getMaxRepetitionLevel(path: _*) + val metadata = builder.putLong("defLevel", defLevel).putLong("repLevel", repLevel).build() + StructType(fields.toArray, metadata) + } } /** * Converts a Parquet [[Type]] to a Spark SQL [[DataType]]. */ - def convertField(parquetType: Type): DataType = parquetType match { + def convertField( + parquetType: Type, + messageType: MessageType = null, + path: Seq[String] = Seq.empty[String]): DataType = parquetType match { case t: PrimitiveType => convertPrimitiveField(t) - case t: GroupType => convertGroupField(t.asGroupType()) + case t: GroupType => + val curPath = path ++ Seq(t.getName()) + convertGroupField(t.asGroupType(), messageType, curPath) } private def convertPrimitiveField(field: PrimitiveType): DataType = { @@ -190,8 +215,11 @@ private[parquet] class ParquetSchemaConverter( } } - private def convertGroupField(field: GroupType): DataType = { - Option(field.getOriginalType).fold(convert(field): DataType) { + private def convertGroupField( + field: GroupType, + messageType: MessageType, + path: Seq[String]): DataType = { + Option(field.getOriginalType).fold(convert(field, messageType, path): DataType) { // A Parquet list is represented as a 3-level structure: // // group (LIST) { @@ -212,13 +240,22 @@ private[parquet] class ParquetSchemaConverter( val repeatedType = field.getType(0) ParquetSchemaConverter.checkConversionRequirement( repeatedType.isRepetition(REPEATED), s"Invalid list type $field") - + val builder = new MetadataBuilder() if (isElementType(repeatedType, field.getName)) { - ArrayType(convertField(repeatedType), containsNull = false) + val defLevel = messageType.getMaxDefinitionLevel(path: _*) + val repLevel = messageType.getMaxRepetitionLevel(path: _*) + val metadata = builder.putLong("defLevel", defLevel).putLong("repLevel", repLevel).build() + ArrayType(convertField(repeatedType, messageType, path), containsNull = false, + metadata = metadata) } else { val elementType = repeatedType.asGroupType().getType(0) val optional = elementType.isRepetition(OPTIONAL) - ArrayType(convertField(elementType), containsNull = optional) + val curPath = path ++ Seq(repeatedType.getName) + val defLevel = messageType.getMaxDefinitionLevel(curPath: _*) + val repLevel = messageType.getMaxRepetitionLevel(curPath: _*) + val metadata = builder.putLong("defLevel", defLevel).putLong("repLevel", repLevel).build() + ArrayType(convertField(elementType, messageType, curPath), containsNull = optional, + metadata = metadata) } // scalastyle:off @@ -242,9 +279,11 @@ private[parquet] class ParquetSchemaConverter( val valueType = keyValueType.getType(1) val valueOptional = valueType.isRepetition(OPTIONAL) + val keyPath = path ++ Seq(keyValueType.getName) + val valuePath = path ++ Seq(keyValueType.getName) MapType( - convertField(keyType), - convertField(valueType), + convertField(keyType, messageType, keyPath), + convertField(valueType, messageType, valuePath), valueContainsNull = valueOptional) case _ => @@ -439,7 +478,7 @@ private[parquet] class ParquetSchemaConverter( // `LIST` structure. This behavior is somewhat a hybrid of parquet-hive and parquet-avro // (1.6.0rc3): the 3-level structure is similar to parquet-hive while the 3rd level element // field name "array" is borrowed from parquet-avro. - case ArrayType(elementType, nullable @ true) if writeLegacyParquetFormat => + case ArrayType(elementType, nullable @ true, _) if writeLegacyParquetFormat => // group (LIST) { // optional group bag { // repeated array; @@ -457,7 +496,7 @@ private[parquet] class ParquetSchemaConverter( // Spark 1.4.x and prior versions convert ArrayType with non-nullable elements into a 2-level // LIST structure. This behavior mimics parquet-avro (1.6.0rc3). Note that this case is // covered by the backwards-compatibility rules implemented in `isElementType()`. - case ArrayType(elementType, nullable @ false) if writeLegacyParquetFormat => + case ArrayType(elementType, nullable @ false, _) if writeLegacyParquetFormat => // group (LIST) { // repeated element; // } @@ -486,7 +525,7 @@ private[parquet] class ParquetSchemaConverter( // ArrayType and MapType (standard mode) // ===================================== - case ArrayType(elementType, containsNull) if !writeLegacyParquetFormat => + case ArrayType(elementType, containsNull, _) if !writeLegacyParquetFormat => // group (LIST) { // repeated group list { // element; @@ -521,7 +560,7 @@ private[parquet] class ParquetSchemaConverter( // Other types // =========== - case StructType(fields) => + case StructType(fields, _) => fields.foldLeft(Types.buildGroup(repetition)) { (builder, field) => builder.addField(convertField(field)) }.named(field.name) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index cf68ed4ec36a..f3d026257d47 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -49,7 +49,7 @@ object EvaluatePython { case DateType | TimestampType => true case _: StructType => true case _: UserDefinedType[_] => true - case ArrayType(elementType, _) => needConversionInPython(elementType) + case ArrayType(elementType, _, _) => needConversionInPython(elementType) case MapType(keyType, valueType, _) => needConversionInPython(keyType) || needConversionInPython(valueType) case _ => false @@ -129,10 +129,10 @@ object EvaluatePython { case (c: String, BinaryType) => c.getBytes(StandardCharsets.UTF_8) case (c, BinaryType) if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c - case (c: java.util.List[_], ArrayType(elementType, _)) => + case (c: java.util.List[_], ArrayType(elementType, _, _)) => new GenericArrayData(c.asScala.map { e => fromJava(e, elementType)}.toArray) - case (c, ArrayType(elementType, _)) if c.getClass.isArray => + case (c, ArrayType(elementType, _, _)) if c.getClass.isArray => new GenericArrayData(c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType))) case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => @@ -141,7 +141,7 @@ object EvaluatePython { val values = keyValues.map(kv => fromJava(kv._2, valueType)).toArray ArrayBasedMapData(keys, values) - case (c, StructType(fields)) if c.getClass.isArray => + case (c, StructType(fields, _)) if c.getClass.isArray => val array = c.asInstanceOf[Array[_]] if (array.length != fields.length) { throw new IllegalStateException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 2d6c3974a833..c78b5c393a19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -68,7 +68,7 @@ private object PostgresDialect extends JdbcDialect { case DoubleType => Some(JdbcType("FLOAT8", Types.DOUBLE)) case t: DecimalType => Some( JdbcType(s"NUMERIC(${t.precision},${t.scale})", java.sql.Types.NUMERIC)) - case ArrayType(et, _) if et.isInstanceOf[AtomicType] => + case ArrayType(et, _, _) if et.isInstanceOf[AtomicType] => getJDBCType(et).map(_.databaseTypeDefinition) .orElse(JdbcUtils.getCommonJDBCType(et).map(_.databaseTypeDefinition)) .map(typeName => JdbcType(s"$typeName[]", java.sql.Types.ARRAY)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 8a980a7eb538..b50f4be8a3f7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -63,7 +63,7 @@ abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext { val actual = converter.convert(MessageTypeParser.parseMessageType(parquetSchema)) val expected = sqlSchema assert( - actual === expected, + DataType.equalsIgnoreCompatibleNullability(actual, expected), s"""Schema mismatch. |Expected schema: ${expected.json} |Actual schema: ${actual.json} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 7e576a865799..9f36d61bdec7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -617,7 +617,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(r1.getString(ordinal) == r2.getString(ordinal), "Seed = " + seed) case CalendarIntervalType => assert(r1.getInterval(ordinal) === r2.get(ordinal).asInstanceOf[CalendarInterval]) - case ArrayType(childType, n) => + case ArrayType(childType, n, _) => val a1 = r1.getArray(ordinal).array val a2 = r2.getList(ordinal).toArray assert(a1.length == a2.length, "Seed = " + seed) @@ -649,7 +649,7 @@ class ColumnarBatchSuite extends SparkFunSuite { } case _ => assert(a1 === a2, "Seed = " + seed) } - case StructType(childFields) => + case StructType(childFields, _) => compareStruct(childFields, r1.getStruct(ordinal, fields.length), r2.getStruct(ordinal), seed) case _ => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 585befe37825..e1daeaaefb91 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -636,7 +636,7 @@ private[hive] trait HiveInspectors { * We can easily map to the Hive built-in object inspector according to the data type. */ def toInspector(dataType: DataType): ObjectInspector = dataType match { - case ArrayType(tpe, _) => + case ArrayType(tpe, _, _) => ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe)) case MapType(keyType, valueType, _) => ObjectInspectorFactory.getStandardMapObjectInspector( @@ -655,7 +655,7 @@ private[hive] trait HiveInspectors { case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector // TODO decimal precision? case DecimalType() => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector - case StructType(fields) => + case StructType(fields, _) => ObjectInspectorFactory.getStandardStructObjectInspector( java.util.Arrays.asList(fields.map(f => f.name) : _*), java.util.Arrays.asList(fields.map(f => toInspector(f.dataType)) : _*)) @@ -695,7 +695,7 @@ private[hive] trait HiveInspectors { getDecimalWritableConstantObjectInspector(value) case Literal(_, NullType) => getPrimitiveNullWritableConstantObjectInspector - case Literal(value, ArrayType(dt, _)) => + case Literal(value, ArrayType(dt, _, _)) => val listObjectInspector = toInspector(dt) if (value == null) { ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, null) @@ -899,9 +899,9 @@ private[hive] trait HiveInspectors { } def toTypeInfo: TypeInfo = dt match { - case ArrayType(elemType, _) => + case ArrayType(elemType, _, _) => getListTypeInfo(elemType.toTypeInfo) - case StructType(fields) => + case StructType(fields, _) => getStructTypeInfo( java.util.Arrays.asList(fields.map(_.name) : _*), java.util.Arrays.asList(fields.map(_.dataType.toTypeInfo) : _*)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index 9ed357c587c3..836523ff35c8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -401,11 +401,11 @@ private[spark] object HiveUtils extends Logging { ShortType, DateType, TimestampType, BinaryType) protected[sql] def toHiveString(a: (Any, DataType)): String = a match { - case (struct: Row, StructType(fields)) => + case (struct: Row, StructType(fields, _)) => struct.toSeq.zip(fields).map { case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" }.mkString("{", ",", "}") - case (seq: Seq[_], ArrayType(typ, _)) => + case (seq: Seq[_], ArrayType(typ, _, _)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") case (map: Map[_, _], MapType(kType, vType, _)) => map.map { @@ -424,11 +424,11 @@ private[spark] object HiveUtils extends Logging { /** Hive outputs fields of structs slightly differently than top level attributes. */ protected def toHiveStructString(a: (Any, DataType)): String = a match { - case (struct: Row, StructType(fields)) => + case (struct: Row, StructType(fields, _)) => struct.toSeq.zip(fields).map { case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" }.mkString("{", ",", "}") - case (seq: Seq[_], ArrayType(typ, _)) => + case (seq: Seq[_], ArrayType(typ, _, _)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") case (map: Map[_, _], MapType(kType, vType, _)) => map.map { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 3b867bbfa181..fb8c0e0aa4b7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -98,7 +98,7 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { val dataTypes = data.map(_.dataType) def toWritableInspector(dataType: DataType): ObjectInspector = dataType match { - case ArrayType(tpe, _) => + case ArrayType(tpe, _, _) => ObjectInspectorFactory.getStandardListObjectInspector(toWritableInspector(tpe)) case MapType(keyType, valueType, _) => ObjectInspectorFactory.getStandardMapObjectInspector( @@ -116,7 +116,7 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { case DateType => PrimitiveObjectInspectorFactory.writableDateObjectInspector case TimestampType => PrimitiveObjectInspectorFactory.writableTimestampObjectInspector case DecimalType() => PrimitiveObjectInspectorFactory.writableHiveDecimalObjectInspector - case StructType(fields) => + case StructType(fields, _) => ObjectInspectorFactory.getStandardStructObjectInspector( java.util.Arrays.asList(fields.map(f => f.name) : _*), java.util.Arrays.asList(fields.map(f => toWritableInspector(f.dataType)) : _*)) From d5e5a605bf8c6b0bcd510f56dac583dd30629859 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 4 Jul 2016 17:43:39 +0800 Subject: [PATCH 02/17] Remove commented code. --- .../execution/datasources/parquet/VectorizedColumnReader.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 941d5ff5bb98..6a2884c00f60 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -670,8 +670,6 @@ private void updateReptitionInfo( repCount = reptitionMap.get(1); } reptitionMap.put(1, repCount + 1); - // insertArrayForRepetition(column, rowIds, offsets, reptitionMap, total, - // 0, 1); insertRepeatedArray(column, rowIds, offsets, reptitionMap, total, maxRepLevel - 1); } else { // Repeated values. We increase repetition count. From 5c4c1c84606763b78b6b8b774be87862df520f9c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 5 Jul 2016 10:41:59 +0800 Subject: [PATCH 03/17] Fix test. --- .../execution/datasources/parquet/VectorizedColumnReader.java | 1 + .../sql/execution/datasources/parquet/ParquetFileFormat.scala | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 6a2884c00f60..958a59225468 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -614,6 +614,7 @@ private void updateReptitionInfo( if (column.getParentColumn().getDefLevel() == maxDefLevel) { insertRepeatedArray(column, rowIds, offsets, reptitionMap, total, repLevel); offsets.put(maxRepLevel, offset + 1); + prevRepLevel = -1; } else if (defLevel == 0) { // A null record at root level. // Obtain most-top column (repetition level 1). diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 449542906e6b..8182d98d40b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -257,7 +257,8 @@ private[sql] class ParquetFileFormat override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = { val conf = sparkSession.sessionState.conf conf.parquetVectorizedReaderEnabled && conf.wholeStageEnabled && - schema.length <= conf.wholeStageMaxNumFields + schema.length <= conf.wholeStageMaxNumFields && + schema.forall(!_.dataType.isInstanceOf[MapType]) } override def isSplitable( From 114a69b30bee0bb80f9028205fc020387c29ac24 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 5 Jul 2016 11:43:27 +0800 Subject: [PATCH 04/17] Fix test. --- .../apache/spark/sql/execution/vectorized/ColumnVector.java | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index 6ef56ae0605c..6cf38a7e227a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -200,9 +200,7 @@ public short getShort(int ordinal) { public long getLong(int ordinal) { return data.getLong(offset + ordinal); } @Override - public float getFloat(int ordinal) { - throw new UnsupportedOperationException(); - } + public float getFloat(int ordinal) { return data.getFloat(offset + ordinal); } @Override public double getDouble(int ordinal) { return data.getDouble(offset + ordinal); } From 4dca939b131cdab042f437236b108686d12f9d94 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 6 Jul 2016 12:52:53 +0800 Subject: [PATCH 05/17] For array type 2, don't take repeated type in computing repetition level. --- .../datasources/parquet/ParquetSchemaConverter.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index 78f885c6d434..448ad72f2360 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -251,8 +251,8 @@ private[parquet] class ParquetSchemaConverter( val elementType = repeatedType.asGroupType().getType(0) val optional = elementType.isRepetition(OPTIONAL) val curPath = path ++ Seq(repeatedType.getName) - val defLevel = messageType.getMaxDefinitionLevel(curPath: _*) - val repLevel = messageType.getMaxRepetitionLevel(curPath: _*) + val defLevel = messageType.getMaxDefinitionLevel(path: _*) + val repLevel = messageType.getMaxRepetitionLevel(path: _*) val metadata = builder.putLong("defLevel", defLevel).putLong("repLevel", repLevel).build() ArrayType(convertField(elementType, messageType, curPath), containsNull = optional, metadata = metadata) From ded41b2726a0905fded1a74c5005732d636e7d20 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 6 Jul 2016 13:04:20 +0800 Subject: [PATCH 06/17] supportBatch should check unsupported MapType recursively. --- .../sql/execution/datasources/parquet/ParquetFileFormat.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 8182d98d40b3..cbce77504f57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -258,7 +258,7 @@ private[sql] class ParquetFileFormat val conf = sparkSession.sessionState.conf conf.parquetVectorizedReaderEnabled && conf.wholeStageEnabled && schema.length <= conf.wholeStageMaxNumFields && - schema.forall(!_.dataType.isInstanceOf[MapType]) + !schema.existsRecursively(_.isInstanceOf[MapType]) } override def isSplitable( From bf61a752aa2801449cbc555b6dc9bf59e7426875 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 6 Jul 2016 16:24:25 +0800 Subject: [PATCH 07/17] Definition level of array type 2 should include repeated type. --- .../execution/datasources/parquet/ParquetSchemaConverter.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index 448ad72f2360..dc8481c59f2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -251,7 +251,7 @@ private[parquet] class ParquetSchemaConverter( val elementType = repeatedType.asGroupType().getType(0) val optional = elementType.isRepetition(OPTIONAL) val curPath = path ++ Seq(repeatedType.getName) - val defLevel = messageType.getMaxDefinitionLevel(path: _*) + val defLevel = messageType.getMaxDefinitionLevel(curPath: _*) val repLevel = messageType.getMaxRepetitionLevel(path: _*) val metadata = builder.putLong("defLevel", defLevel).putLong("repLevel", repLevel).build() ArrayType(convertField(elementType, messageType, curPath), containsNull = optional, From d7194808febc08e35dfe939472d18dcff979fbab Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 7 Jul 2016 15:25:22 +0800 Subject: [PATCH 08/17] Fix test. --- .../datasources/parquet/VectorizedColumnReader.java | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 958a59225468..189e566c98ca 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -611,7 +611,7 @@ private void updateReptitionInfo( offset = offsets.get(maxRepLevel); } - if (column.getParentColumn().getDefLevel() == maxDefLevel) { + if (column.getParentColumn().getDefLevel() == maxDefLevel && maxRepLevel > 0) { insertRepeatedArray(column, rowIds, offsets, reptitionMap, total, repLevel); offsets.put(maxRepLevel, offset + 1); prevRepLevel = -1; @@ -627,6 +627,14 @@ private void updateReptitionInfo( if (rowIds.containsKey(1)) { rowId = rowIds.get(1); } + // Add up previously accumulated count for repetition level 1. + // Otherwise, we will override previous non-null records. + int repCount = 0; + if (reptitionMap.containsKey(1)) { + repCount = reptitionMap.get(1); + } + reptitionMap.put(1, 0); + rowId += repCount; // Insert null record and increase row id. topColumn.putNull(rowId); rowIds.put(1, rowId + 1); From e3f74bdd58a7655e57d7f210dce39f6fa71ba5d9 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 8 Jul 2016 11:51:39 +0800 Subject: [PATCH 09/17] Support getBoolean. --- .../apache/spark/sql/execution/vectorized/ColumnVector.java | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index 6cf38a7e227a..86b6f25be9d4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -181,9 +181,7 @@ public Object[] array() { public boolean isNullAt(int ordinal) { return data.isNullAt(offset + ordinal); } @Override - public boolean getBoolean(int ordinal) { - throw new UnsupportedOperationException(); - } + public boolean getBoolean(int ordinal) { return data.getBoolean(offset + ordinal); } @Override public byte getByte(int ordinal) { return data.getByte(offset + ordinal); } From 5d5e933f4817fb852da8500fcb8151e77c620611 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 11 Jul 2016 23:18:01 +0800 Subject: [PATCH 10/17] Consider more cases when the value is null. --- .../parquet/VectorizedColumnReader.java | 211 +++++++++++++----- 1 file changed, 156 insertions(+), 55 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 189e566c98ca..e5ac36154ac7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -163,6 +163,7 @@ public void readBatch(int total, ColumnVector column) throws IOException { } else { if (total <= 0) break; } + if (leftInPage == 0) { readPage(); leftInPage = (int) (endOfPageValueCount - valuesRead); @@ -574,6 +575,80 @@ private void insertRepeatedArray( } } + private ColumnVector findInnerElementWithDefLevel(ColumnVector column, int defLevel) { + while (true) { + if (column == null) { + return null; + } + ColumnVector parent = column.getParentColumn(); + if (parent != null && parent.getDefLevel() == defLevel) { + ColumnVector outside = parent.getParentColumn(); + if (outside == null || outside.getDefLevel() < defLevel) { + return column; + } + } + column = parent; + } + } + + private ColumnVector findHiddenInnerElementWithDefLevel(ColumnVector column, int defLevel) { + while (true) { + if (column == null) { + return null; + } + ColumnVector parent = column.getParentColumn(); + if (parent != null && parent.getDefLevel() <= defLevel) { + ColumnVector outside = parent.getParentColumn(); + if (outside == null || outside.getDefLevel() < defLevel) { + return column; + } + } + column = parent; + } + } + + private boolean isLegacyArray(ColumnVector column) { + ColumnVector parent = column.getNearestParentArrayColumn(); + if (parent == null) { + return false; + } else if (parent.getRepLevel() <= maxRepLevel && parent.getDefLevel() < maxDefLevel) { + return true; + } + return false; + } + + private void increaseRowId(Map rowIds, int level) { + int rowId = 0; + if (rowIds.containsKey(level)) { + rowId = rowIds.get(level); + } + rowIds.put(level, rowId + 1); + } + + private void insertNullRecord( + ColumnVector column, + Map rowIds, + Map reptitionMap) { + int repLevel = column.getRepLevel(); + + if (repLevel == 0) { + repLevel = 1; + } + + int rowId = 0; + if (rowIds.containsKey(repLevel)) { + rowId = rowIds.get(repLevel); + } + + if (reptitionMap.containsKey(repLevel)) { + rowId += reptitionMap.get(repLevel); + reptitionMap.put(repLevel, 0); + } + + column.putNull(rowId); + rowIds.put(repLevel, rowId + 1); + } + /** * Reads repetition level for each value and updates length and offset info for above columns, * recursively. @@ -587,21 +662,25 @@ private void updateReptitionInfo( // Keeps repetition levels and corresponding repetition counts. Map reptitionMap = new HashMap(); + for (int i = 0; i <= maxRepLevel; i++) { + reptitionMap.put(i, 0); + } + if (column.getParentColumn() != null) { - int prevRepLevel = -1; + boolean newBeginning = true; for (int i = 0; i < valuesReadInPage; i++) { int repLevel = repetitionLevelColumn.nextInt(); int defLevel = definitionLevelColumn.nextInt(); - if (prevRepLevel >= 0) { + if (!newBeginning) { // When a new record begins at lower repetition level, // we insert array into repeated column. if (repLevel < maxRepLevel) { insertRepeatedArray(column, rowIds, offsets, reptitionMap, total, repLevel); } } - prevRepLevel = repLevel; + newBeginning = false; // When definition level is less than max definition level, // there is a null value. @@ -611,66 +690,88 @@ private void updateReptitionInfo( offset = offsets.get(maxRepLevel); } - if (column.getParentColumn().getDefLevel() == maxDefLevel && maxRepLevel > 0) { - insertRepeatedArray(column, rowIds, offsets, reptitionMap, total, repLevel); - offsets.put(maxRepLevel, offset + 1); - prevRepLevel = -1; - } else if (defLevel == 0) { - // A null record at root level. - // Obtain most-top column (repetition level 1). - ColumnVector topColumn = column.getParentColumn(); - while (topColumn.getParentColumn() != null) { - topColumn = topColumn.getParentColumn(); - } - // Get its current row id. - int rowId = 0; - if (rowIds.containsKey(1)) { - rowId = rowIds.get(1); - } - // Add up previously accumulated count for repetition level 1. - // Otherwise, we will override previous non-null records. - int repCount = 0; - if (reptitionMap.containsKey(1)) { - repCount = reptitionMap.get(1); - } - reptitionMap.put(1, 0); - rowId += repCount; - // Insert null record and increase row id. - topColumn.putNull(rowId); - rowIds.put(1, rowId + 1); - - // Increse row id at later repetition levels. - for (int j = 2; j <= maxRepLevel; j++) { - rowId = 0; - if (rowIds.containsKey(j)) { - rowId = rowIds.get(j); + // The null value is defined at the root level. + // Insert a null record. + if (repLevel == 0 && defLevel == 0) { + ColumnVector parent = column.getParentColumn(); + if (parent != null && parent.getDefLevel() == maxDefLevel + && parent.getRepLevel() == maxRepLevel) { + // A repeated element at root level. + // E.g., The repeatedPrimitive at the following schema. + // Going to insert an empty record. + // messageType: message spark_schema { + // optional int32 optionalPrimitive; + // required int32 requiredPrimitive; + // + // repeated int32 repeatedPrimitive; + // + // optional group optionalMessage { + // optional int32 someId; + // } + // required group requiredMessage { + // optional int32 someId; + // } + // repeated group repeatedMessage { + // optional int32 someId; + // } + // } + insertRepeatedArray(column, rowIds, offsets, reptitionMap, total, repLevel); + } else { + // Obtain most outside column. + ColumnVector topColumn = column.getParentColumn(); + while (topColumn.getParentColumn() != null) { + topColumn = topColumn.getParentColumn(); } - rowIds.put(j, rowId + 1); - } + insertNullRecord(topColumn, rowIds, reptitionMap); + } // Move to next offset in max repetition level as we processed the current value. offsets.put(maxRepLevel, offset + 1); - - prevRepLevel = -1; - } else if (repLevel == maxRepLevel) { - // A null value at max repetition level. - // This null value is repeated in a wrapping group. Simply increase repetition count. - int repCount = 0; - if (reptitionMap.containsKey(repLevel)) { - repCount = reptitionMap.get(repLevel); - } - reptitionMap.put(repLevel, repCount + 1); + newBeginning = true; + } else if (isLegacyArray(column) && + column.getNearestParentArrayColumn().getDefLevel() == defLevel) { + // For a legacy array, if a null is defined at the repeated group column, it actually + // means an element with null value. + + reptitionMap.put(maxRepLevel, reptitionMap.get(maxRepLevel) + 1); + } else if (!column.getParentColumn().isArray() && + column.getParentColumn().getDefLevel() == defLevel) { + // A null element defined in the wrapping non-repeated group. + increaseRowId(rowIds, 1); } else { - // Null value at definition level > 0. - int repCount = 0; - if (reptitionMap.containsKey(maxRepLevel)) { - repCount = reptitionMap.get(maxRepLevel); + // An empty element defined in outside group. + // E.g., the element in the following schema. + // messageType: message spark_schema { + // required int32 index; + // optional group col { + // optional float f1; + // optional group f2 (LIST) { + // repeated group list { + // optional boolean element; + // } + // } + // } + // } + ColumnVector parent = findInnerElementWithDefLevel(column, defLevel); + if (parent != null) { + // Found the group with the same definition level. + // Insert a null record at definition level. + // E.g, R=0, D=1 for above schema. + insertNullRecord(parent, rowIds, reptitionMap); + offsets.put(maxRepLevel, offset + 1); + newBeginning = true; + } else { + // Found the group with lower definition level. + // Insert an empty record. + // E.g, R=0, D=2 for above schema. + parent = findHiddenInnerElementWithDefLevel(column, defLevel); + insertRepeatedArray(column, rowIds, offsets, reptitionMap, total, repLevel); + offsets.put(maxRepLevel, offset + 1); + newBeginning = true; } - reptitionMap.put(maxRepLevel, repCount + 1); } } else { // Determine the repetition level of non-null values. - // A new record begins with non-null value. if (maxRepLevel == 0) { // A required record at root level. @@ -690,7 +791,7 @@ private void updateReptitionInfo( } } } - if (prevRepLevel >= 0) { + if (!newBeginning) { insertRepeatedArray(column, rowIds, offsets, reptitionMap, total, 0); } } From 42f53de2af894f961468300b250907a7775e9aac Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 12 Jul 2016 08:52:22 +0800 Subject: [PATCH 11/17] Add more comments. --- .../parquet/VectorizedColumnReader.java | 64 +++++++++++++++---- 1 file changed, 53 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index e5ac36154ac7..bcd049515fa5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -141,12 +141,12 @@ public void readBatch(int total, ColumnVector column) throws IOException { while (true) { // Compute the number of values we want to read in this page. int leftInPage = (int) (endOfPageValueCount - valuesRead); - // When we reach the end of this page, we update repetition info of this column + // When we reach the end of this page, we construct nested records of this column // and then read next page. if (leftInPage == 0) { - // Update repetition info for this column. + // Constructs nested records if needed. if (valuesReadInPage > 0 && isNestedColumn) { - updateReptitionInfo(column, rowIds, offsets, valuesReadInPage, total); + constructNestedRecords(column, rowIds, offsets, valuesReadInPage, total); if (rowIds.containsKey(1)) { repeatedRowId = rowIds.get(1); } @@ -507,7 +507,14 @@ private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset) thr } /** - * Inserts arrays into parent repeated columns. + * Inserts records into parent columns of a column. These parent columns are repeated columns. As + * the real data are read into the column, we only need to insert array into its repeated columns. + * @param column The ColumnVector which the data in the page are read into. + * @param rowIds Mapping between repetition levels and their current row ids for constructing. + * @param offsets The beginning offsets in columns which we use to construct nested records. + * @param reptitionMap Mapping between repetition levels and their corresponding counts. + * @param total The total number of rows to construct. + * @param repLevel The current repetition level. */ private void insertRepeatedArray( ColumnVector column, @@ -522,10 +529,9 @@ private void insertRepeatedArray( parentRepeatedColumn = parentRepeatedColumn.getNearestParentArrayColumn(); if (parentRepeatedColumn != null) { int parentColRepLevel = parentRepeatedColumn.getRepLevel(); - // Only process the parent columns whose repetition levels are equal to or more than - // the given repetition level (less than or equal to max repetition level). - // E.g., when the current repetition level is 1 and max repetition level us 2, - // we only add arrays into the column whose repetition level is 1. + // The current repetition level means the beginning level of the current value. Thus, + // we only need to insert array into the parent columns whose repetition levels are + // equal to or more than the given repetition level. if (parentColRepLevel >= repLevel) { // Current row id at this column. int rowId = 0; @@ -562,6 +568,7 @@ private void insertRepeatedArray( reptitionMap.put(curRepLevel - 1, nextRepCount + 1); } + // In vectorization, the most outside repeated element is at the reptition 1. if (curRepLevel == 1 && rowId == total) { return; } @@ -575,6 +582,13 @@ private void insertRepeatedArray( } } + /** + * Finds the outside element of an inner element which is defined as Catalyst DataType, + * with the specified definition level. + * @param column The column as the beginning level for looking up the inner element. + * @param defLevel The specified definition level. + * @return the column which is the outside group element of the inner element. + */ private ColumnVector findInnerElementWithDefLevel(ColumnVector column, int defLevel) { while (true) { if (column == null) { @@ -591,6 +605,13 @@ private ColumnVector findInnerElementWithDefLevel(ColumnVector column, int defLe } } + /** + * Finds the outside element of the inner element which is not defined as Catalyst DataType, + * with the specified definition level. + * @param column The column as the beginning level for looking up the inner element. + * @param defLevel The specified definition level. + * @return the column which is the outside group element of the inner element. + */ private ColumnVector findHiddenInnerElementWithDefLevel(ColumnVector column, int defLevel) { while (true) { if (column == null) { @@ -607,6 +628,11 @@ private ColumnVector findHiddenInnerElementWithDefLevel(ColumnVector column, int } } + /** + * Checks if the given column is a legacy array in Parquet schema. + * @param column The column we want to check if it is legacy array. + * @return whether the given column is a legacy array in Parquet schema. + */ private boolean isLegacyArray(ColumnVector column) { ColumnVector parent = column.getNearestParentArrayColumn(); if (parent == null) { @@ -617,6 +643,11 @@ private boolean isLegacyArray(ColumnVector column) { return false; } + /** + * Increases row id for a specified repetition level. + * @param rowIds Mapping between repetition levels and their current row ids for constructing. + * @param level repetition level. + */ private void increaseRowId(Map rowIds, int level) { int rowId = 0; if (rowIds.containsKey(level)) { @@ -625,6 +656,12 @@ private void increaseRowId(Map rowIds, int level) { rowIds.put(level, rowId + 1); } + /** + * Inserts a null record at specified column. + * @param column The ColumnVector which the data in the page are read into. + * @param rowIds Mapping between repetition levels and their current row ids for constructing. + * @param reptitionMap Mapping between repetition levels and their corresponding counts. + */ private void insertNullRecord( ColumnVector column, Map rowIds, @@ -650,10 +687,15 @@ private void insertNullRecord( } /** - * Reads repetition level for each value and updates length and offset info for above columns, - * recursively. + * Iterates the values of definition and repetition levels for the values read in the page, + * and constructs nested records accordingly. + * @param column The ColumnVector which the data in the page are read into. + * @param rowIds Mapping between repetition levels and their current row ids for constructing. + * @param offsets The beginning offsets in columns which we use to construct nested records. + * @param valuesReadInPage The number of values read in the current page. + * @param total The total number of rows to construct. */ - private void updateReptitionInfo( + private void constructNestedRecords( ColumnVector column, Map rowIds, Map offsets, From 60f2d7cd59e34a57bbdfcd606a7cf975726ad846 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 12 Jul 2016 20:17:41 +0800 Subject: [PATCH 12/17] Fix null capacity issue. --- .../parquet/VectorizedColumnReader.java | 52 +++++++++++-------- .../execution/vectorized/ColumnVector.java | 6 +++ .../vectorized/OffHeapColumnVector.java | 9 +++- .../vectorized/OnHeapColumnVector.java | 13 +++-- 4 files changed, 54 insertions(+), 26 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index bcd049515fa5..055f93fbf553 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -165,7 +165,10 @@ public void readBatch(int total, ColumnVector column) throws IOException { } if (leftInPage == 0) { - readPage(); + boolean pageExists = readPage(); + if (!pageExists) { + return; + } leftInPage = (int) (endOfPageValueCount - valuesRead); } int num = Math.min(total, leftInPage); @@ -449,30 +452,35 @@ private void readFixedLenByteArrayBatch(int rowId, int num, } } - private void readPage() throws IOException { + private boolean readPage() throws IOException { DataPage page = pageReader.readPage(); - // TODO: Why is this a visitor? - page.accept(new DataPage.Visitor() { - @Override - public Void visit(DataPageV1 dataPageV1) { - try { - readPageV1(dataPageV1); - return null; - } catch (IOException e) { - throw new RuntimeException(e); + if (page == null) { + return false; + } else { + // TODO: Why is this a visitor? + page.accept(new DataPage.Visitor() { + @Override + public Void visit(DataPageV1 dataPageV1) { + try { + readPageV1(dataPageV1); + return null; + } catch (IOException e) { + throw new RuntimeException(e); + } } - } - @Override - public Void visit(DataPageV2 dataPageV2) { - try { - readPageV2(dataPageV2); - return null; - } catch (IOException e) { - throw new RuntimeException(e); + @Override + public Void visit(DataPageV2 dataPageV2) { + try { + readPageV2(dataPageV2); + return null; + } catch (IOException e) { + throw new RuntimeException(e); + } } - } - }); + }); + return true; + } } private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset) throws IOException { @@ -549,6 +557,7 @@ private void insertRepeatedArray( offset = offsets.get(curRepLevel); } + parentRepeatedColumn.reserve(rowId + 1); parentRepeatedColumn.putArray(rowId, offset, repCount); offset += repCount; @@ -682,6 +691,7 @@ private void insertNullRecord( reptitionMap.put(repLevel, 0); } + column.reserve(rowId + 1); column.putNull(rowId); rowIds.put(repLevel, rowId + 1); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index 86b6f25be9d4..783bc6d5d113 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -296,6 +296,11 @@ public void reserve(int requiredCapacity) { */ protected abstract void reserveInternal(int capacity); + /** + * Ensures that there is enough storage to store null information. + */ + protected abstract void reserveNulls(int capacity); + /** * Returns the number of nulls in this column. */ @@ -1054,6 +1059,7 @@ public ColumnVector reserveDictionaryIds(int capacity) { dictionaryIds.reset(); dictionaryIds.reserve(capacity); } + reserveNulls(capacity); return dictionaryIds; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index 913a05a0aa0e..f708f8c9e17d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -445,8 +445,13 @@ protected void reserveInternal(int newCapacity) { } else { throw new RuntimeException("Unhandled " + type); } - this.nulls = Platform.reallocateMemory(nulls, elementsAppended, newCapacity); - Platform.setMemory(nulls + elementsAppended, (byte)0, newCapacity - elementsAppended); + reserveNulls(newCapacity); capacity = newCapacity; } + + @Override + protected void reserveNulls(int capacity) { + this.nulls = Platform.reallocateMemory(nulls, elementsAppended, capacity); + Platform.setMemory(nulls + elementsAppended, (byte)0, capacity - elementsAppended); + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 5542b4d82a59..48ac775c409a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -454,10 +454,17 @@ protected void reserveInternal(int newCapacity) { throw new RuntimeException("Unhandled " + type); } - byte[] newNulls = new byte[newCapacity]; - if (nulls != null) System.arraycopy(nulls, 0, newNulls, 0, nulls.length); - nulls = newNulls; + reserveNulls(newCapacity); capacity = newCapacity; } + + @Override + protected void reserveNulls(int capacity) { + if (nulls == null || nulls.length < capacity) { + byte[] newNulls = new byte[capacity]; + if (nulls != null) System.arraycopy(nulls, 0, newNulls, 0, nulls.length); + nulls = newNulls; + } + } } From 1b37fe810a9d8971e626fdc0655063c33afce815 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 13 Jul 2016 16:33:29 +0800 Subject: [PATCH 13/17] Use primitive array instead of HashMap. --- .../parquet/VectorizedColumnReader.java | 132 ++++++------------ .../execution/vectorized/ColumnVector.java | 21 ++- 2 files changed, 60 insertions(+), 93 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 055f93fbf553..e4668c6597bb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -135,8 +135,8 @@ public void readBatch(int total, ColumnVector column) throws IOException { int valuesReadInPage = 0; int repeatedRowId = 0; - Map rowIds = new HashMap(); - Map offsets = new HashMap(); + int[] rowIds = new int[maxRepLevel + 2]; + int[] offsets = new int[maxRepLevel + 2]; while (true) { // Compute the number of values we want to read in this page. @@ -147,9 +147,7 @@ public void readBatch(int total, ColumnVector column) throws IOException { // Constructs nested records if needed. if (valuesReadInPage > 0 && isNestedColumn) { constructNestedRecords(column, rowIds, offsets, valuesReadInPage, total); - if (rowIds.containsKey(1)) { - repeatedRowId = rowIds.get(1); - } + repeatedRowId = rowIds[1]; valuesReadInPage = 0; } } @@ -520,15 +518,15 @@ private void initDataReader(Encoding dataEncoding, byte[] bytes, int offset) thr * @param column The ColumnVector which the data in the page are read into. * @param rowIds Mapping between repetition levels and their current row ids for constructing. * @param offsets The beginning offsets in columns which we use to construct nested records. - * @param reptitionMap Mapping between repetition levels and their corresponding counts. + * @param repetitions Mapping between repetition levels and their corresponding counts. * @param total The total number of rows to construct. * @param repLevel The current repetition level. */ private void insertRepeatedArray( ColumnVector column, - Map rowIds, - Map offsets, - Map reptitionMap, + int[] rowIds, + int[] offsets, + int[] repetitions, int total, int repLevel) throws IOException { ColumnVector parentRepeatedColumn = column; @@ -541,44 +539,21 @@ private void insertRepeatedArray( // we only need to insert array into the parent columns whose repetition levels are // equal to or more than the given repetition level. if (parentColRepLevel >= repLevel) { - // Current row id at this column. - int rowId = 0; - if (rowIds.containsKey(curRepLevel)) { - rowId = rowIds.get(curRepLevel); - } - // Repetition count. - int repCount = 0; - if (reptitionMap.containsKey(curRepLevel)) { - repCount = reptitionMap.get(curRepLevel); - } - // Offset of values. - int offset = 0; - if (offsets.containsKey(curRepLevel)) { - offset = offsets.get(curRepLevel); - } - - parentRepeatedColumn.reserve(rowId + 1); - parentRepeatedColumn.putArray(rowId, offset, repCount); + parentRepeatedColumn.reserve(rowIds[curRepLevel] + 1); + parentRepeatedColumn.putArray(rowIds[curRepLevel], + offsets[curRepLevel], repetitions[curRepLevel]); - offset += repCount; - repCount = 0; - rowId++; - - offsets.put(curRepLevel, offset); - reptitionMap.put(curRepLevel, repCount); - rowIds.put(curRepLevel, rowId); + offsets[curRepLevel] += repetitions[curRepLevel]; + repetitions[curRepLevel] = 0; + rowIds[curRepLevel]++; // Increase the repetition count for parent repetition level as we add a new record. if (curRepLevel > 1) { - int nextRepCount = 0; - if (reptitionMap.containsKey(curRepLevel - 1)) { - nextRepCount = reptitionMap.get(curRepLevel - 1); - } - reptitionMap.put(curRepLevel - 1, nextRepCount + 1); + repetitions[curRepLevel - 1]++; } - // In vectorization, the most outside repeated element is at the reptition 1. - if (curRepLevel == 1 && rowId == total) { + // In vectorization, the most outside repeated element is at the repetition 1. + if (curRepLevel == 1 && rowIds[curRepLevel] == total) { return; } curRepLevel--; @@ -669,31 +644,24 @@ private void increaseRowId(Map rowIds, int level) { * Inserts a null record at specified column. * @param column The ColumnVector which the data in the page are read into. * @param rowIds Mapping between repetition levels and their current row ids for constructing. - * @param reptitionMap Mapping between repetition levels and their corresponding counts. + * @param repetitions Mapping between repetition levels and their corresponding counts. */ private void insertNullRecord( ColumnVector column, - Map rowIds, - Map reptitionMap) { + int[] rowIds, + int[] repetitions) { int repLevel = column.getRepLevel(); if (repLevel == 0) { repLevel = 1; } - int rowId = 0; - if (rowIds.containsKey(repLevel)) { - rowId = rowIds.get(repLevel); - } + rowIds[repLevel] += repetitions[repLevel]; + repetitions[repLevel] = 0; - if (reptitionMap.containsKey(repLevel)) { - rowId += reptitionMap.get(repLevel); - reptitionMap.put(repLevel, 0); - } - - column.reserve(rowId + 1); - column.putNull(rowId); - rowIds.put(repLevel, rowId + 1); + column.reserve(rowIds[repLevel] + 1); + column.putNull(rowIds[repLevel]); + rowIds[repLevel]++; } /** @@ -707,16 +675,12 @@ private void insertNullRecord( */ private void constructNestedRecords( ColumnVector column, - Map rowIds, - Map offsets, + int[] rowIds, + int[] offsets, int valuesReadInPage, int total) throws IOException { // Keeps repetition levels and corresponding repetition counts. - Map reptitionMap = new HashMap(); - - for (int i = 0; i <= maxRepLevel; i++) { - reptitionMap.put(i, 0); - } + int[] repetitions = new int[maxRepLevel + 2]; if (column.getParentColumn() != null) { boolean newBeginning = true; @@ -729,7 +693,7 @@ private void constructNestedRecords( // When a new record begins at lower repetition level, // we insert array into repeated column. if (repLevel < maxRepLevel) { - insertRepeatedArray(column, rowIds, offsets, reptitionMap, total, repLevel); + insertRepeatedArray(column, rowIds, offsets, repetitions, total, repLevel); } } newBeginning = false; @@ -737,10 +701,7 @@ private void constructNestedRecords( // When definition level is less than max definition level, // there is a null value. if (defLevel < maxDefLevel) { - int offset = 0; - if (offsets.containsKey(maxRepLevel)) { - offset = offsets.get(maxRepLevel); - } + int offset = offsets[maxRepLevel]; // The null value is defined at the root level. // Insert a null record. @@ -767,7 +728,7 @@ private void constructNestedRecords( // optional int32 someId; // } // } - insertRepeatedArray(column, rowIds, offsets, reptitionMap, total, repLevel); + insertRepeatedArray(column, rowIds, offsets, repetitions, total, repLevel); } else { // Obtain most outside column. ColumnVector topColumn = column.getParentColumn(); @@ -775,21 +736,22 @@ private void constructNestedRecords( topColumn = topColumn.getParentColumn(); } - insertNullRecord(topColumn, rowIds, reptitionMap); + insertNullRecord(topColumn, rowIds, repetitions); } // Move to next offset in max repetition level as we processed the current value. - offsets.put(maxRepLevel, offset + 1); + offsets[maxRepLevel]++; newBeginning = true; } else if (isLegacyArray(column) && column.getNearestParentArrayColumn().getDefLevel() == defLevel) { // For a legacy array, if a null is defined at the repeated group column, it actually // means an element with null value. - reptitionMap.put(maxRepLevel, reptitionMap.get(maxRepLevel) + 1); + repetitions[maxRepLevel]++; } else if (!column.getParentColumn().isArray() && column.getParentColumn().getDefLevel() == defLevel) { // A null element defined in the wrapping non-repeated group. - increaseRowId(rowIds, 1); + // increaseRowId(rowIds, 1); + rowIds[1]++; } else { // An empty element defined in outside group. // E.g., the element in the following schema. @@ -809,16 +771,16 @@ private void constructNestedRecords( // Found the group with the same definition level. // Insert a null record at definition level. // E.g, R=0, D=1 for above schema. - insertNullRecord(parent, rowIds, reptitionMap); - offsets.put(maxRepLevel, offset + 1); + insertNullRecord(parent, rowIds, repetitions); + offsets[maxRepLevel]++; newBeginning = true; } else { // Found the group with lower definition level. // Insert an empty record. // E.g, R=0, D=2 for above schema. parent = findHiddenInnerElementWithDefLevel(column, defLevel); - insertRepeatedArray(column, rowIds, offsets, reptitionMap, total, repLevel); - offsets.put(maxRepLevel, offset + 1); + insertRepeatedArray(column, rowIds, offsets, repetitions, total, repLevel); + offsets[maxRepLevel]++; newBeginning = true; } } @@ -827,24 +789,16 @@ private void constructNestedRecords( // A new record begins with non-null value. if (maxRepLevel == 0) { // A required record at root level. - int repCount = 0; - if (reptitionMap.containsKey(1)) { - repCount = reptitionMap.get(1); - } - reptitionMap.put(1, repCount + 1); - insertRepeatedArray(column, rowIds, offsets, reptitionMap, total, maxRepLevel - 1); + repetitions[1]++; + insertRepeatedArray(column, rowIds, offsets, repetitions, total, maxRepLevel - 1); } else { // Repeated values. We increase repetition count. - if (reptitionMap.containsKey(maxRepLevel)) { - reptitionMap.put(maxRepLevel, reptitionMap.get(maxRepLevel) + 1); - } else { - reptitionMap.put(maxRepLevel, 1); - } + repetitions[maxRepLevel]++; } } } if (!newBeginning) { - insertRepeatedArray(column, rowIds, offsets, reptitionMap, total, 0); + insertRepeatedArray(column, rowIds, offsets, repetitions, total, 0); } } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index 783bc6d5d113..ac5bc30786b1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -1006,15 +1006,28 @@ public ColumnVector getParentColumn() { return this.parentColumn; } + /** + * The flag shows if the nearest parent column is initialized. + */ + private boolean isNearestParentArrayColumnInited = false; + + /** + * The nearest parent column which is an Array column. + */ + private ColumnVector nearestParentArrayColumn; + /** * Returns the nearest parent column which is an Array column. */ public ColumnVector getNearestParentArrayColumn() { - ColumnVector parentCol = this.parentColumn; - while (parentCol != null && !parentCol.isArray()) { - parentCol = parentCol.parentColumn; + if (!isNearestParentArrayColumnInited) { + nearestParentArrayColumn = this.parentColumn; + while (nearestParentArrayColumn != null && !nearestParentArrayColumn.isArray()) { + nearestParentArrayColumn = nearestParentArrayColumn.parentColumn; + } + isNearestParentArrayColumnInited = true; } - return parentCol; + return nearestParentArrayColumn; } /** From 17f3b82403deec965e230bf37dcb74795e96bc47 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 13 Jul 2016 17:19:33 +0800 Subject: [PATCH 14/17] Remove unused method. --- .../parquet/VectorizedColumnReader.java | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index e4668c6597bb..3d0d7fd8a481 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -627,19 +627,6 @@ private boolean isLegacyArray(ColumnVector column) { return false; } - /** - * Increases row id for a specified repetition level. - * @param rowIds Mapping between repetition levels and their current row ids for constructing. - * @param level repetition level. - */ - private void increaseRowId(Map rowIds, int level) { - int rowId = 0; - if (rowIds.containsKey(level)) { - rowId = rowIds.get(level); - } - rowIds.put(level, rowId + 1); - } - /** * Inserts a null record at specified column. * @param column The ColumnVector which the data in the page are read into. @@ -750,7 +737,6 @@ private void constructNestedRecords( } else if (!column.getParentColumn().isArray() && column.getParentColumn().getDefLevel() == defLevel) { // A null element defined in the wrapping non-repeated group. - // increaseRowId(rowIds, 1); rowIds[1]++; } else { // An empty element defined in outside group. From 1788d4c3fb9d547390cdea2bcf28c597bee540d2 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 15 Jul 2016 11:35:57 +0800 Subject: [PATCH 15/17] Repetition level encoding will be split across pages. So constructing complex columns should take care of it. --- .../parquet/VectorizedColumnReader.java | 400 +++++++++++------- 1 file changed, 258 insertions(+), 142 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 3d0d7fd8a481..642ce2af4d49 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -82,9 +82,21 @@ public class VectorizedColumnReader { private SpecificParquetRecordReaderBase.IntIterator definitionLevelColumn; private ValuesReader dataColumn; + /** + * Repetition level values. Only poplulated when reading a repeated column. + */ + private int[] repLevelValues; + + /** + * The offset used to access `repLevelValues`. + */ + private int repLevelOffset; + // Only set if vectorized decoding is true. This is used instead of the row by row decoding // with `definitionLevelColumn`. private VectorizedRleValuesReader defColumn; + + // Only set when reading complex column. private VectorizedRleValuesReader defColumnCopy; /** @@ -124,55 +136,88 @@ public VectorizedColumnReader(ColumnDescriptor descriptor, PageReader pageReader throw new IOException("totalValueCount == 0"); } } - + /** * Reads `total` values from this columnReader into column. */ public void readBatch(int total, ColumnVector column) throws IOException { - boolean isNestedColumn = column.getParentColumn() != null; + boolean isComplexColumn = column.getParentColumn() != null; boolean isRepeatedColumn = maxRepLevel > 0; int rowId = 0; - int valuesReadInPage = 0; + int valuesReadInPreviousRun = 0; int repeatedRowId = 0; + int remaining = total; + + // The number of values to read. + int num = 0; + // Stores row ids and offsets during constructing nested records. int[] rowIds = new int[maxRepLevel + 2]; int[] offsets = new int[maxRepLevel + 2]; + // Keeps repetition levels and corresponding repetition counts. + int[] repetitions = new int[maxRepLevel + 2]; + + // The flag used in constructing nested records. When it is true, the previous status + // will be reset. + boolean resetNestedRecord = true; + while (true) { // Compute the number of values we want to read in this page. int leftInPage = (int) (endOfPageValueCount - valuesRead); - // When we reach the end of this page, we construct nested records of this column - // and then read next page. - if (leftInPage == 0) { - // Constructs nested records if needed. - if (valuesReadInPage > 0 && isNestedColumn) { - constructNestedRecords(column, rowIds, offsets, valuesReadInPage, total); - repeatedRowId = rowIds[1]; - valuesReadInPage = 0; - } + + // Constructs nested/repeated records if needed. + if (isComplexColumn) { + boolean endOfPage = leftInPage == 0; + resetNestedRecord = constructNestedRecords(column, repetitions, rowIds, offsets, + valuesReadInPreviousRun, total, resetNestedRecord, endOfPage); + repeatedRowId = rowIds[1]; } + // Stop condition: // If we are going to read data in repeated column, the stop condition is that we // read `total` repeated columns. Eg., if we want to read 5 records of an array of int column. // we can't just read 5 integers. Instead, we have to read the integers until 5 arrays are put // into this array column. if (isRepeatedColumn) { - if (repeatedRowId >= total) break; + if (repeatedRowId == total) break; } else { - if (total <= 0) break; + if (remaining == 0) break; } + // Reaching the end of current page. if (leftInPage == 0) { boolean pageExists = readPage(); + if (isRepeatedColumn) { + // Repetition level encoding will be split across two pages. + // We need to check if it is happen. + if (this.repLevelValues[0] == 0 || !pageExists) { + // No split of repetition level encoding, going to insert last array if any that we skip + // during the end of last page. + if (!resetNestedRecord) { + insertRepeatedArray(column, rowIds, offsets, repetitions, total, 0); + resetNestedRecord = true; + repeatedRowId = rowIds[1]; + if (repeatedRowId == total) break; + } + } + } if (!pageExists) { - return; + // Should not reach here. + throw new IOException("Failed to read page. No page exists anymore!"); } leftInPage = (int) (endOfPageValueCount - valuesRead); } - int num = Math.min(total, leftInPage); + + if (isRepeatedColumn) { + num = getValueCountToReadForRepeatedRowNums(total - repeatedRowId); + } else { + num = Math.min(remaining, leftInPage); + } + if (useDictionary) { // Read and decode dictionary ids. - int dictionaryCapacity = Math.max(total, rowId + num); + int dictionaryCapacity = Math.max(remaining, rowId + num); ColumnVector dictionaryIds = column.reserveDictionaryIds(dictionaryCapacity); defColumn.readIntegers( num, dictionaryIds, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); @@ -197,6 +242,9 @@ public void readBatch(int total, ColumnVector column) throws IOException { decodeDictionaryIds(0, rowId, column, column.getDictionaryIds()); } column.setDictionary(null); + if (isRepeatedColumn) { + column.reserve(Math.max(remaining, rowId + num)); + } switch (descriptor.getType()) { case BOOLEAN: readBooleanBatch(rowId, num, column); @@ -226,13 +274,10 @@ public void readBatch(int total, ColumnVector column) throws IOException { throw new IOException("Unsupported type: " + descriptor.getType()); } } - - valuesReadInPage += num; + valuesReadInPreviousRun = num; valuesRead += num; rowId += num; - if (!isRepeatedColumn) { - total -= num; - } + remaining -= num; } } @@ -651,149 +696,201 @@ private void insertNullRecord( rowIds[repLevel]++; } + /** + * Returns the array of repetition level values. + */ + private int[] getRepetitionLevels() throws IOException { + int[] repetitions = new int[this.pageValueCount]; + for (int i = 0; i < this.pageValueCount; i++) { + repetitions[i] = this.repetitionLevelColumn.nextInt(); + } + return repetitions; + } + + /** + * Returns the number of repeated rows in the current page. + */ + private int getRepeatedRowNumsForCurrentPage() { + if (this.repLevelValues != null) { + int rowNums = 0; + for (int i = 0; i < this.repLevelValues.length; i++) { + if (repLevelValues[i] == 0) rowNums++; + } + return rowNums; + } else { + return 0; + } + } + + /** + * Returns the number of values needed to read in order to have the number of repeated rows. + * @param num the number of repeated rows. + */ + private int getValueCountToReadForRepeatedRowNums(int num) { + int rowNum = 0; + if (this.repLevelValues[this.repLevelOffset] != 0) rowNum++; + for (int i = this.repLevelOffset; i < this.repLevelValues.length; i++) { + if (repLevelValues[i] == 0) { + rowNum++; + if (rowNum > num) { + return i - this.repLevelOffset; + } + } + } + return this.repLevelValues.length - this.repLevelOffset; + } + /** * Iterates the values of definition and repetition levels for the values read in the page, * and constructs nested records accordingly. * @param column The ColumnVector which the data in the page are read into. + @ @param repetitions Mapping between repetition levels and their counts. * @param rowIds Mapping between repetition levels and their current row ids for constructing. * @param offsets The beginning offsets in columns which we use to construct nested records. - * @param valuesReadInPage The number of values read in the current page. + * @param num The number of values read. * @param total The total number of rows to construct. + * @param resetNestedRecord When it is true, the previous status will be reset. + * @param endOfPage Whether reaching the end of current page. + * @return the updated resetNestedRecord flag. */ - private void constructNestedRecords( + private boolean constructNestedRecords( ColumnVector column, + int[] repetitions, int[] rowIds, int[] offsets, - int valuesReadInPage, - int total) throws IOException { - // Keeps repetition levels and corresponding repetition counts. - int[] repetitions = new int[maxRepLevel + 2]; - - if (column.getParentColumn() != null) { - boolean newBeginning = true; - - for (int i = 0; i < valuesReadInPage; i++) { - int repLevel = repetitionLevelColumn.nextInt(); - int defLevel = definitionLevelColumn.nextInt(); - - if (!newBeginning) { - // When a new record begins at lower repetition level, - // we insert array into repeated column. - if (repLevel < maxRepLevel) { - insertRepeatedArray(column, rowIds, offsets, repetitions, total, repLevel); - } + int num, + int total, + boolean resetNestedRecord, + boolean endOfPage) throws IOException { + for (int i = 0; i < num; i++) { + int repLevel; + if (repLevelValues != null) { + repLevel = repLevelValues[repLevelOffset++]; + } else { + repLevel = maxRepLevel; + } + int defLevel = definitionLevelColumn.nextInt(); + + // If there are previous values and counts needed to be consider. + if (!resetNestedRecord) { + // When a new record begins at lower repetition level, + // we insert array into repeated column. + if (repLevel < maxRepLevel) { + insertRepeatedArray(column, rowIds, offsets, repetitions, total, repLevel); } - newBeginning = false; - - // When definition level is less than max definition level, - // there is a null value. - if (defLevel < maxDefLevel) { - int offset = offsets[maxRepLevel]; - - // The null value is defined at the root level. - // Insert a null record. - if (repLevel == 0 && defLevel == 0) { - ColumnVector parent = column.getParentColumn(); - if (parent != null && parent.getDefLevel() == maxDefLevel - && parent.getRepLevel() == maxRepLevel) { - // A repeated element at root level. - // E.g., The repeatedPrimitive at the following schema. - // Going to insert an empty record. - // messageType: message spark_schema { - // optional int32 optionalPrimitive; - // required int32 requiredPrimitive; - // - // repeated int32 repeatedPrimitive; - // - // optional group optionalMessage { - // optional int32 someId; - // } - // required group requiredMessage { - // optional int32 someId; - // } - // repeated group repeatedMessage { - // optional int32 someId; - // } - // } - insertRepeatedArray(column, rowIds, offsets, repetitions, total, repLevel); - } else { - // Obtain most outside column. - ColumnVector topColumn = column.getParentColumn(); - while (topColumn.getParentColumn() != null) { - topColumn = topColumn.getParentColumn(); - } - - insertNullRecord(topColumn, rowIds, repetitions); - } - // Move to next offset in max repetition level as we processed the current value. - offsets[maxRepLevel]++; - newBeginning = true; - } else if (isLegacyArray(column) && - column.getNearestParentArrayColumn().getDefLevel() == defLevel) { - // For a legacy array, if a null is defined at the repeated group column, it actually - // means an element with null value. - - repetitions[maxRepLevel]++; - } else if (!column.getParentColumn().isArray() && - column.getParentColumn().getDefLevel() == defLevel) { - // A null element defined in the wrapping non-repeated group. - rowIds[1]++; - } else { - // An empty element defined in outside group. - // E.g., the element in the following schema. + } + resetNestedRecord = false; + + // When definition level is less than max definition level, + // there is a null value. + if (defLevel < maxDefLevel) { + int offset = offsets[maxRepLevel]; + + // The null value is defined at the root level. + // Insert a null record. + if (repLevel == 0 && defLevel == 0) { + ColumnVector parent = column.getParentColumn(); + if (parent != null && parent.getDefLevel() == maxDefLevel + && parent.getRepLevel() == maxRepLevel) { + // A repeated element at root level. + // E.g., The repeatedPrimitive at the following schema. + // Going to insert an empty record. // messageType: message spark_schema { - // required int32 index; - // optional group col { - // optional float f1; - // optional group f2 (LIST) { - // repeated group list { - // optional boolean element; - // } - // } + // optional int32 optionalPrimitive; + // required int32 requiredPrimitive; + // + // repeated int32 repeatedPrimitive; + // + // optional group optionalMessage { + // optional int32 someId; + // } + // required group requiredMessage { + // optional int32 someId; + // } + // repeated group repeatedMessage { + // optional int32 someId; // } // } - ColumnVector parent = findInnerElementWithDefLevel(column, defLevel); - if (parent != null) { - // Found the group with the same definition level. - // Insert a null record at definition level. - // E.g, R=0, D=1 for above schema. - insertNullRecord(parent, rowIds, repetitions); - offsets[maxRepLevel]++; - newBeginning = true; - } else { - // Found the group with lower definition level. - // Insert an empty record. - // E.g, R=0, D=2 for above schema. - parent = findHiddenInnerElementWithDefLevel(column, defLevel); - insertRepeatedArray(column, rowIds, offsets, repetitions, total, repLevel); - offsets[maxRepLevel]++; - newBeginning = true; + insertRepeatedArray(column, rowIds, offsets, repetitions, total, repLevel); + } else { + // Obtain most outside column. + ColumnVector topColumn = column.getParentColumn(); + while (topColumn.getParentColumn() != null) { + topColumn = topColumn.getParentColumn(); } + + insertNullRecord(topColumn, rowIds, repetitions); } + // Move to next offset in max repetition level as we processed the current value. + offsets[maxRepLevel]++; + resetNestedRecord = true; + } else if (isLegacyArray(column) && + column.getNearestParentArrayColumn().getDefLevel() == defLevel) { + // For a legacy array, if a null is defined at the repeated group column, it actually + // means an element with null value. + + repetitions[maxRepLevel]++; + } else if (!column.getParentColumn().isArray() && + column.getParentColumn().getDefLevel() == defLevel) { + // A null element defined in the wrapping non-repeated group. + rowIds[1]++; } else { - // Determine the repetition level of non-null values. - // A new record begins with non-null value. - if (maxRepLevel == 0) { - // A required record at root level. - repetitions[1]++; - insertRepeatedArray(column, rowIds, offsets, repetitions, total, maxRepLevel - 1); + // An empty element defined in outside group. + // E.g., the element in the following schema. + // messageType: message spark_schema { + // required int32 index; + // optional group col { + // optional float f1; + // optional group f2 (LIST) { + // repeated group list { + // optional boolean element; + // } + // } + // } + // } + ColumnVector parent = findInnerElementWithDefLevel(column, defLevel); + if (parent != null) { + // Found the group with the same definition level. + // Insert a null record at definition level. + // E.g, R=0, D=1 for above schema. + insertNullRecord(parent, rowIds, repetitions); + offsets[maxRepLevel]++; + resetNestedRecord = true; } else { - // Repeated values. We increase repetition count. - repetitions[maxRepLevel]++; + // Found the group with lower definition level. + // Insert an empty record. + // E.g, R=0, D=2 for above schema. + parent = findHiddenInnerElementWithDefLevel(column, defLevel); + insertRepeatedArray(column, rowIds, offsets, repetitions, total, repLevel); + offsets[maxRepLevel]++; + resetNestedRecord = true; } } + } else { + // Determine the repetition level of non-null values. + // A new record begins with non-null value. + if (maxRepLevel == 0) { + // A required record at root level. + repetitions[1]++; + insertRepeatedArray(column, rowIds, offsets, repetitions, total, maxRepLevel - 1); + } else { + // Repeated values. We increase repetition count. + repetitions[maxRepLevel]++; + } } - if (!newBeginning) { - insertRepeatedArray(column, rowIds, offsets, repetitions, total, 0); - } } + // Insert the last repeated record if any. + if (!endOfPage && !resetNestedRecord) { + insertRepeatedArray(column, rowIds, offsets, repetitions, total, 0); + resetNestedRecord = true; + } + return resetNestedRecord; } private void readPageV1(DataPageV1 page) throws IOException { this.pageValueCount = page.getValueCount(); ValuesReader rlReader = page.getRlEncoding().getValuesReader(descriptor, REPETITION_LEVEL); ValuesReader dlReader; - ValuesReader dlReaderCopy; // Initialize the decoders. if (page.getDlEncoding() != Encoding.RLE && descriptor.getMaxDefinitionLevel() != 0) { @@ -801,17 +898,27 @@ private void readPageV1(DataPageV1 page) throws IOException { } int bitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel()); this.defColumn = new VectorizedRleValuesReader(bitWidth); - this.defColumnCopy = new VectorizedRleValuesReader(bitWidth); dlReader = this.defColumn; - dlReaderCopy = this.defColumnCopy; this.repetitionLevelColumn = new ValuesReaderIntIterator(rlReader); - this.definitionLevelColumn = new ValuesReaderIntIterator(dlReaderCopy); try { byte[] bytes = page.getBytes().toByteArray(); rlReader.initFromPage(pageValueCount, bytes, 0); int next = rlReader.getNextOffset(); dlReader.initFromPage(pageValueCount, bytes, next); + + ValuesReader dlReaderCopy; + this.defColumnCopy = new VectorizedRleValuesReader(bitWidth); + dlReaderCopy = this.defColumnCopy; + this.definitionLevelColumn = new ValuesReaderIntIterator(dlReaderCopy); dlReaderCopy.initFromPage(pageValueCount, bytes, next); + + // If this is a repeated column, read repetition level values for this page. + if (maxRepLevel > 0) { + this.repLevelValues = getRepetitionLevels(); + int numOfRepeatedRow = getRepeatedRowNumsForCurrentPage(); + this.repLevelOffset = 0; + } + next = dlReader.getNextOffset(); initDataReader(page.getValueEncoding(), bytes, next); } catch (IOException e) { @@ -826,12 +933,21 @@ private void readPageV2(DataPageV2 page) throws IOException { int bitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel()); this.defColumn = new VectorizedRleValuesReader(bitWidth); - this.defColumnCopy = new VectorizedRleValuesReader(bitWidth); - this.definitionLevelColumn = new ValuesReaderIntIterator(this.defColumnCopy); this.defColumn.initFromBuffer( this.pageValueCount, page.getDefinitionLevels().toByteArray()); + + this.defColumnCopy = new VectorizedRleValuesReader(bitWidth); + this.definitionLevelColumn = new ValuesReaderIntIterator(this.defColumnCopy); this.defColumnCopy.initFromBuffer( - this.pageValueCount, page.getDefinitionLevels().toByteArray()); + this.pageValueCount, page.getDefinitionLevels().toByteArray()); + + // If this is a repeated column, read repetition level values for this page. + if (maxRepLevel > 0) { + this.repLevelValues = getRepetitionLevels(); + int numOfRepeatedRow = getRepeatedRowNumsForCurrentPage(); + this.repLevelOffset = 0; + } + try { initDataReader(page.getDataEncoding(), page.getData().toByteArray(), 0); } catch (IOException e) { From 3b8c3ce36c1fcb2f201fe3588e2cbd3a4b32e7b4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 15 Jul 2016 15:01:08 +0800 Subject: [PATCH 16/17] Fix a bug. --- .../execution/datasources/parquet/VectorizedColumnReader.java | 1 + .../datasources/parquet/VectorizedRleValuesReader.java | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 642ce2af4d49..bc844537e3af 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -878,6 +878,7 @@ private boolean constructNestedRecords( repetitions[maxRepLevel]++; } } + if (rowIds[1] == total) return resetNestedRecord; } // Insert the last repeated record if any. if (!endOfPage && !resetNestedRecord) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index 6413e4e595c8..62157389013b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -498,7 +498,7 @@ public void readLongs(int total, ColumnVector c, int rowId) { public void readBinary(int total, ColumnVector c, int rowId) { throw new UnsupportedOperationException("only readInts is valid."); } - + @Override public void readBooleans(int total, ColumnVector c, int rowId) { throw new UnsupportedOperationException("only readInts is valid."); From cc35cabac105b3778c26afc22ac4f4ca1b295585 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 20 Jul 2016 11:47:24 +0800 Subject: [PATCH 17/17] Improve the algorithm. --- .../parquet/VectorizedColumnReader.java | 168 +++++------------- 1 file changed, 49 insertions(+), 119 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index bc844537e3af..08091dd8ab39 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -82,16 +82,6 @@ public class VectorizedColumnReader { private SpecificParquetRecordReaderBase.IntIterator definitionLevelColumn; private ValuesReader dataColumn; - /** - * Repetition level values. Only poplulated when reading a repeated column. - */ - private int[] repLevelValues; - - /** - * The offset used to access `repLevelValues`. - */ - private int repLevelOffset; - // Only set if vectorized decoding is true. This is used instead of the row by row decoding // with `definitionLevelColumn`. private VectorizedRleValuesReader defColumn; @@ -136,15 +126,24 @@ public VectorizedColumnReader(ColumnDescriptor descriptor, PageReader pageReader throw new IOException("totalValueCount == 0"); } } + /** + * Whether this column is the element of a complex column. + */ + boolean asComplexColElement; + + /** + * The flag used in constructing nested records. When it is true, the previous status + * will be reset. + */ + boolean resetNestedRecord = true; /** * Reads `total` values from this columnReader into column. */ public void readBatch(int total, ColumnVector column) throws IOException { - boolean isComplexColumn = column.getParentColumn() != null; + asComplexColElement = column.getParentColumn() != null; boolean isRepeatedColumn = maxRepLevel > 0; int rowId = 0; - int valuesReadInPreviousRun = 0; int repeatedRowId = 0; int remaining = total; @@ -158,22 +157,10 @@ public void readBatch(int total, ColumnVector column) throws IOException { // Keeps repetition levels and corresponding repetition counts. int[] repetitions = new int[maxRepLevel + 2]; - // The flag used in constructing nested records. When it is true, the previous status - // will be reset. - boolean resetNestedRecord = true; - while (true) { // Compute the number of values we want to read in this page. int leftInPage = (int) (endOfPageValueCount - valuesRead); - // Constructs nested/repeated records if needed. - if (isComplexColumn) { - boolean endOfPage = leftInPage == 0; - resetNestedRecord = constructNestedRecords(column, repetitions, rowIds, offsets, - valuesReadInPreviousRun, total, resetNestedRecord, endOfPage); - repeatedRowId = rowIds[1]; - } - // Stop condition: // If we are going to read data in repeated column, the stop condition is that we // read `total` repeated columns. Eg., if we want to read 5 records of an array of int column. @@ -188,30 +175,29 @@ public void readBatch(int total, ColumnVector column) throws IOException { // Reaching the end of current page. if (leftInPage == 0) { boolean pageExists = readPage(); - if (isRepeatedColumn) { - // Repetition level encoding will be split across two pages. - // We need to check if it is happen. - if (this.repLevelValues[0] == 0 || !pageExists) { - // No split of repetition level encoding, going to insert last array if any that we skip - // during the end of last page. - if (!resetNestedRecord) { - insertRepeatedArray(column, rowIds, offsets, repetitions, total, 0); - resetNestedRecord = true; - repeatedRowId = rowIds[1]; - if (repeatedRowId == total) break; - } - } - } if (!pageExists) { + if (!resetNestedRecord) { + insertRepeatedArray(column, rowIds, offsets, repetitions, total, 0); + resetNestedRecord = true; + repeatedRowId = rowIds[1]; + if (repeatedRowId == total) break; + } // Should not reach here. throw new IOException("Failed to read page. No page exists anymore!"); } leftInPage = (int) (endOfPageValueCount - valuesRead); } - if (isRepeatedColumn) { - num = getValueCountToReadForRepeatedRowNums(total - repeatedRowId); + // Determine the number of values to read for this column in the current page. + if (asComplexColElement) { + // Using repetition and definition level encodings to construct nested/repeated records. + // When constructing nested/repeated records, we returns the number of values to read in + // this page for this column. + num = constructComplexRecords(column, repetitions, rowIds, offsets, leftInPage, total); + repeatedRowId = rowIds[1]; } else { + // If this column is not a repeated/nested column, just read minimum of remaining values + // and all values left in the current page. num = Math.min(remaining, leftInPage); } @@ -274,7 +260,6 @@ public void readBatch(int total, ColumnVector column) throws IOException { throw new IOException("Unsupported type: " + descriptor.getType()); } } - valuesReadInPreviousRun = num; valuesRead += num; rowId += num; remaining -= num; @@ -707,68 +692,26 @@ private int[] getRepetitionLevels() throws IOException { return repetitions; } - /** - * Returns the number of repeated rows in the current page. - */ - private int getRepeatedRowNumsForCurrentPage() { - if (this.repLevelValues != null) { - int rowNums = 0; - for (int i = 0; i < this.repLevelValues.length; i++) { - if (repLevelValues[i] == 0) rowNums++; - } - return rowNums; - } else { - return 0; - } - } - - /** - * Returns the number of values needed to read in order to have the number of repeated rows. - * @param num the number of repeated rows. - */ - private int getValueCountToReadForRepeatedRowNums(int num) { - int rowNum = 0; - if (this.repLevelValues[this.repLevelOffset] != 0) rowNum++; - for (int i = this.repLevelOffset; i < this.repLevelValues.length; i++) { - if (repLevelValues[i] == 0) { - rowNum++; - if (rowNum > num) { - return i - this.repLevelOffset; - } - } - } - return this.repLevelValues.length - this.repLevelOffset; - } - /** * Iterates the values of definition and repetition levels for the values read in the page, - * and constructs nested records accordingly. + * and constructs complex records accordingly. * @param column The ColumnVector which the data in the page are read into. @ @param repetitions Mapping between repetition levels and their counts. * @param rowIds Mapping between repetition levels and their current row ids for constructing. * @param offsets The beginning offsets in columns which we use to construct nested records. - * @param num The number of values read. + * @param leftInPage The number of values can be read in the current page. * @param total The total number of rows to construct. - * @param resetNestedRecord When it is true, the previous status will be reset. - * @param endOfPage Whether reaching the end of current page. - * @return the updated resetNestedRecord flag. + * @return the number of values needed to read in the current page. */ - private boolean constructNestedRecords( + private int constructComplexRecords( ColumnVector column, int[] repetitions, int[] rowIds, int[] offsets, - int num, - int total, - boolean resetNestedRecord, - boolean endOfPage) throws IOException { - for (int i = 0; i < num; i++) { - int repLevel; - if (repLevelValues != null) { - repLevel = repLevelValues[repLevelOffset++]; - } else { - repLevel = maxRepLevel; - } + int leftInPage, + int total) throws IOException { + for (int i = 0; i < leftInPage; i++) { + int repLevel = repetitionLevelColumn.nextInt(); int defLevel = definitionLevelColumn.nextInt(); // If there are previous values and counts needed to be consider. @@ -878,14 +821,11 @@ private boolean constructNestedRecords( repetitions[maxRepLevel]++; } } - if (rowIds[1] == total) return resetNestedRecord; + // If we have constructed `total` records, return the number of values to read. + if (rowIds[1] == total) return i + 1; } - // Insert the last repeated record if any. - if (!endOfPage && !resetNestedRecord) { - insertRepeatedArray(column, rowIds, offsets, repetitions, total, 0); - resetNestedRecord = true; - } - return resetNestedRecord; + // All `leftInPage` values in the current page are needed to read. + return leftInPage; } private void readPageV1(DataPageV1 page) throws IOException { @@ -907,17 +847,12 @@ private void readPageV1(DataPageV1 page) throws IOException { int next = rlReader.getNextOffset(); dlReader.initFromPage(pageValueCount, bytes, next); - ValuesReader dlReaderCopy; - this.defColumnCopy = new VectorizedRleValuesReader(bitWidth); - dlReaderCopy = this.defColumnCopy; - this.definitionLevelColumn = new ValuesReaderIntIterator(dlReaderCopy); - dlReaderCopy.initFromPage(pageValueCount, bytes, next); - - // If this is a repeated column, read repetition level values for this page. - if (maxRepLevel > 0) { - this.repLevelValues = getRepetitionLevels(); - int numOfRepeatedRow = getRepeatedRowNumsForCurrentPage(); - this.repLevelOffset = 0; + if (asComplexColElement) { + ValuesReader dlReaderCopy; + this.defColumnCopy = new VectorizedRleValuesReader(bitWidth); + dlReaderCopy = this.defColumnCopy; + this.definitionLevelColumn = new ValuesReaderIntIterator(dlReaderCopy); + dlReaderCopy.initFromPage(pageValueCount, bytes, next); } next = dlReader.getNextOffset(); @@ -937,16 +872,11 @@ private void readPageV2(DataPageV2 page) throws IOException { this.defColumn.initFromBuffer( this.pageValueCount, page.getDefinitionLevels().toByteArray()); - this.defColumnCopy = new VectorizedRleValuesReader(bitWidth); - this.definitionLevelColumn = new ValuesReaderIntIterator(this.defColumnCopy); - this.defColumnCopy.initFromBuffer( - this.pageValueCount, page.getDefinitionLevels().toByteArray()); - - // If this is a repeated column, read repetition level values for this page. - if (maxRepLevel > 0) { - this.repLevelValues = getRepetitionLevels(); - int numOfRepeatedRow = getRepeatedRowNumsForCurrentPage(); - this.repLevelOffset = 0; + if (asComplexColElement) { + this.defColumnCopy = new VectorizedRleValuesReader(bitWidth); + this.definitionLevelColumn = new ValuesReaderIntIterator(this.defColumnCopy); + this.defColumnCopy.initFromBuffer( + this.pageValueCount, page.getDefinitionLevels().toByteArray()); } try {