From a4e8048b0b1a7f6e7818bc4c0fa1a17604d2c239 Mon Sep 17 00:00:00 2001 From: Johan Lasperas Date: Fri, 15 Dec 2023 10:54:59 +0100 Subject: [PATCH 1/6] Type promotion in Parquet readers --- .../parquet/ParquetVectorUpdaterFactory.java | 241 ++++++++++++++++++ .../parquet/VectorizedColumnReader.java | 44 +++- .../parquet/ParquetRowConverter.scala | 22 +- .../datasources/parquet/ParquetIOSuite.scala | 11 - .../parquet/ParquetQuerySuite.scala | 12 +- .../parquet/ParquetTypeWideningSuite.scala | 215 ++++++++++++++++ 6 files changed, 513 insertions(+), 32 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java index 1ed6c4329ebd0..0ece8fefac09a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java @@ -81,6 +81,10 @@ public ParquetVectorUpdater getUpdater(ColumnDescriptor descriptor, DataType spa // For unsigned int32, it stores as plain signed int32 in Parquet when dictionary // fallbacks. We read them as long values. return new UnsignedIntegerUpdater(); + } else if (sparkType == DataTypes.LongType || canReadAsLongDecimal(descriptor, sparkType)) { + return new IntegerToLongUpdater(); + } else if (canReadAsBinaryDecimal(descriptor, sparkType)) { + return new IntegerToBinaryUpdater(); } else if (sparkType == DataTypes.ByteType) { return new ByteUpdater(); } else if (sparkType == DataTypes.ShortType) { @@ -92,6 +96,13 @@ public ParquetVectorUpdater getUpdater(ColumnDescriptor descriptor, DataType spa boolean failIfRebase = "EXCEPTION".equals(datetimeRebaseMode); return new IntegerWithRebaseUpdater(failIfRebase); } + } else if (sparkType == DataTypes.TimestampNTZType) { + if ("CORRECTED".equals(datetimeRebaseMode)) { + return new DateToTimestampNTZUpdater(); + } else { + boolean failIfRebase = "EXCEPTION".equals(datetimeRebaseMode); + return new DateToTimestampNTZWithRebaseUpdater(failIfRebase); + } } else if (sparkType instanceof YearMonthIntervalType) { return new IntegerUpdater(); } @@ -104,6 +115,8 @@ public ParquetVectorUpdater getUpdater(ColumnDescriptor descriptor, DataType spa } else { return new LongUpdater(); } + } else if (canReadAsBinaryDecimal(descriptor, sparkType)) { + return new LongToBinaryUpdater(); } else if (isLongDecimal(sparkType) && isUnsignedIntTypeMatched(64)) { // In `ParquetToSparkSchemaConverter`, we map parquet UINT64 to our Decimal(20, 0). // For unsigned int64, it stores as plain signed int64 in Parquet when dictionary @@ -134,6 +147,8 @@ public ParquetVectorUpdater getUpdater(ColumnDescriptor descriptor, DataType spa case FLOAT -> { if (sparkType == DataTypes.FloatType) { return new FloatUpdater(); + } else if (sparkType == DataTypes.DoubleType) { + return new FloatToDoubleUpdater(); } } case DOUBLE -> { @@ -281,6 +296,121 @@ public void decodeSingleDictionaryId( } } + static class IntegerToLongUpdater implements ParquetVectorUpdater { + @Override + public void readValues( + int total, + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + for (int i = 0; i < total; ++i) { + values.putLong(offset + i, valuesReader.readInteger()); + } + } + + @Override + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipIntegers(total); + } + + @Override + public void readValue( + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + values.putLong(offset, valuesReader.readInteger()); + } + + @Override + public void decodeSingleDictionaryId( + int offset, + WritableColumnVector values, + WritableColumnVector dictionaryIds, + Dictionary dictionary) { + values.putLong(offset, dictionary.decodeToInt(dictionaryIds.getDictId(offset))); + } + } + + static class DateToTimestampNTZUpdater implements ParquetVectorUpdater { + @Override + public void readValues( + int total, + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + for (int i = 0; i < total; ++i) { + values.putLong(offset + i, DateTimeUtils.daysToMicros(valuesReader.readInteger(), ZoneOffset.UTC)); + } + } + + @Override + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipIntegers(total); + } + + @Override + public void readValue( + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + values.putLong(offset, DateTimeUtils.daysToMicros(valuesReader.readInteger(), ZoneOffset.UTC)); + } + + @Override + public void decodeSingleDictionaryId( + int offset, + WritableColumnVector values, + WritableColumnVector dictionaryIds, + Dictionary dictionary) { + int days = dictionary.decodeToInt(dictionaryIds.getDictId(offset)); + values.putLong(offset, DateTimeUtils.daysToMicros(days, ZoneOffset.UTC)); + } + } + + private static class DateToTimestampNTZWithRebaseUpdater implements ParquetVectorUpdater { + private final boolean failIfRebase; + + DateToTimestampNTZWithRebaseUpdater(boolean failIfRebase) { + this.failIfRebase = failIfRebase; + } + + @Override + public void readValues( + int total, + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + for (int i = 0; i < total; ++i) { + int rebasedDays = rebaseDays(valuesReader.readInteger(), failIfRebase); + values.putLong(offset + i, DateTimeUtils.daysToMicros(rebasedDays, ZoneOffset.UTC)); + } + } + + @Override + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipIntegers(total); + } + + @Override + public void readValue( + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + int rebasedDays = rebaseDays(valuesReader.readInteger(), failIfRebase); + values.putLong(offset, DateTimeUtils.daysToMicros(rebasedDays, ZoneOffset.UTC)); + } + + @Override + public void decodeSingleDictionaryId( + int offset, + WritableColumnVector values, + WritableColumnVector dictionaryIds, + Dictionary dictionary) { + int rebasedDays = rebaseDays(dictionary.decodeToInt(dictionaryIds.getDictId(offset)), failIfRebase); + values.putLong(offset, DateTimeUtils.daysToMicros(rebasedDays, ZoneOffset.UTC)); + } + } + private static class UnsignedIntegerUpdater implements ParquetVectorUpdater { @Override public void readValues( @@ -684,6 +814,41 @@ public void decodeSingleDictionaryId( } } + static class FloatToDoubleUpdater implements ParquetVectorUpdater { + @Override + public void readValues( + int total, + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + for (int i = 0; i < total; ++i) { + values.putDouble(offset + i, valuesReader.readFloat()); + } + } + + @Override + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipFloats(total); + } + + @Override + public void readValue( + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + values.putDouble(offset, valuesReader.readFloat()); + } + + @Override + public void decodeSingleDictionaryId( + int offset, + WritableColumnVector values, + WritableColumnVector dictionaryIds, + Dictionary dictionary) { + values.putDouble(offset, dictionary.decodeToFloat(dictionaryIds.getDictId(offset))); + } + } + private static class DoubleUpdater implements ParquetVectorUpdater { @Override public void readValues( @@ -751,6 +916,82 @@ public void decodeSingleDictionaryId( } } + private static class IntegerToBinaryUpdater implements ParquetVectorUpdater { + + @Override + public void readValues( + int total, + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + for (int i = 0; i < total; i++) { + readValue(offset + i, values, valuesReader); + } + } + + @Override + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipIntegers(total); + } + + @Override + public void readValue( + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + BigInteger value = BigInteger.valueOf(valuesReader.readInteger()); + values.putByteArray(offset, value.toByteArray()); + } + + @Override + public void decodeSingleDictionaryId( + int offset, + WritableColumnVector values, + WritableColumnVector dictionaryIds, + Dictionary dictionary) { + BigInteger value = BigInteger.valueOf(dictionary.decodeToInt(dictionaryIds.getDictId(offset))); + values.putByteArray(offset, value.toByteArray()); + } + } + + private static class LongToBinaryUpdater implements ParquetVectorUpdater { + + @Override + public void readValues( + int total, + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + for (int i = 0; i < total; i++) { + readValue(offset + i, values, valuesReader); + } + } + + @Override + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipLongs(total); + } + + @Override + public void readValue( + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + BigInteger value = BigInteger.valueOf(valuesReader.readLong()); + values.putByteArray(offset, value.toByteArray()); + } + + @Override + public void decodeSingleDictionaryId( + int offset, + WritableColumnVector values, + WritableColumnVector dictionaryIds, + Dictionary dictionary) { + BigInteger value = BigInteger.valueOf(dictionary.decodeToLong(dictionaryIds.getDictId(offset))); + values.putByteArray(offset, value.toByteArray()); + } + } + private static class BinaryToSQLTimestampUpdater implements ParquetVectorUpdater { @Override public void readValues( diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 04fbe716ad92f..33b9412c37663 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -38,10 +38,13 @@ import org.apache.parquet.schema.PrimitiveType; import org.apache.spark.sql.execution.vectorized.WritableColumnVector; +import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.DecimalType; import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.BOOLEAN; import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT64; +import static org.apache.spark.sql.types.DataTypes.*; /** * Decoder to return values from a single column. @@ -140,23 +143,42 @@ public VectorizedColumnReader( this.writerVersion = writerVersion; } - private boolean isLazyDecodingSupported(PrimitiveType.PrimitiveTypeName typeName) { + private boolean isLazyDecodingSupported(PrimitiveType.PrimitiveTypeName typeName, + DataType sparkType) { + // Don't use lazy dictionary decoding if the column needs extra processing: upcasting or date / + // decimal scale rebasing. return switch (typeName) { - case INT32 -> - !(logicalTypeAnnotation instanceof DateLogicalTypeAnnotation) || - "CORRECTED".equals(datetimeRebaseMode); + case INT32 -> { + boolean needsUpcast = sparkType == LongType || sparkType == TimestampNTZType || + !DecimalType.is32BitDecimalType(sparkType); + boolean needsRebase = logicalTypeAnnotation instanceof DateLogicalTypeAnnotation && !"CORRECTED".equals(datetimeRebaseMode); + yield !needsUpcast && !needsRebase && !needsDecimalScaleRebase(sparkType); + } case INT64 -> { - if (updaterFactory.isTimestampTypeMatched(TimeUnit.MICROS)) { - yield "CORRECTED".equals(datetimeRebaseMode); - } else { - yield !updaterFactory.isTimestampTypeMatched(TimeUnit.MILLIS); - } + boolean needsUpcast = !DecimalType.is64BitDecimalType(sparkType) || + updaterFactory.isTimestampTypeMatched(TimeUnit.MILLIS); + boolean needsRebase = updaterFactory.isTimestampTypeMatched(TimeUnit.MICROS) && !"CORRECTED".equals(datetimeRebaseMode); + yield !needsUpcast && !needsRebase && !needsDecimalScaleRebase(sparkType); } - case FLOAT, DOUBLE, BINARY -> true; + case FLOAT -> sparkType == FloatType; + case DOUBLE, BINARY -> !needsDecimalScaleRebase(sparkType); default -> false; }; } + /** + * Returns whether the Parquet type of this column and the given spark type are two decimal types + * with different scales. + */ + private boolean needsDecimalScaleRebase(DataType sparkType) { + LogicalTypeAnnotation typeAnnotation = descriptor.getPrimitiveType().getLogicalTypeAnnotation(); + if (!(typeAnnotation instanceof DecimalLogicalTypeAnnotation)) return false; + if (!(sparkType instanceof DecimalType)) return false; + DecimalLogicalTypeAnnotation parquetDecimal = (DecimalLogicalTypeAnnotation) typeAnnotation; + DecimalType sparkDecimal = (DecimalType) sparkType; + return parquetDecimal.getScale() != sparkDecimal.scale(); +} + /** * Reads `total` rows from this columnReader into column. */ @@ -205,7 +227,7 @@ void readBatch( // TIMESTAMP_MILLIS encoded as INT64 can't be lazily decoded as we need to post process // the values to add microseconds precision. if (column.hasDictionary() || (startRowId == pageFirstRowIndex && - isLazyDecodingSupported(typeName))) { + isLazyDecodingSupported(typeName, column.dataType()))) { // Column vector supports lazy decoding of dictionary values so just set the dictionary. // We can't do this if startRowId is not the first row index in the page AND the column // doesn't have a dictionary (i.e. some non-dictionary encoded values have already been diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index b3be89085014e..dd72ce6b31961 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -29,7 +29,7 @@ import org.apache.parquet.io.ColumnIOFactory import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter} import org.apache.parquet.schema.{GroupType, Type, Types} import org.apache.parquet.schema.LogicalTypeAnnotation._ -import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.{BINARY, FIXED_LEN_BYTE_ARRAY, INT32, INT64, INT96} +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.{BINARY, FIXED_LEN_BYTE_ARRAY, FLOAT, INT32, INT64, INT96} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow @@ -313,6 +313,16 @@ private[parquet] class ParquetRowConverter( override def addInt(value: Int): Unit = this.updater.setLong(Integer.toUnsignedLong(value)) } + case LongType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT32 => + new ParquetPrimitiveConverter(updater) { + override def addInt(value: Int): Unit = + this.updater.setLong(value) + } + case DoubleType if parquetType.asPrimitiveType().getPrimitiveTypeName == FLOAT => + new ParquetPrimitiveConverter(updater) { + override def addFloat(value: Float): Unit = + this.updater.setDouble(value) + } case BooleanType | IntegerType | LongType | FloatType | DoubleType | BinaryType | _: AnsiIntervalType => new ParquetPrimitiveConverter(updater) @@ -438,6 +448,16 @@ private[parquet] class ParquetRowConverter( } } + // Allow upcasting INT32 date to timestampNTZ. + case TimestampNTZType if schemaConverter.isTimestampNTZEnabled() && + parquetType.asPrimitiveType().getPrimitiveTypeName == INT32 && + parquetType.getLogicalTypeAnnotation.isInstanceOf[DateLogicalTypeAnnotation] => + new ParquetPrimitiveConverter(updater) { + override def addInt(value: Int): Unit = { + this.updater.set(DateTimeUtils.daysToMicros(dateRebaseFunc(value), ZoneOffset.UTC)) + } + } + case DateType => new ParquetPrimitiveConverter(updater) { override def addInt(value: Int): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 374a8a8078edc..6319d47ffb78b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -1070,17 +1070,6 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession } } - test("SPARK-35640: int as long should throw schema incompatible error") { - val data = (1 to 4).map(i => Tuple1(i)) - val readSchema = StructType(Seq(StructField("_1", DataTypes.LongType))) - - withParquetFile(data) { path => - val errMsg = intercept[Exception](spark.read.schema(readSchema).parquet(path).collect()) - .getMessage - assert(errMsg.contains("Parquet column cannot be converted in file")) - } - } - test("write metadata") { val hadoopConf = spark.sessionState.newHadoopConf() withTempPath { file => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 43103db522bac..41019c83f7896 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -1098,19 +1098,13 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS test("row group skipping doesn't overflow when reading into larger type") { withTempPath { path => Seq(0).toDF("a").write.parquet(path.toString) - // The vectorized and non-vectorized readers will produce different exceptions, we don't need - // to test both as this covers row group skipping. - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") { - // Reading integer 'a' as a long isn't supported. Check that an exception is raised instead - // of incorrectly skipping the single row group and producing incorrect results. - val exception = intercept[SparkException] { + withAllParquetReaders { + val result = spark.read .schema("a LONG") .parquet(path.toString) .where(s"a < ${Long.MaxValue}") - .collect() - } - assert(exception.getCause.getCause.isInstanceOf[SchemaColumnConvertNotSupportedException]) + checkAnswer(result, Row(0)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala new file mode 100644 index 0000000000000..e612ada038f38 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import java.io.File + +import org.apache.hadoop.fs.Path +import org.apache.parquet.format.converter.ParquetMetadataConverter +import org.apache.parquet.hadoop.{ParquetFileReader, ParquetOutputFormat} +import org.apache.spark.SparkException +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.{DataFrame, QueryTest, Row} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types._ + +class ParquetTypeWideningSuite + extends QueryTest + with ParquetTest + with SharedSparkSession + with AdaptiveSparkPlanHelper { + + import testImplicits._ + + /** + * Write a Parquet file with the given values stored using type `fromType` and read it back + * using type `toType` with each Parquet reader. If `expectError` returns true, check that an + * error is thrown during the read. Otherwise check that the data read matches the data written. + */ + private def checkAllParquetReaders( + values: Seq[String], + fromType: DataType, + toType: DataType, + expectError: => Boolean): Unit = { + val timestampRebaseModes = toType match { + case _: TimestampNTZType | _: DateType => Seq("CORRECTED", "LEGACY") + case _ => Seq("CORRECTED") + } + for { + dictionaryEnabled <- Seq(true, false) + timestampRebaseMode <- timestampRebaseModes + } + withClue( + s"with dictionary encoding '$dictionaryEnabled' with timestamp rebase mode " + + s"'$timestampRebaseMode''") { + withAllParquetWriters { + withTempDir { dir => + val expected = + writeParquetFiles(dir, values, fromType, dictionaryEnabled, timestampRebaseMode) + withAllParquetReaders { + if (expectError) { + val exception = intercept[SparkException] { + readParquetFiles(dir, toType).collect() + } + assert( + exception.getCause.getCause + .isInstanceOf[SchemaColumnConvertNotSupportedException] || + exception.getCause.getCause + .isInstanceOf[org.apache.parquet.io.ParquetDecodingException]) + } else { + checkAnswer(readParquetFiles(dir, toType), expected.select($"a".cast(toType))) + } + } + } + } + } + } + + /** + * Reads all parquet files in the given directory using the given type. + */ + private def readParquetFiles(dir: File, dataType: DataType): DataFrame = { + spark.read.schema(s"a ${dataType.sql}").parquet(dir.getAbsolutePath) + } + + /** + * Writes values to a parquet file in the given directory using the given type and returns a + * DataFrame corresponding to the data written. If dictionaryEnabled is true, the columns will + * be dictionary encoded. Each provided value is repeated 10 times to allow dictionary encoding + * to be used. timestampRebaseMode can be either "CORRECTED" or "LEGACY", see + * [[SQLConf.PARQUET_REBASE_MODE_IN_WRITE]] + */ + private def writeParquetFiles( + dir: File, + values: Seq[String], + dataType: DataType, + dictionaryEnabled: Boolean, + timestampRebaseMode: String = "CORRECTED"): DataFrame = { + val repeatedValues = List.fill(if (dictionaryEnabled) 10 else 1)(values).flatten + val df = repeatedValues.toDF("a").select(col("a").cast(dataType)) + withSQLConf( + ParquetOutputFormat.ENABLE_DICTIONARY -> dictionaryEnabled.toString, + SQLConf.PARQUET_REBASE_MODE_IN_WRITE.key -> timestampRebaseMode) { + df.write.mode("overwrite").parquet(dir.getAbsolutePath) + } + + // Decimals stored as byte arrays (precision > 18) are not dictionary encoded. + if (dictionaryEnabled && !DecimalType.isByteArrayDecimalType(dataType)) { + assertAllParquetFilesDictionaryEncoded(dir) + } + df + } + + /** + * Asserts that all parquet files in the given directory have all their columns dictionary + * encoded. + */ + private def assertAllParquetFilesDictionaryEncoded(dir: File): Unit = { + dir.listFiles(_.getName.endsWith(".parquet")).foreach { file => + val parquetMetadata = ParquetFileReader.readFooter( + spark.sessionState.newHadoopConf(), + new Path(dir.toString, file.getName), + ParquetMetadataConverter.NO_FILTER) + parquetMetadata.getBlocks.forEach { block => + block.getColumns.forEach { col => + assert( + col.hasDictionaryPage, + "This test covers dictionary encoding but column " + + s"'${col.getPath.toDotString}' in the test data is not dictionary encoded.") + } + } + } + } + + for { + case (values: Seq[String], fromType: DataType, toType: DataType) <- Seq( + (Seq("1", "2", Short.MinValue.toString), ShortType, IntegerType), + // Int->Short isn't a widening conversion but Parquet stores both as INT32 so it just works. + (Seq("1", "2", Short.MinValue.toString), IntegerType, ShortType), + (Seq("1", "2", Int.MinValue.toString), IntegerType, LongType), + (Seq("2020-01-01", "2020-01-02", "1312-02-27"), DateType, TimestampNTZType), + (Seq("1.23", "10.34"), FloatType, DoubleType)) + } + test(s"parquet widening conversion $fromType -> $toType") { + checkAllParquetReaders(values, fromType, toType, expectError = false) + } + + for { + case (values: Seq[String], fromType: DataType, toType: DataType) <- Seq( + (Seq("1", "2", Int.MinValue.toString), LongType, IntegerType), + // Test different timestamp types + (Seq("2020-01-01", "2020-01-02", "1312-02-27"), TimestampNTZType, DateType), + (Seq("1.23", "10.34"), DoubleType, FloatType)) + } + test(s"unsupported parquet conversion $fromType -> $toType") { + checkAllParquetReaders(values, fromType, toType, expectError = true) + } + + for { + (fromPrecision, toPrecision) <- + // Test widening and narrowing precision between the same and different decimal physical + // parquet types: + // - INT32: precisions 5, 7 + // - INT64: precisions 10, 12 + // - FIXED_LEN_BYTE_ARRAY: precisions 20, 22 + Seq(5 -> 7, 5 -> 10, 5 -> 20, 10 -> 12, 10 -> 20, 20 -> 22) ++ + Seq(7 -> 5, 10 -> 5, 20 -> 5, 12 -> 10, 20 -> 10, 22 -> 20) + } + test( + s"parquet decimal precision change Decimal($fromPrecision, 2) -> Decimal($toPrecision, 2)") { + checkAllParquetReaders( + values = Seq("1.23", "10.34"), + fromType = DecimalType(fromPrecision, 2), + toType = DecimalType(toPrecision, 2), + expectError = fromPrecision > toPrecision && + // parquet-mr allows reading decimals into a smaller precision decimal type without + // checking for overflows. See test below. + spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean) + } + + test("parquet decimal type change Decimal(5, 2) -> Decimal(3, 2) overflows with parquet-mr") { + withTempDir { dir => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + writeParquetFiles( + dir, + values = Seq("123.45", "999.99"), + DecimalType(5, 2), + dictionaryEnabled = false) + checkAnswer(readParquetFiles(dir, DecimalType(3, 2)), Row(null) :: Row(null) :: Nil) + } + } + } + + test("parquet decimal type change IntegerType -> ShortType overflows") { + withTempDir { dir => + withAllParquetReaders { + // Int & Short are both stored as INT32 in Parquet but Int.MinValue will overflow when + // reading as Short in Spark. + val overflowValue = Short.MaxValue.toInt + 1 + writeParquetFiles( + dir, + Seq(overflowValue.toString), + IntegerType, + dictionaryEnabled = false) + checkAnswer(readParquetFiles(dir, ShortType), Row(Short.MinValue)) + } + } + } +} From cb1487edc57ab47827db245d741be408d300682a Mon Sep 17 00:00:00 2001 From: Johan Lasperas Date: Fri, 15 Dec 2023 15:11:56 +0100 Subject: [PATCH 2/6] Fix formatting --- .../parquet/ParquetVectorUpdaterFactory.java | 15 ++++++++++----- .../parquet/VectorizedColumnReader.java | 6 ++++-- .../parquet/ParquetTypeWideningSuite.scala | 7 ++++--- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java index 0ece8fefac09a..729812252b102 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java @@ -339,7 +339,8 @@ public void readValues( WritableColumnVector values, VectorizedValuesReader valuesReader) { for (int i = 0; i < total; ++i) { - values.putLong(offset + i, DateTimeUtils.daysToMicros(valuesReader.readInteger(), ZoneOffset.UTC)); + long days = DateTimeUtils.daysToMicros(valuesReader.readInteger(), ZoneOffset.UTC); + values.putLong(offset + i, days); } } @@ -353,7 +354,8 @@ public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { - values.putLong(offset, DateTimeUtils.daysToMicros(valuesReader.readInteger(), ZoneOffset.UTC)); + long days = DateTimeUtils.daysToMicros(valuesReader.readInteger(), ZoneOffset.UTC); + values.putLong(offset, days); } @Override @@ -406,7 +408,8 @@ public void decodeSingleDictionaryId( WritableColumnVector values, WritableColumnVector dictionaryIds, Dictionary dictionary) { - int rebasedDays = rebaseDays(dictionary.decodeToInt(dictionaryIds.getDictId(offset)), failIfRebase); + int rebasedDays = + rebaseDays(dictionary.decodeToInt(dictionaryIds.getDictId(offset)), failIfRebase); values.putLong(offset, DateTimeUtils.daysToMicros(rebasedDays, ZoneOffset.UTC)); } } @@ -949,7 +952,8 @@ public void decodeSingleDictionaryId( WritableColumnVector values, WritableColumnVector dictionaryIds, Dictionary dictionary) { - BigInteger value = BigInteger.valueOf(dictionary.decodeToInt(dictionaryIds.getDictId(offset))); + BigInteger value = + BigInteger.valueOf(dictionary.decodeToInt(dictionaryIds.getDictId(offset))); values.putByteArray(offset, value.toByteArray()); } } @@ -987,7 +991,8 @@ public void decodeSingleDictionaryId( WritableColumnVector values, WritableColumnVector dictionaryIds, Dictionary dictionary) { - BigInteger value = BigInteger.valueOf(dictionary.decodeToLong(dictionaryIds.getDictId(offset))); + BigInteger value = + BigInteger.valueOf(dictionary.decodeToLong(dictionaryIds.getDictId(offset))); values.putByteArray(offset, value.toByteArray()); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 33b9412c37663..ba4667d709791 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -151,13 +151,15 @@ private boolean isLazyDecodingSupported(PrimitiveType.PrimitiveTypeName typeName case INT32 -> { boolean needsUpcast = sparkType == LongType || sparkType == TimestampNTZType || !DecimalType.is32BitDecimalType(sparkType); - boolean needsRebase = logicalTypeAnnotation instanceof DateLogicalTypeAnnotation && !"CORRECTED".equals(datetimeRebaseMode); + boolean needsRebase = logicalTypeAnnotation instanceof DateLogicalTypeAnnotation && + !"CORRECTED".equals(datetimeRebaseMode); yield !needsUpcast && !needsRebase && !needsDecimalScaleRebase(sparkType); } case INT64 -> { boolean needsUpcast = !DecimalType.is64BitDecimalType(sparkType) || updaterFactory.isTimestampTypeMatched(TimeUnit.MILLIS); - boolean needsRebase = updaterFactory.isTimestampTypeMatched(TimeUnit.MICROS) && !"CORRECTED".equals(datetimeRebaseMode); + boolean needsRebase = updaterFactory.isTimestampTypeMatched(TimeUnit.MICROS) && + !"CORRECTED".equals(datetimeRebaseMode); yield !needsUpcast && !needsRebase && !needsDecimalScaleRebase(sparkType); } case FLOAT -> sparkType == FloatType; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala index e612ada038f38..b862c3385592b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala @@ -21,11 +21,12 @@ import java.io.File import org.apache.hadoop.fs.Path import org.apache.parquet.format.converter.ParquetMetadataConverter import org.apache.parquet.hadoop.{ParquetFileReader, ParquetOutputFormat} + import org.apache.spark.SparkException +import org.apache.spark.sql.{DataFrame, QueryTest, Row} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException import org.apache.spark.sql.functions.col -import org.apache.spark.sql.{DataFrame, QueryTest, Row} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ @@ -139,7 +140,7 @@ class ParquetTypeWideningSuite } for { - case (values: Seq[String], fromType: DataType, toType: DataType) <- Seq( + (values: Seq[String], fromType: DataType, toType: DataType) <- Seq( (Seq("1", "2", Short.MinValue.toString), ShortType, IntegerType), // Int->Short isn't a widening conversion but Parquet stores both as INT32 so it just works. (Seq("1", "2", Short.MinValue.toString), IntegerType, ShortType), @@ -152,7 +153,7 @@ class ParquetTypeWideningSuite } for { - case (values: Seq[String], fromType: DataType, toType: DataType) <- Seq( + (values: Seq[String], fromType: DataType, toType: DataType) <- Seq( (Seq("1", "2", Int.MinValue.toString), LongType, IntegerType), // Test different timestamp types (Seq("2020-01-01", "2020-01-02", "1312-02-27"), TimestampNTZType, DateType), From dc8b489ca79b6ee8c7572acbc6dea41b775e5532 Mon Sep 17 00:00:00 2001 From: Johan Lasperas Date: Mon, 18 Dec 2023 13:07:57 +0100 Subject: [PATCH 3/6] Add int->double + more tests --- .../parquet/ParquetVectorUpdaterFactory.java | 37 +++++++++++++++++++ .../parquet/ParquetRowConverter.scala | 5 +++ .../parquet/ParquetTypeWideningSuite.scala | 14 ++++--- 3 files changed, 51 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java index 729812252b102..8fce52051d7dc 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java @@ -89,6 +89,8 @@ public ParquetVectorUpdater getUpdater(ColumnDescriptor descriptor, DataType spa return new ByteUpdater(); } else if (sparkType == DataTypes.ShortType) { return new ShortUpdater(); + } else if (sparkType == DataTypes.DoubleType) { + return new IntegerToDoubleUpdater(); } else if (sparkType == DataTypes.DateType) { if ("CORRECTED".equals(datetimeRebaseMode)) { return new IntegerUpdater(); @@ -331,6 +333,41 @@ public void decodeSingleDictionaryId( } } + static class IntegerToDoubleUpdater implements ParquetVectorUpdater { + @Override + public void readValues( + int total, + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + for (int i = 0; i < total; ++i) { + values.putDouble(offset + i, valuesReader.readInteger()); + } + } + + @Override + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipIntegers(total); + } + + @Override + public void readValue( + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + values.putDouble(offset, valuesReader.readInteger()); + } + + @Override + public void decodeSingleDictionaryId( + int offset, + WritableColumnVector values, + WritableColumnVector dictionaryIds, + Dictionary dictionary) { + values.putDouble(offset, dictionary.decodeToInt(dictionaryIds.getDictId(offset))); + } + } + static class DateToTimestampNTZUpdater implements ParquetVectorUpdater { @Override public void readValues( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index dd72ce6b31961..b2222f4297e90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -318,6 +318,11 @@ private[parquet] class ParquetRowConverter( override def addInt(value: Int): Unit = this.updater.setLong(value) } + case DoubleType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT32 => + new ParquetPrimitiveConverter(updater) { + override def addInt(value: Int): Unit = + this.updater.setDouble(value) + } case DoubleType if parquetType.asPrimitiveType().getPrimitiveTypeName == FLOAT => new ParquetPrimitiveConverter(updater) { override def addFloat(value: Float): Unit = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala index b862c3385592b..811907e39c202 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala @@ -145,8 +145,11 @@ class ParquetTypeWideningSuite // Int->Short isn't a widening conversion but Parquet stores both as INT32 so it just works. (Seq("1", "2", Short.MinValue.toString), IntegerType, ShortType), (Seq("1", "2", Int.MinValue.toString), IntegerType, LongType), - (Seq("2020-01-01", "2020-01-02", "1312-02-27"), DateType, TimestampNTZType), - (Seq("1.23", "10.34"), FloatType, DoubleType)) + (Seq("1", "2", Short.MinValue.toString), ShortType, DoubleType), + (Seq("1", "2", Int.MinValue.toString), IntegerType, DoubleType), + (Seq("1.23", "10.34"), FloatType, DoubleType), + (Seq("2020-01-01", "2020-01-02", "1312-02-27"), DateType, TimestampNTZType) + ) } test(s"parquet widening conversion $fromType -> $toType") { checkAllParquetReaders(values, fromType, toType, expectError = false) @@ -155,9 +158,10 @@ class ParquetTypeWideningSuite for { (values: Seq[String], fromType: DataType, toType: DataType) <- Seq( (Seq("1", "2", Int.MinValue.toString), LongType, IntegerType), - // Test different timestamp types - (Seq("2020-01-01", "2020-01-02", "1312-02-27"), TimestampNTZType, DateType), - (Seq("1.23", "10.34"), DoubleType, FloatType)) + (Seq("1.23", "10.34"), DoubleType, FloatType), + (Seq("1.23", "10.34"), FloatType, LongType), + (Seq("2020-01-01", "2020-01-02", "1312-02-27"), TimestampNTZType, DateType) + ) } test(s"unsupported parquet conversion $fromType -> $toType") { checkAllParquetReaders(values, fromType, toType, expectError = true) From dd1975c2030cc88db13e225fe32a30816ab59244 Mon Sep 17 00:00:00 2001 From: Johan Lasperas Date: Tue, 19 Dec 2023 18:03:19 +0100 Subject: [PATCH 4/6] Address comments + add test cases --- .../parquet/ParquetVectorUpdaterFactory.java | 14 ++++-- .../parquet/VectorizedColumnReader.java | 47 +++++++++---------- .../parquet/ParquetTypeWideningSuite.scala | 40 ++++++++++++---- 3 files changed, 62 insertions(+), 39 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java index 8fce52051d7dc..2bccba0fc94ca 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java @@ -22,6 +22,7 @@ import org.apache.parquet.io.api.Binary; import org.apache.parquet.schema.LogicalTypeAnnotation; import org.apache.parquet.schema.LogicalTypeAnnotation.IntLogicalTypeAnnotation; +import org.apache.parquet.schema.LogicalTypeAnnotation.DateLogicalTypeAnnotation; import org.apache.parquet.schema.LogicalTypeAnnotation.DecimalLogicalTypeAnnotation; import org.apache.parquet.schema.LogicalTypeAnnotation.TimestampLogicalTypeAnnotation; import org.apache.parquet.schema.PrimitiveType; @@ -98,7 +99,7 @@ public ParquetVectorUpdater getUpdater(ColumnDescriptor descriptor, DataType spa boolean failIfRebase = "EXCEPTION".equals(datetimeRebaseMode); return new IntegerWithRebaseUpdater(failIfRebase); } - } else if (sparkType == DataTypes.TimestampNTZType) { + } else if (sparkType == DataTypes.TimestampNTZType && isDateTypeMatched(descriptor)) { if ("CORRECTED".equals(datetimeRebaseMode)) { return new DateToTimestampNTZUpdater(); } else { @@ -376,8 +377,7 @@ public void readValues( WritableColumnVector values, VectorizedValuesReader valuesReader) { for (int i = 0; i < total; ++i) { - long days = DateTimeUtils.daysToMicros(valuesReader.readInteger(), ZoneOffset.UTC); - values.putLong(offset + i, days); + readValue(offset + i, values, valuesReader); } } @@ -420,8 +420,7 @@ public void readValues( WritableColumnVector values, VectorizedValuesReader valuesReader) { for (int i = 0; i < total; ++i) { - int rebasedDays = rebaseDays(valuesReader.readInteger(), failIfRebase); - values.putLong(offset + i, DateTimeUtils.daysToMicros(rebasedDays, ZoneOffset.UTC)); + readValue(offset + i, values, valuesReader); } } @@ -1436,6 +1435,11 @@ private static boolean isTimestamp(DataType dt) { return dt == DataTypes.TimestampType || dt == DataTypes.TimestampNTZType; } + boolean isDateTypeMatched(ColumnDescriptor descriptor) { + LogicalTypeAnnotation typeAnnotation = descriptor.getPrimitiveType().getLogicalTypeAnnotation(); + return typeAnnotation instanceof DateLogicalTypeAnnotation; + } + private static boolean isDecimalTypeMatched(ColumnDescriptor descriptor, DataType dt) { DecimalType d = (DecimalType) dt; LogicalTypeAnnotation typeAnnotation = descriptor.getPrimitiveType().getLogicalTypeAnnotation(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index ba4667d709791..8273ce3fa6487 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -145,42 +145,37 @@ public VectorizedColumnReader( private boolean isLazyDecodingSupported(PrimitiveType.PrimitiveTypeName typeName, DataType sparkType) { - // Don't use lazy dictionary decoding if the column needs extra processing: upcasting or date / - // decimal scale rebasing. - return switch (typeName) { - case INT32 -> { + boolean isSupported = false; + // Don't use lazy dictionary decoding if the column needs extra processing: upcasting or date + // rebasing. + switch (typeName) { + case INT32: { boolean needsUpcast = sparkType == LongType || sparkType == TimestampNTZType || !DecimalType.is32BitDecimalType(sparkType); boolean needsRebase = logicalTypeAnnotation instanceof DateLogicalTypeAnnotation && - !"CORRECTED".equals(datetimeRebaseMode); - yield !needsUpcast && !needsRebase && !needsDecimalScaleRebase(sparkType); + !"CORRECTED".equals(datetimeRebaseMode); + isSupported = !needsUpcast && !needsRebase; + break; } - case INT64 -> { + case INT64: { boolean needsUpcast = !DecimalType.is64BitDecimalType(sparkType) || updaterFactory.isTimestampTypeMatched(TimeUnit.MILLIS); boolean needsRebase = updaterFactory.isTimestampTypeMatched(TimeUnit.MICROS) && - !"CORRECTED".equals(datetimeRebaseMode); - yield !needsUpcast && !needsRebase && !needsDecimalScaleRebase(sparkType); + !"CORRECTED".equals(datetimeRebaseMode); + isSupported = !needsUpcast && !needsRebase; + break; } - case FLOAT -> sparkType == FloatType; - case DOUBLE, BINARY -> !needsDecimalScaleRebase(sparkType); - default -> false; - }; + case FLOAT: + isSupported = sparkType == FloatType; + break; + case DOUBLE: + case BINARY: + isSupported = true; + break; + } + return isSupported; } - /** - * Returns whether the Parquet type of this column and the given spark type are two decimal types - * with different scales. - */ - private boolean needsDecimalScaleRebase(DataType sparkType) { - LogicalTypeAnnotation typeAnnotation = descriptor.getPrimitiveType().getLogicalTypeAnnotation(); - if (!(typeAnnotation instanceof DecimalLogicalTypeAnnotation)) return false; - if (!(sparkType instanceof DecimalType)) return false; - DecimalLogicalTypeAnnotation parquetDecimal = (DecimalLogicalTypeAnnotation) typeAnnotation; - DecimalType sparkDecimal = (DecimalType) sparkType; - return parquetDecimal.getScale() != sparkDecimal.scale(); -} - /** * Reads `total` rows from this columnReader into column. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala index 811907e39c202..72580f7078e23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala @@ -27,7 +27,8 @@ import org.apache.spark.sql.{DataFrame, QueryTest, Row} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException import org.apache.spark.sql.functions.col -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} +import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ @@ -50,8 +51,10 @@ class ParquetTypeWideningSuite toType: DataType, expectError: => Boolean): Unit = { val timestampRebaseModes = toType match { - case _: TimestampNTZType | _: DateType => Seq("CORRECTED", "LEGACY") - case _ => Seq("CORRECTED") + case _: TimestampNTZType | _: DateType => + Seq(LegacyBehaviorPolicy.CORRECTED, LegacyBehaviorPolicy.LEGACY) + case _ => + Seq(LegacyBehaviorPolicy.CORRECTED) } for { dictionaryEnabled <- Seq(true, false) @@ -72,8 +75,10 @@ class ParquetTypeWideningSuite assert( exception.getCause.getCause .isInstanceOf[SchemaColumnConvertNotSupportedException] || - exception.getCause.getCause - .isInstanceOf[org.apache.parquet.io.ParquetDecodingException]) + exception.getCause.getCause + .isInstanceOf[org.apache.parquet.io.ParquetDecodingException] || + exception.getCause.getMessage.contains( + "Unable to create Parquet converter for data type")) } else { checkAnswer(readParquetFiles(dir, toType), expected.select($"a".cast(toType))) } @@ -102,12 +107,13 @@ class ParquetTypeWideningSuite values: Seq[String], dataType: DataType, dictionaryEnabled: Boolean, - timestampRebaseMode: String = "CORRECTED"): DataFrame = { + timestampRebaseMode: LegacyBehaviorPolicy.Value = LegacyBehaviorPolicy.CORRECTED) + : DataFrame = { val repeatedValues = List.fill(if (dictionaryEnabled) 10 else 1)(values).flatten val df = repeatedValues.toDF("a").select(col("a").cast(dataType)) withSQLConf( ParquetOutputFormat.ENABLE_DICTIONARY -> dictionaryEnabled.toString, - SQLConf.PARQUET_REBASE_MODE_IN_WRITE.key -> timestampRebaseMode) { + SQLConf.PARQUET_REBASE_MODE_IN_WRITE.key -> timestampRebaseMode.toString) { df.write.mode("overwrite").parquet(dir.getAbsolutePath) } @@ -160,13 +166,31 @@ class ParquetTypeWideningSuite (Seq("1", "2", Int.MinValue.toString), LongType, IntegerType), (Seq("1.23", "10.34"), DoubleType, FloatType), (Seq("1.23", "10.34"), FloatType, LongType), - (Seq("2020-01-01", "2020-01-02", "1312-02-27"), TimestampNTZType, DateType) + (Seq("1.23", "10.34"), LongType, DateType), + (Seq("1.23", "10.34"), IntegerType, TimestampType), + (Seq("1.23", "10.34"), IntegerType, TimestampNTZType), + (Seq("2020-01-01", "2020-01-02", "1312-02-27"), DateType, TimestampType) ) } test(s"unsupported parquet conversion $fromType -> $toType") { checkAllParquetReaders(values, fromType, toType, expectError = true) } + for { + (values: Seq[String], fromType: DataType, toType: DataType) <- Seq( + (Seq("2020-01-01", "2020-01-02", "1312-02-27"), TimestampType, DateType), + (Seq("2020-01-01", "2020-01-02", "1312-02-27"), TimestampNTZType, DateType)) + outputTimestampType <- ParquetOutputTimestampType.values + } + test(s"unsupported parquet timestamp conversion $fromType ($outputTimestampType) -> $toType") { + withSQLConf( + SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> outputTimestampType.toString, + SQLConf.PARQUET_INT96_REBASE_MODE_IN_WRITE.key -> LegacyBehaviorPolicy.CORRECTED.toString + ) { + checkAllParquetReaders(values, fromType, toType, expectError = true) + } + } + for { (fromPrecision, toPrecision) <- // Test widening and narrowing precision between the same and different decimal physical From d6316408115f76cacebacbb7cf2179add0bc9d29 Mon Sep 17 00:00:00 2001 From: Johan Lasperas Date: Tue, 19 Dec 2023 18:19:45 +0100 Subject: [PATCH 5/6] nit --- .../datasources/parquet/ParquetVectorUpdaterFactory.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java index 2bccba0fc94ca..12b0ec911b625 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java @@ -1435,7 +1435,7 @@ private static boolean isTimestamp(DataType dt) { return dt == DataTypes.TimestampType || dt == DataTypes.TimestampNTZType; } - boolean isDateTypeMatched(ColumnDescriptor descriptor) { + private static boolean isDateTypeMatched(ColumnDescriptor descriptor) { LogicalTypeAnnotation typeAnnotation = descriptor.getPrimitiveType().getLogicalTypeAnnotation(); return typeAnnotation instanceof DateLogicalTypeAnnotation; } From 6e01b4bfbe41bcead3746dcc71b21373b144a534 Mon Sep 17 00:00:00 2001 From: Johan Lasperas Date: Thu, 21 Dec 2023 10:14:10 +0100 Subject: [PATCH 6/6] Address comments --- .../parquet/VectorizedColumnReader.java | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 8273ce3fa6487..7c9bca6710aa0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -143,25 +143,27 @@ public VectorizedColumnReader( this.writerVersion = writerVersion; } - private boolean isLazyDecodingSupported(PrimitiveType.PrimitiveTypeName typeName, - DataType sparkType) { + private boolean isLazyDecodingSupported( + PrimitiveType.PrimitiveTypeName typeName, + DataType sparkType) { boolean isSupported = false; // Don't use lazy dictionary decoding if the column needs extra processing: upcasting or date // rebasing. switch (typeName) { case INT32: { - boolean needsUpcast = sparkType == LongType || sparkType == TimestampNTZType || - !DecimalType.is32BitDecimalType(sparkType); + boolean isDate = logicalTypeAnnotation instanceof DateLogicalTypeAnnotation; + boolean needsUpcast = sparkType == LongType || (isDate && sparkType == TimestampNTZType) || + !DecimalType.is32BitDecimalType(sparkType); boolean needsRebase = logicalTypeAnnotation instanceof DateLogicalTypeAnnotation && - !"CORRECTED".equals(datetimeRebaseMode); + !"CORRECTED".equals(datetimeRebaseMode); isSupported = !needsUpcast && !needsRebase; break; } case INT64: { boolean needsUpcast = !DecimalType.is64BitDecimalType(sparkType) || - updaterFactory.isTimestampTypeMatched(TimeUnit.MILLIS); + updaterFactory.isTimestampTypeMatched(TimeUnit.MILLIS); boolean needsRebase = updaterFactory.isTimestampTypeMatched(TimeUnit.MICROS) && - !"CORRECTED".equals(datetimeRebaseMode); + !"CORRECTED".equals(datetimeRebaseMode); isSupported = !needsUpcast && !needsRebase; break; }