diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java index 31a1957b4fb91..4961b52f4bb53 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java @@ -22,6 +22,7 @@ import org.apache.parquet.io.api.Binary; import org.apache.parquet.schema.LogicalTypeAnnotation; import org.apache.parquet.schema.LogicalTypeAnnotation.IntLogicalTypeAnnotation; +import org.apache.parquet.schema.LogicalTypeAnnotation.DateLogicalTypeAnnotation; import org.apache.parquet.schema.LogicalTypeAnnotation.DecimalLogicalTypeAnnotation; import org.apache.parquet.schema.LogicalTypeAnnotation.TimestampLogicalTypeAnnotation; import org.apache.parquet.schema.PrimitiveType; @@ -81,10 +82,16 @@ public ParquetVectorUpdater getUpdater(ColumnDescriptor descriptor, DataType spa // For unsigned int32, it stores as plain signed int32 in Parquet when dictionary // fallbacks. We read them as long values. return new UnsignedIntegerUpdater(); + } else if (sparkType == DataTypes.LongType || canReadAsLongDecimal(descriptor, sparkType)) { + return new IntegerToLongUpdater(); + } else if (canReadAsBinaryDecimal(descriptor, sparkType)) { + return new IntegerToBinaryUpdater(); } else if (sparkType == DataTypes.ByteType) { return new ByteUpdater(); } else if (sparkType == DataTypes.ShortType) { return new ShortUpdater(); + } else if (sparkType == DataTypes.DoubleType) { + return new IntegerToDoubleUpdater(); } else if (sparkType == DataTypes.DateType) { if ("CORRECTED".equals(datetimeRebaseMode)) { return new IntegerUpdater(); @@ -92,6 +99,13 @@ public ParquetVectorUpdater getUpdater(ColumnDescriptor descriptor, DataType spa boolean failIfRebase = "EXCEPTION".equals(datetimeRebaseMode); return new IntegerWithRebaseUpdater(failIfRebase); } + } else if (sparkType == DataTypes.TimestampNTZType && isDateTypeMatched(descriptor)) { + if ("CORRECTED".equals(datetimeRebaseMode)) { + return new DateToTimestampNTZUpdater(); + } else { + boolean failIfRebase = "EXCEPTION".equals(datetimeRebaseMode); + return new DateToTimestampNTZWithRebaseUpdater(failIfRebase); + } } else if (sparkType instanceof YearMonthIntervalType) { return new IntegerUpdater(); } @@ -104,6 +118,8 @@ public ParquetVectorUpdater getUpdater(ColumnDescriptor descriptor, DataType spa } else { return new LongUpdater(); } + } else if (canReadAsBinaryDecimal(descriptor, sparkType)) { + return new LongToBinaryUpdater(); } else if (isLongDecimal(sparkType) && isUnsignedIntTypeMatched(64)) { // In `ParquetToSparkSchemaConverter`, we map parquet UINT64 to our Decimal(20, 0). // For unsigned int64, it stores as plain signed int64 in Parquet when dictionary @@ -142,6 +158,8 @@ public ParquetVectorUpdater getUpdater(ColumnDescriptor descriptor, DataType spa case FLOAT -> { if (sparkType == DataTypes.FloatType) { return new FloatUpdater(); + } else if (sparkType == DataTypes.DoubleType) { + return new FloatToDoubleUpdater(); } } case DOUBLE -> { @@ -288,6 +306,157 @@ public void decodeSingleDictionaryId( } } + static class IntegerToLongUpdater implements ParquetVectorUpdater { + @Override + public void readValues( + int total, + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + for (int i = 0; i < total; ++i) { + values.putLong(offset + i, valuesReader.readInteger()); + } + } + + @Override + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipIntegers(total); + } + + @Override + public void readValue( + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + values.putLong(offset, valuesReader.readInteger()); + } + + @Override + public void decodeSingleDictionaryId( + int offset, + WritableColumnVector values, + WritableColumnVector dictionaryIds, + Dictionary dictionary) { + values.putLong(offset, dictionary.decodeToInt(dictionaryIds.getDictId(offset))); + } + } + + static class IntegerToDoubleUpdater implements ParquetVectorUpdater { + @Override + public void readValues( + int total, + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + for (int i = 0; i < total; ++i) { + values.putDouble(offset + i, valuesReader.readInteger()); + } + } + + @Override + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipIntegers(total); + } + + @Override + public void readValue( + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + values.putDouble(offset, valuesReader.readInteger()); + } + + @Override + public void decodeSingleDictionaryId( + int offset, + WritableColumnVector values, + WritableColumnVector dictionaryIds, + Dictionary dictionary) { + values.putDouble(offset, dictionary.decodeToInt(dictionaryIds.getDictId(offset))); + } + } + + static class DateToTimestampNTZUpdater implements ParquetVectorUpdater { + @Override + public void readValues( + int total, + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + for (int i = 0; i < total; ++i) { + readValue(offset + i, values, valuesReader); + } + } + + @Override + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipIntegers(total); + } + + @Override + public void readValue( + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + long days = DateTimeUtils.daysToMicros(valuesReader.readInteger(), ZoneOffset.UTC); + values.putLong(offset, days); + } + + @Override + public void decodeSingleDictionaryId( + int offset, + WritableColumnVector values, + WritableColumnVector dictionaryIds, + Dictionary dictionary) { + int days = dictionary.decodeToInt(dictionaryIds.getDictId(offset)); + values.putLong(offset, DateTimeUtils.daysToMicros(days, ZoneOffset.UTC)); + } + } + + private static class DateToTimestampNTZWithRebaseUpdater implements ParquetVectorUpdater { + private final boolean failIfRebase; + + DateToTimestampNTZWithRebaseUpdater(boolean failIfRebase) { + this.failIfRebase = failIfRebase; + } + + @Override + public void readValues( + int total, + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + for (int i = 0; i < total; ++i) { + readValue(offset + i, values, valuesReader); + } + } + + @Override + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipIntegers(total); + } + + @Override + public void readValue( + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + int rebasedDays = rebaseDays(valuesReader.readInteger(), failIfRebase); + values.putLong(offset, DateTimeUtils.daysToMicros(rebasedDays, ZoneOffset.UTC)); + } + + @Override + public void decodeSingleDictionaryId( + int offset, + WritableColumnVector values, + WritableColumnVector dictionaryIds, + Dictionary dictionary) { + int rebasedDays = + rebaseDays(dictionary.decodeToInt(dictionaryIds.getDictId(offset)), failIfRebase); + values.putLong(offset, DateTimeUtils.daysToMicros(rebasedDays, ZoneOffset.UTC)); + } + } + private static class UnsignedIntegerUpdater implements ParquetVectorUpdater { @Override public void readValues( @@ -691,6 +860,41 @@ public void decodeSingleDictionaryId( } } + static class FloatToDoubleUpdater implements ParquetVectorUpdater { + @Override + public void readValues( + int total, + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + for (int i = 0; i < total; ++i) { + values.putDouble(offset + i, valuesReader.readFloat()); + } + } + + @Override + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipFloats(total); + } + + @Override + public void readValue( + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + values.putDouble(offset, valuesReader.readFloat()); + } + + @Override + public void decodeSingleDictionaryId( + int offset, + WritableColumnVector values, + WritableColumnVector dictionaryIds, + Dictionary dictionary) { + values.putDouble(offset, dictionary.decodeToFloat(dictionaryIds.getDictId(offset))); + } + } + private static class DoubleUpdater implements ParquetVectorUpdater { @Override public void readValues( @@ -758,6 +962,84 @@ public void decodeSingleDictionaryId( } } + private static class IntegerToBinaryUpdater implements ParquetVectorUpdater { + + @Override + public void readValues( + int total, + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + for (int i = 0; i < total; i++) { + readValue(offset + i, values, valuesReader); + } + } + + @Override + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipIntegers(total); + } + + @Override + public void readValue( + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + BigInteger value = BigInteger.valueOf(valuesReader.readInteger()); + values.putByteArray(offset, value.toByteArray()); + } + + @Override + public void decodeSingleDictionaryId( + int offset, + WritableColumnVector values, + WritableColumnVector dictionaryIds, + Dictionary dictionary) { + BigInteger value = + BigInteger.valueOf(dictionary.decodeToInt(dictionaryIds.getDictId(offset))); + values.putByteArray(offset, value.toByteArray()); + } + } + + private static class LongToBinaryUpdater implements ParquetVectorUpdater { + + @Override + public void readValues( + int total, + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + for (int i = 0; i < total; i++) { + readValue(offset + i, values, valuesReader); + } + } + + @Override + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipLongs(total); + } + + @Override + public void readValue( + int offset, + WritableColumnVector values, + VectorizedValuesReader valuesReader) { + BigInteger value = BigInteger.valueOf(valuesReader.readLong()); + values.putByteArray(offset, value.toByteArray()); + } + + @Override + public void decodeSingleDictionaryId( + int offset, + WritableColumnVector values, + WritableColumnVector dictionaryIds, + Dictionary dictionary) { + BigInteger value = + BigInteger.valueOf(dictionary.decodeToLong(dictionaryIds.getDictId(offset))); + values.putByteArray(offset, value.toByteArray()); + } + } + private static class BinaryToSQLTimestampUpdater implements ParquetVectorUpdater { @Override public void readValues( @@ -1156,6 +1438,11 @@ private static boolean isLongDecimal(DataType dt) { return false; } + private static boolean isDateTypeMatched(ColumnDescriptor descriptor) { + LogicalTypeAnnotation typeAnnotation = descriptor.getPrimitiveType().getLogicalTypeAnnotation(); + return typeAnnotation instanceof DateLogicalTypeAnnotation; + } + private static boolean isDecimalTypeMatched(ColumnDescriptor descriptor, DataType dt) { DecimalType d = (DecimalType) dt; LogicalTypeAnnotation typeAnnotation = descriptor.getPrimitiveType().getLogicalTypeAnnotation(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 04fbe716ad92f..7c9bca6710aa0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -38,10 +38,13 @@ import org.apache.parquet.schema.PrimitiveType; import org.apache.spark.sql.execution.vectorized.WritableColumnVector; +import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.DecimalType; import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.BOOLEAN; import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT64; +import static org.apache.spark.sql.types.DataTypes.*; /** * Decoder to return values from a single column. @@ -140,21 +143,39 @@ public VectorizedColumnReader( this.writerVersion = writerVersion; } - private boolean isLazyDecodingSupported(PrimitiveType.PrimitiveTypeName typeName) { - return switch (typeName) { - case INT32 -> - !(logicalTypeAnnotation instanceof DateLogicalTypeAnnotation) || - "CORRECTED".equals(datetimeRebaseMode); - case INT64 -> { - if (updaterFactory.isTimestampTypeMatched(TimeUnit.MICROS)) { - yield "CORRECTED".equals(datetimeRebaseMode); - } else { - yield !updaterFactory.isTimestampTypeMatched(TimeUnit.MILLIS); - } + private boolean isLazyDecodingSupported( + PrimitiveType.PrimitiveTypeName typeName, + DataType sparkType) { + boolean isSupported = false; + // Don't use lazy dictionary decoding if the column needs extra processing: upcasting or date + // rebasing. + switch (typeName) { + case INT32: { + boolean isDate = logicalTypeAnnotation instanceof DateLogicalTypeAnnotation; + boolean needsUpcast = sparkType == LongType || (isDate && sparkType == TimestampNTZType) || + !DecimalType.is32BitDecimalType(sparkType); + boolean needsRebase = logicalTypeAnnotation instanceof DateLogicalTypeAnnotation && + !"CORRECTED".equals(datetimeRebaseMode); + isSupported = !needsUpcast && !needsRebase; + break; } - case FLOAT, DOUBLE, BINARY -> true; - default -> false; - }; + case INT64: { + boolean needsUpcast = !DecimalType.is64BitDecimalType(sparkType) || + updaterFactory.isTimestampTypeMatched(TimeUnit.MILLIS); + boolean needsRebase = updaterFactory.isTimestampTypeMatched(TimeUnit.MICROS) && + !"CORRECTED".equals(datetimeRebaseMode); + isSupported = !needsUpcast && !needsRebase; + break; + } + case FLOAT: + isSupported = sparkType == FloatType; + break; + case DOUBLE: + case BINARY: + isSupported = true; + break; + } + return isSupported; } /** @@ -205,7 +226,7 @@ void readBatch( // TIMESTAMP_MILLIS encoded as INT64 can't be lazily decoded as we need to post process // the values to add microseconds precision. if (column.hasDictionary() || (startRowId == pageFirstRowIndex && - isLazyDecodingSupported(typeName))) { + isLazyDecodingSupported(typeName, column.dataType()))) { // Column vector supports lazy decoding of dictionary values so just set the dictionary. // We can't do this if startRowId is not the first row index in the page AND the column // doesn't have a dictionary (i.e. some non-dictionary encoded values have already been diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index b3be89085014e..b2222f4297e90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -29,7 +29,7 @@ import org.apache.parquet.io.ColumnIOFactory import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter} import org.apache.parquet.schema.{GroupType, Type, Types} import org.apache.parquet.schema.LogicalTypeAnnotation._ -import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.{BINARY, FIXED_LEN_BYTE_ARRAY, INT32, INT64, INT96} +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.{BINARY, FIXED_LEN_BYTE_ARRAY, FLOAT, INT32, INT64, INT96} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow @@ -313,6 +313,21 @@ private[parquet] class ParquetRowConverter( override def addInt(value: Int): Unit = this.updater.setLong(Integer.toUnsignedLong(value)) } + case LongType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT32 => + new ParquetPrimitiveConverter(updater) { + override def addInt(value: Int): Unit = + this.updater.setLong(value) + } + case DoubleType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT32 => + new ParquetPrimitiveConverter(updater) { + override def addInt(value: Int): Unit = + this.updater.setDouble(value) + } + case DoubleType if parquetType.asPrimitiveType().getPrimitiveTypeName == FLOAT => + new ParquetPrimitiveConverter(updater) { + override def addFloat(value: Float): Unit = + this.updater.setDouble(value) + } case BooleanType | IntegerType | LongType | FloatType | DoubleType | BinaryType | _: AnsiIntervalType => new ParquetPrimitiveConverter(updater) @@ -438,6 +453,16 @@ private[parquet] class ParquetRowConverter( } } + // Allow upcasting INT32 date to timestampNTZ. + case TimestampNTZType if schemaConverter.isTimestampNTZEnabled() && + parquetType.asPrimitiveType().getPrimitiveTypeName == INT32 && + parquetType.getLogicalTypeAnnotation.isInstanceOf[DateLogicalTypeAnnotation] => + new ParquetPrimitiveConverter(updater) { + override def addInt(value: Int): Unit = { + this.updater.set(DateTimeUtils.daysToMicros(dateRebaseFunc(value), ZoneOffset.UTC)) + } + } + case DateType => new ParquetPrimitiveConverter(updater) { override def addInt(value: Int): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 374a8a8078edc..6319d47ffb78b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -1070,17 +1070,6 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession } } - test("SPARK-35640: int as long should throw schema incompatible error") { - val data = (1 to 4).map(i => Tuple1(i)) - val readSchema = StructType(Seq(StructField("_1", DataTypes.LongType))) - - withParquetFile(data) { path => - val errMsg = intercept[Exception](spark.read.schema(readSchema).parquet(path).collect()) - .getMessage - assert(errMsg.contains("Parquet column cannot be converted in file")) - } - } - test("write metadata") { val hadoopConf = spark.sessionState.newHadoopConf() withTempPath { file => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 2ba01eea51e20..73a9222c73386 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -1110,19 +1110,13 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS test("row group skipping doesn't overflow when reading into larger type") { withTempPath { path => Seq(0).toDF("a").write.parquet(path.toString) - // The vectorized and non-vectorized readers will produce different exceptions, we don't need - // to test both as this covers row group skipping. - withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") { - // Reading integer 'a' as a long isn't supported. Check that an exception is raised instead - // of incorrectly skipping the single row group and producing incorrect results. - val exception = intercept[SparkException] { + withAllParquetReaders { + val result = spark.read .schema("a LONG") .parquet(path.toString) .where(s"a < ${Long.MaxValue}") - .collect() - } - assert(exception.getCause.getCause.isInstanceOf[SchemaColumnConvertNotSupportedException]) + checkAnswer(result, Row(0)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala new file mode 100644 index 0000000000000..72580f7078e23 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import java.io.File + +import org.apache.hadoop.fs.Path +import org.apache.parquet.format.converter.ParquetMetadataConverter +import org.apache.parquet.hadoop.{ParquetFileReader, ParquetOutputFormat} + +import org.apache.spark.SparkException +import org.apache.spark.sql.{DataFrame, QueryTest, Row} +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} +import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types._ + +class ParquetTypeWideningSuite + extends QueryTest + with ParquetTest + with SharedSparkSession + with AdaptiveSparkPlanHelper { + + import testImplicits._ + + /** + * Write a Parquet file with the given values stored using type `fromType` and read it back + * using type `toType` with each Parquet reader. If `expectError` returns true, check that an + * error is thrown during the read. Otherwise check that the data read matches the data written. + */ + private def checkAllParquetReaders( + values: Seq[String], + fromType: DataType, + toType: DataType, + expectError: => Boolean): Unit = { + val timestampRebaseModes = toType match { + case _: TimestampNTZType | _: DateType => + Seq(LegacyBehaviorPolicy.CORRECTED, LegacyBehaviorPolicy.LEGACY) + case _ => + Seq(LegacyBehaviorPolicy.CORRECTED) + } + for { + dictionaryEnabled <- Seq(true, false) + timestampRebaseMode <- timestampRebaseModes + } + withClue( + s"with dictionary encoding '$dictionaryEnabled' with timestamp rebase mode " + + s"'$timestampRebaseMode''") { + withAllParquetWriters { + withTempDir { dir => + val expected = + writeParquetFiles(dir, values, fromType, dictionaryEnabled, timestampRebaseMode) + withAllParquetReaders { + if (expectError) { + val exception = intercept[SparkException] { + readParquetFiles(dir, toType).collect() + } + assert( + exception.getCause.getCause + .isInstanceOf[SchemaColumnConvertNotSupportedException] || + exception.getCause.getCause + .isInstanceOf[org.apache.parquet.io.ParquetDecodingException] || + exception.getCause.getMessage.contains( + "Unable to create Parquet converter for data type")) + } else { + checkAnswer(readParquetFiles(dir, toType), expected.select($"a".cast(toType))) + } + } + } + } + } + } + + /** + * Reads all parquet files in the given directory using the given type. + */ + private def readParquetFiles(dir: File, dataType: DataType): DataFrame = { + spark.read.schema(s"a ${dataType.sql}").parquet(dir.getAbsolutePath) + } + + /** + * Writes values to a parquet file in the given directory using the given type and returns a + * DataFrame corresponding to the data written. If dictionaryEnabled is true, the columns will + * be dictionary encoded. Each provided value is repeated 10 times to allow dictionary encoding + * to be used. timestampRebaseMode can be either "CORRECTED" or "LEGACY", see + * [[SQLConf.PARQUET_REBASE_MODE_IN_WRITE]] + */ + private def writeParquetFiles( + dir: File, + values: Seq[String], + dataType: DataType, + dictionaryEnabled: Boolean, + timestampRebaseMode: LegacyBehaviorPolicy.Value = LegacyBehaviorPolicy.CORRECTED) + : DataFrame = { + val repeatedValues = List.fill(if (dictionaryEnabled) 10 else 1)(values).flatten + val df = repeatedValues.toDF("a").select(col("a").cast(dataType)) + withSQLConf( + ParquetOutputFormat.ENABLE_DICTIONARY -> dictionaryEnabled.toString, + SQLConf.PARQUET_REBASE_MODE_IN_WRITE.key -> timestampRebaseMode.toString) { + df.write.mode("overwrite").parquet(dir.getAbsolutePath) + } + + // Decimals stored as byte arrays (precision > 18) are not dictionary encoded. + if (dictionaryEnabled && !DecimalType.isByteArrayDecimalType(dataType)) { + assertAllParquetFilesDictionaryEncoded(dir) + } + df + } + + /** + * Asserts that all parquet files in the given directory have all their columns dictionary + * encoded. + */ + private def assertAllParquetFilesDictionaryEncoded(dir: File): Unit = { + dir.listFiles(_.getName.endsWith(".parquet")).foreach { file => + val parquetMetadata = ParquetFileReader.readFooter( + spark.sessionState.newHadoopConf(), + new Path(dir.toString, file.getName), + ParquetMetadataConverter.NO_FILTER) + parquetMetadata.getBlocks.forEach { block => + block.getColumns.forEach { col => + assert( + col.hasDictionaryPage, + "This test covers dictionary encoding but column " + + s"'${col.getPath.toDotString}' in the test data is not dictionary encoded.") + } + } + } + } + + for { + (values: Seq[String], fromType: DataType, toType: DataType) <- Seq( + (Seq("1", "2", Short.MinValue.toString), ShortType, IntegerType), + // Int->Short isn't a widening conversion but Parquet stores both as INT32 so it just works. + (Seq("1", "2", Short.MinValue.toString), IntegerType, ShortType), + (Seq("1", "2", Int.MinValue.toString), IntegerType, LongType), + (Seq("1", "2", Short.MinValue.toString), ShortType, DoubleType), + (Seq("1", "2", Int.MinValue.toString), IntegerType, DoubleType), + (Seq("1.23", "10.34"), FloatType, DoubleType), + (Seq("2020-01-01", "2020-01-02", "1312-02-27"), DateType, TimestampNTZType) + ) + } + test(s"parquet widening conversion $fromType -> $toType") { + checkAllParquetReaders(values, fromType, toType, expectError = false) + } + + for { + (values: Seq[String], fromType: DataType, toType: DataType) <- Seq( + (Seq("1", "2", Int.MinValue.toString), LongType, IntegerType), + (Seq("1.23", "10.34"), DoubleType, FloatType), + (Seq("1.23", "10.34"), FloatType, LongType), + (Seq("1.23", "10.34"), LongType, DateType), + (Seq("1.23", "10.34"), IntegerType, TimestampType), + (Seq("1.23", "10.34"), IntegerType, TimestampNTZType), + (Seq("2020-01-01", "2020-01-02", "1312-02-27"), DateType, TimestampType) + ) + } + test(s"unsupported parquet conversion $fromType -> $toType") { + checkAllParquetReaders(values, fromType, toType, expectError = true) + } + + for { + (values: Seq[String], fromType: DataType, toType: DataType) <- Seq( + (Seq("2020-01-01", "2020-01-02", "1312-02-27"), TimestampType, DateType), + (Seq("2020-01-01", "2020-01-02", "1312-02-27"), TimestampNTZType, DateType)) + outputTimestampType <- ParquetOutputTimestampType.values + } + test(s"unsupported parquet timestamp conversion $fromType ($outputTimestampType) -> $toType") { + withSQLConf( + SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> outputTimestampType.toString, + SQLConf.PARQUET_INT96_REBASE_MODE_IN_WRITE.key -> LegacyBehaviorPolicy.CORRECTED.toString + ) { + checkAllParquetReaders(values, fromType, toType, expectError = true) + } + } + + for { + (fromPrecision, toPrecision) <- + // Test widening and narrowing precision between the same and different decimal physical + // parquet types: + // - INT32: precisions 5, 7 + // - INT64: precisions 10, 12 + // - FIXED_LEN_BYTE_ARRAY: precisions 20, 22 + Seq(5 -> 7, 5 -> 10, 5 -> 20, 10 -> 12, 10 -> 20, 20 -> 22) ++ + Seq(7 -> 5, 10 -> 5, 20 -> 5, 12 -> 10, 20 -> 10, 22 -> 20) + } + test( + s"parquet decimal precision change Decimal($fromPrecision, 2) -> Decimal($toPrecision, 2)") { + checkAllParquetReaders( + values = Seq("1.23", "10.34"), + fromType = DecimalType(fromPrecision, 2), + toType = DecimalType(toPrecision, 2), + expectError = fromPrecision > toPrecision && + // parquet-mr allows reading decimals into a smaller precision decimal type without + // checking for overflows. See test below. + spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean) + } + + test("parquet decimal type change Decimal(5, 2) -> Decimal(3, 2) overflows with parquet-mr") { + withTempDir { dir => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + writeParquetFiles( + dir, + values = Seq("123.45", "999.99"), + DecimalType(5, 2), + dictionaryEnabled = false) + checkAnswer(readParquetFiles(dir, DecimalType(3, 2)), Row(null) :: Row(null) :: Nil) + } + } + } + + test("parquet decimal type change IntegerType -> ShortType overflows") { + withTempDir { dir => + withAllParquetReaders { + // Int & Short are both stored as INT32 in Parquet but Int.MinValue will overflow when + // reading as Short in Spark. + val overflowValue = Short.MaxValue.toInt + 1 + writeParquetFiles( + dir, + Seq(overflowValue.toString), + IntegerType, + dictionaryEnabled = false) + checkAnswer(readParquetFiles(dir, ShortType), Row(Short.MinValue)) + } + } + } +}