Skip to content

Commit 0356ac0

Browse files
johanl-dbdongjoon-hyun
authored andcommitted
[SPARK-40876][SQL] Widening type promotion from integers to decimal in Parquet vectorized reader
### What changes were proposed in this pull request? This is a follow-up from #44368 and #44513, implementing an additional type promotion from integers to decimals in the parquet vectorized reader, bringing it at parity with the non-vectorized reader in that regard. ### Why are the changes needed? This allows reading parquet files that have different schemas and mix decimals and integers - e.g reading files containing either `Decimal(15, 2)` and `INT32` as `Decimal(15, 2)` - as long as the requested decimal type is large enough to accommodate the integer values without precision loss. ### Does this PR introduce _any_ user-facing change? Yes, the following now succeeds when using the vectorized Parquet reader: ``` Seq(20).toDF($"a".cast(IntegerType)).write.parquet(path) spark.read.schema("a decimal(12, 0)").parquet(path).collect() ``` It failed before with the vectorized reader and succeeded with the non-vectorized reader. ### How was this patch tested? - Tests added to `ParquetWideningTypeSuite` - Updated relevant `ParquetQuerySuite` test. ### Was this patch authored or co-authored using generative AI tooling? No Closes #44803 from johanl-db/SPARK-40876-widening-promotion-int-to-decimal. Authored-by: Johan Lasperas <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent fa60a7e commit 0356ac0

File tree

4 files changed

+150
-27
lines changed

4 files changed

+150
-27
lines changed

sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1407,7 +1407,11 @@ private static class IntegerToDecimalUpdater extends DecimalUpdater {
14071407
super(sparkType);
14081408
LogicalTypeAnnotation typeAnnotation =
14091409
descriptor.getPrimitiveType().getLogicalTypeAnnotation();
1410-
this.parquetScale = ((DecimalLogicalTypeAnnotation) typeAnnotation).getScale();
1410+
if (typeAnnotation instanceof DecimalLogicalTypeAnnotation) {
1411+
this.parquetScale = ((DecimalLogicalTypeAnnotation) typeAnnotation).getScale();
1412+
} else {
1413+
this.parquetScale = 0;
1414+
}
14111415
}
14121416

14131417
@Override
@@ -1436,14 +1440,18 @@ public void decodeSingleDictionaryId(
14361440
}
14371441
}
14381442

1439-
private static class LongToDecimalUpdater extends DecimalUpdater {
1443+
private static class LongToDecimalUpdater extends DecimalUpdater {
14401444
private final int parquetScale;
14411445

1442-
LongToDecimalUpdater(ColumnDescriptor descriptor, DecimalType sparkType) {
1446+
LongToDecimalUpdater(ColumnDescriptor descriptor, DecimalType sparkType) {
14431447
super(sparkType);
14441448
LogicalTypeAnnotation typeAnnotation =
14451449
descriptor.getPrimitiveType().getLogicalTypeAnnotation();
1446-
this.parquetScale = ((DecimalLogicalTypeAnnotation) typeAnnotation).getScale();
1450+
if (typeAnnotation instanceof DecimalLogicalTypeAnnotation) {
1451+
this.parquetScale = ((DecimalLogicalTypeAnnotation) typeAnnotation).getScale();
1452+
} else {
1453+
this.parquetScale = 0;
1454+
}
14471455
}
14481456

14491457
@Override
@@ -1641,6 +1649,12 @@ private static boolean isDateTypeMatched(ColumnDescriptor descriptor) {
16411649
return typeAnnotation instanceof DateLogicalTypeAnnotation;
16421650
}
16431651

1652+
private static boolean isSignedIntAnnotation(LogicalTypeAnnotation typeAnnotation) {
1653+
if (!(typeAnnotation instanceof IntLogicalTypeAnnotation)) return false;
1654+
IntLogicalTypeAnnotation intAnnotation = (IntLogicalTypeAnnotation) typeAnnotation;
1655+
return intAnnotation.isSigned();
1656+
}
1657+
16441658
private static boolean isDecimalTypeMatched(ColumnDescriptor descriptor, DataType dt) {
16451659
DecimalType requestedType = (DecimalType) dt;
16461660
LogicalTypeAnnotation typeAnnotation = descriptor.getPrimitiveType().getLogicalTypeAnnotation();
@@ -1652,6 +1666,20 @@ private static boolean isDecimalTypeMatched(ColumnDescriptor descriptor, DataTyp
16521666
int scaleIncrease = requestedType.scale() - parquetType.getScale();
16531667
int precisionIncrease = requestedType.precision() - parquetType.getPrecision();
16541668
return scaleIncrease >= 0 && precisionIncrease >= scaleIncrease;
1669+
} else if (typeAnnotation == null || isSignedIntAnnotation(typeAnnotation)) {
1670+
// Allow reading signed integers (which may be un-annotated) as decimal as long as the
1671+
// requested decimal type is large enough to represent all possible values.
1672+
PrimitiveType.PrimitiveTypeName typeName =
1673+
descriptor.getPrimitiveType().getPrimitiveTypeName();
1674+
int integerPrecision = requestedType.precision() - requestedType.scale();
1675+
switch (typeName) {
1676+
case INT32:
1677+
return integerPrecision >= DecimalType$.MODULE$.IntDecimal().precision();
1678+
case INT64:
1679+
return integerPrecision >= DecimalType$.MODULE$.LongDecimal().precision();
1680+
default:
1681+
return false;
1682+
}
16551683
}
16561684
return false;
16571685
}
@@ -1662,6 +1690,9 @@ private static boolean isSameDecimalScale(ColumnDescriptor descriptor, DataType
16621690
if (typeAnnotation instanceof DecimalLogicalTypeAnnotation) {
16631691
DecimalLogicalTypeAnnotation decimalType = (DecimalLogicalTypeAnnotation) typeAnnotation;
16641692
return decimalType.getScale() == d.scale();
1693+
} else if (typeAnnotation == null || isSignedIntAnnotation(typeAnnotation)) {
1694+
// Consider signed integers (which may be un-annotated) as having scale 0.
1695+
return d.scale() == 0;
16651696
}
16661697
return false;
16671698
}

sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,18 +153,17 @@ private boolean isLazyDecodingSupported(
153153
// rebasing.
154154
switch (typeName) {
155155
case INT32: {
156-
boolean isDate = logicalTypeAnnotation instanceof DateLogicalTypeAnnotation;
157-
boolean isDecimal = logicalTypeAnnotation instanceof DecimalLogicalTypeAnnotation;
156+
boolean isDecimal = sparkType instanceof DecimalType;
158157
boolean needsUpcast = sparkType == LongType || sparkType == DoubleType ||
159-
(isDate && sparkType == TimestampNTZType) ||
158+
sparkType == TimestampNTZType ||
160159
(isDecimal && !DecimalType.is32BitDecimalType(sparkType));
161160
boolean needsRebase = logicalTypeAnnotation instanceof DateLogicalTypeAnnotation &&
162161
!"CORRECTED".equals(datetimeRebaseMode);
163162
isSupported = !needsUpcast && !needsRebase && !needsDecimalScaleRebase(sparkType);
164163
break;
165164
}
166165
case INT64: {
167-
boolean isDecimal = logicalTypeAnnotation instanceof DecimalLogicalTypeAnnotation;
166+
boolean isDecimal = sparkType instanceof DecimalType;
168167
boolean needsUpcast = (isDecimal && !DecimalType.is64BitDecimalType(sparkType)) ||
169168
updaterFactory.isTimestampTypeMatched(TimeUnit.MILLIS);
170169
boolean needsRebase = updaterFactory.isTimestampTypeMatched(TimeUnit.MICROS) &&

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,8 +1037,10 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS
10371037

10381038
withAllParquetReaders {
10391039
// We can read the decimal parquet field with a larger precision, if scale is the same.
1040-
val schema = "a DECIMAL(9, 1), b DECIMAL(18, 2), c DECIMAL(38, 2)"
1041-
checkAnswer(readParquet(schema, path), df)
1040+
val schema1 = "a DECIMAL(9, 1), b DECIMAL(18, 2), c DECIMAL(38, 2)"
1041+
checkAnswer(readParquet(schema1, path), df)
1042+
val schema2 = "a DECIMAL(18, 1), b DECIMAL(38, 2), c DECIMAL(38, 2)"
1043+
checkAnswer(readParquet(schema2, path), df)
10421044
}
10431045

10441046
withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
@@ -1067,10 +1069,12 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS
10671069

10681070
withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
10691071
checkAnswer(readParquet("a DECIMAL(3, 2)", path), sql("SELECT 1.00"))
1072+
checkAnswer(readParquet("a DECIMAL(11, 2)", path), sql("SELECT 1.00"))
10701073
checkAnswer(readParquet("b DECIMAL(3, 2)", path), Row(null))
10711074
checkAnswer(readParquet("b DECIMAL(11, 1)", path), sql("SELECT 123456.0"))
10721075
checkAnswer(readParquet("c DECIMAL(11, 1)", path), Row(null))
10731076
checkAnswer(readParquet("c DECIMAL(13, 0)", path), df.select("c"))
1077+
checkAnswer(readParquet("c DECIMAL(22, 0)", path), df.select("c"))
10741078
val e = intercept[SparkException] {
10751079
readParquet("d DECIMAL(3, 2)", path).collect()
10761080
}.getCause

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala

Lines changed: 106 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.parquet
1919
import java.io.File
2020

2121
import org.apache.hadoop.fs.Path
22+
import org.apache.parquet.column.{Encoding, ParquetProperties}
2223
import org.apache.parquet.format.converter.ParquetMetadataConverter
2324
import org.apache.parquet.hadoop.{ParquetFileReader, ParquetOutputFormat}
2425

@@ -31,6 +32,7 @@ import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
3132
import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType
3233
import org.apache.spark.sql.test.SharedSparkSession
3334
import org.apache.spark.sql.types._
35+
import org.apache.spark.sql.types.DecimalType.{ByteDecimal, IntDecimal, LongDecimal, ShortDecimal}
3436

3537
class ParquetTypeWideningSuite
3638
extends QueryTest
@@ -121,6 +123,19 @@ class ParquetTypeWideningSuite
121123
if (dictionaryEnabled && !DecimalType.isByteArrayDecimalType(dataType)) {
122124
assertAllParquetFilesDictionaryEncoded(dir)
123125
}
126+
127+
// Check which encoding was used when writing Parquet V2 files.
128+
val isParquetV2 = spark.conf.getOption(ParquetOutputFormat.WRITER_VERSION)
129+
.contains(ParquetProperties.WriterVersion.PARQUET_2_0.toString)
130+
if (isParquetV2) {
131+
if (dictionaryEnabled) {
132+
assertParquetV2Encoding(dir, Encoding.PLAIN)
133+
} else if (DecimalType.is64BitDecimalType(dataType)) {
134+
assertParquetV2Encoding(dir, Encoding.DELTA_BINARY_PACKED)
135+
} else if (DecimalType.isByteArrayDecimalType(dataType)) {
136+
assertParquetV2Encoding(dir, Encoding.DELTA_BYTE_ARRAY)
137+
}
138+
}
124139
df
125140
}
126141

@@ -145,6 +160,27 @@ class ParquetTypeWideningSuite
145160
}
146161
}
147162

163+
/**
164+
* Asserts that all parquet files in the given directory have all their columns encoded with the
165+
* given encoding.
166+
*/
167+
private def assertParquetV2Encoding(dir: File, expected_encoding: Encoding): Unit = {
168+
dir.listFiles(_.getName.endsWith(".parquet")).foreach { file =>
169+
val parquetMetadata = ParquetFileReader.readFooter(
170+
spark.sessionState.newHadoopConf(),
171+
new Path(dir.toString, file.getName),
172+
ParquetMetadataConverter.NO_FILTER)
173+
parquetMetadata.getBlocks.forEach { block =>
174+
block.getColumns.forEach { col =>
175+
assert(
176+
col.getEncodings.contains(expected_encoding),
177+
s"Expected column '${col.getPath.toDotString}' to use encoding $expected_encoding " +
178+
s"but found ${col.getEncodings}.")
179+
}
180+
}
181+
}
182+
}
183+
148184
for {
149185
(values: Seq[String], fromType: DataType, toType: DataType) <- Seq(
150186
(Seq("1", "2", Short.MinValue.toString), ShortType, IntegerType),
@@ -157,24 +193,77 @@ class ParquetTypeWideningSuite
157193
(Seq("2020-01-01", "2020-01-02", "1312-02-27"), DateType, TimestampNTZType)
158194
)
159195
}
160-
test(s"parquet widening conversion $fromType -> $toType") {
161-
checkAllParquetReaders(values, fromType, toType, expectError = false)
162-
}
196+
test(s"parquet widening conversion $fromType -> $toType") {
197+
checkAllParquetReaders(values, fromType, toType, expectError = false)
198+
}
199+
200+
for {
201+
(values: Seq[String], fromType: DataType, toType: DataType) <- Seq(
202+
(Seq("1", Byte.MaxValue.toString), ByteType, IntDecimal),
203+
(Seq("1", Byte.MaxValue.toString), ByteType, LongDecimal),
204+
(Seq("1", Short.MaxValue.toString), ShortType, IntDecimal),
205+
(Seq("1", Short.MaxValue.toString), ShortType, LongDecimal),
206+
(Seq("1", Short.MaxValue.toString), ShortType, DecimalType(DecimalType.MAX_PRECISION, 0)),
207+
(Seq("1", Int.MaxValue.toString), IntegerType, IntDecimal),
208+
(Seq("1", Int.MaxValue.toString), IntegerType, LongDecimal),
209+
(Seq("1", Int.MaxValue.toString), IntegerType, DecimalType(DecimalType.MAX_PRECISION, 0)),
210+
(Seq("1", Long.MaxValue.toString), LongType, LongDecimal),
211+
(Seq("1", Long.MaxValue.toString), LongType, DecimalType(DecimalType.MAX_PRECISION, 0)),
212+
(Seq("1", Byte.MaxValue.toString), ByteType, DecimalType(IntDecimal.precision + 1, 1)),
213+
(Seq("1", Short.MaxValue.toString), ShortType, DecimalType(IntDecimal.precision + 1, 1)),
214+
(Seq("1", Int.MaxValue.toString), IntegerType, DecimalType(IntDecimal.precision + 1, 1)),
215+
(Seq("1", Long.MaxValue.toString), LongType, DecimalType(LongDecimal.precision + 1, 1))
216+
)
217+
}
218+
test(s"parquet widening conversion $fromType -> $toType") {
219+
checkAllParquetReaders(values, fromType, toType, expectError = false)
220+
}
163221

164222
for {
165223
(values: Seq[String], fromType: DataType, toType: DataType) <- Seq(
166224
(Seq("1", "2", Int.MinValue.toString), LongType, IntegerType),
167225
(Seq("1.23", "10.34"), DoubleType, FloatType),
168226
(Seq("1.23", "10.34"), FloatType, LongType),
227+
(Seq("1", "10"), LongType, DoubleType),
169228
(Seq("1", "10"), LongType, DateType),
170229
(Seq("1", "10"), IntegerType, TimestampType),
171230
(Seq("1", "10"), IntegerType, TimestampNTZType),
172231
(Seq("2020-01-01", "2020-01-02", "1312-02-27"), DateType, TimestampType)
173232
)
174233
}
175-
test(s"unsupported parquet conversion $fromType -> $toType") {
176-
checkAllParquetReaders(values, fromType, toType, expectError = true)
177-
}
234+
test(s"unsupported parquet conversion $fromType -> $toType") {
235+
checkAllParquetReaders(values, fromType, toType, expectError = true)
236+
}
237+
238+
for {
239+
(values: Seq[String], fromType: DataType, toType: DecimalType) <- Seq(
240+
// Parquet stores byte, short, int values as INT32, which then requires using a decimal that
241+
// can hold at least 4 byte integers.
242+
(Seq("1", "2"), ByteType, DecimalType(1, 0)),
243+
(Seq("1", "2"), ByteType, ByteDecimal),
244+
(Seq("1", "2"), ShortType, ByteDecimal),
245+
(Seq("1", "2"), ShortType, ShortDecimal),
246+
(Seq("1", "2"), IntegerType, ShortDecimal),
247+
(Seq("1", "2"), ByteType, DecimalType(ByteDecimal.precision + 1, 1)),
248+
(Seq("1", "2"), ShortType, DecimalType(ShortDecimal.precision + 1, 1)),
249+
(Seq("1", "2"), LongType, IntDecimal),
250+
(Seq("1", "2"), ByteType, DecimalType(ByteDecimal.precision - 1, 0)),
251+
(Seq("1", "2"), ShortType, DecimalType(ShortDecimal.precision - 1, 0)),
252+
(Seq("1", "2"), IntegerType, DecimalType(IntDecimal.precision - 1, 0)),
253+
(Seq("1", "2"), LongType, DecimalType(LongDecimal.precision - 1, 0)),
254+
(Seq("1", "2"), ByteType, DecimalType(ByteDecimal.precision, 1)),
255+
(Seq("1", "2"), ShortType, DecimalType(ShortDecimal.precision, 1)),
256+
(Seq("1", "2"), IntegerType, DecimalType(IntDecimal.precision, 1)),
257+
(Seq("1", "2"), LongType, DecimalType(LongDecimal.precision, 1))
258+
)
259+
}
260+
test(s"unsupported parquet conversion $fromType -> $toType") {
261+
checkAllParquetReaders(values, fromType, toType,
262+
expectError =
263+
// parquet-mr allows reading decimals into a smaller precision decimal type without
264+
// checking for overflows. See test below checking for the overflow case in parquet-mr.
265+
spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean)
266+
}
178267

179268
for {
180269
(values: Seq[String], fromType: DataType, toType: DataType) <- Seq(
@@ -201,17 +290,17 @@ class ParquetTypeWideningSuite
201290
Seq(5 -> 7, 5 -> 10, 5 -> 20, 10 -> 12, 10 -> 20, 20 -> 22) ++
202291
Seq(7 -> 5, 10 -> 5, 20 -> 5, 12 -> 10, 20 -> 10, 22 -> 20)
203292
}
204-
test(
205-
s"parquet decimal precision change Decimal($fromPrecision, 2) -> Decimal($toPrecision, 2)") {
206-
checkAllParquetReaders(
207-
values = Seq("1.23", "10.34"),
208-
fromType = DecimalType(fromPrecision, 2),
209-
toType = DecimalType(toPrecision, 2),
210-
expectError = fromPrecision > toPrecision &&
211-
// parquet-mr allows reading decimals into a smaller precision decimal type without
212-
// checking for overflows. See test below checking for the overflow case in parquet-mr.
213-
spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean)
214-
}
293+
test(
294+
s"parquet decimal precision change Decimal($fromPrecision, 2) -> Decimal($toPrecision, 2)") {
295+
checkAllParquetReaders(
296+
values = Seq("1.23", "10.34"),
297+
fromType = DecimalType(fromPrecision, 2),
298+
toType = DecimalType(toPrecision, 2),
299+
expectError = fromPrecision > toPrecision &&
300+
// parquet-mr allows reading decimals into a smaller precision decimal type without
301+
// checking for overflows. See test below checking for the overflow case in parquet-mr.
302+
spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean)
303+
}
215304

216305
for {
217306
((fromPrecision, fromScale), (toPrecision, toScale)) <-

0 commit comments

Comments
 (0)