Skip to content

Commit ec1a8bc

Browse files
committed
Add widening type promotion from integers to decimals in Parquet vectorized reader
1 parent c4af64e commit ec1a8bc

File tree

4 files changed

+144
-27
lines changed

4 files changed

+144
-27
lines changed

sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1406,7 +1406,11 @@ private static class IntegerToDecimalUpdater extends DecimalUpdater {
14061406
super(sparkType);
14071407
LogicalTypeAnnotation typeAnnotation =
14081408
descriptor.getPrimitiveType().getLogicalTypeAnnotation();
1409-
this.parquetScale = ((DecimalLogicalTypeAnnotation) typeAnnotation).getScale();
1409+
if (typeAnnotation instanceof DecimalLogicalTypeAnnotation) {
1410+
this.parquetScale = ((DecimalLogicalTypeAnnotation) typeAnnotation).getScale();
1411+
} else {
1412+
this.parquetScale = 0;
1413+
}
14101414
}
14111415

14121416
@Override
@@ -1435,14 +1439,18 @@ public void decodeSingleDictionaryId(
14351439
}
14361440
}
14371441

1438-
private static class LongToDecimalUpdater extends DecimalUpdater {
1442+
private static class LongToDecimalUpdater extends DecimalUpdater {
14391443
private final int parquetScale;
14401444

1441-
LongToDecimalUpdater(ColumnDescriptor descriptor, DecimalType sparkType) {
1445+
LongToDecimalUpdater(ColumnDescriptor descriptor, DecimalType sparkType) {
14421446
super(sparkType);
14431447
LogicalTypeAnnotation typeAnnotation =
14441448
descriptor.getPrimitiveType().getLogicalTypeAnnotation();
1445-
this.parquetScale = ((DecimalLogicalTypeAnnotation) typeAnnotation).getScale();
1449+
if (typeAnnotation instanceof DecimalLogicalTypeAnnotation) {
1450+
this.parquetScale = ((DecimalLogicalTypeAnnotation) typeAnnotation).getScale();
1451+
} else {
1452+
this.parquetScale = 0;
1453+
}
14461454
}
14471455

14481456
@Override
@@ -1651,6 +1659,20 @@ private static boolean isDecimalTypeMatched(ColumnDescriptor descriptor, DataTyp
16511659
int scaleIncrease = requestedType.scale() - parquetType.getScale();
16521660
int precisionIncrease = requestedType.precision() - parquetType.getPrecision();
16531661
return scaleIncrease >= 0 && precisionIncrease >= scaleIncrease;
1662+
} else if (typeAnnotation == null || typeAnnotation instanceof IntLogicalTypeAnnotation) {
1663+
// Allow reading integers (which may be un-annotated) as decimal as long as the requested
1664+
// decimal type is large enough to represent all possible values.
1665+
PrimitiveType.PrimitiveTypeName typeName =
1666+
descriptor.getPrimitiveType().getPrimitiveTypeName();
1667+
int integerPrecision = requestedType.precision() - requestedType.scale();
1668+
switch (typeName) {
1669+
case INT32:
1670+
return integerPrecision >= DecimalType$.MODULE$.IntDecimal().precision();
1671+
case INT64:
1672+
return integerPrecision >= DecimalType$.MODULE$.LongDecimal().precision();
1673+
default:
1674+
return false;
1675+
}
16541676
}
16551677
return false;
16561678
}
@@ -1661,6 +1683,9 @@ private static boolean isSameDecimalScale(ColumnDescriptor descriptor, DataType
16611683
if (typeAnnotation instanceof DecimalLogicalTypeAnnotation) {
16621684
DecimalLogicalTypeAnnotation decimalType = (DecimalLogicalTypeAnnotation) typeAnnotation;
16631685
return decimalType.getScale() == d.scale();
1686+
} else if (typeAnnotation == null || typeAnnotation instanceof IntLogicalTypeAnnotation) {
1687+
// Consider integers (which may be un-annotated) as having scale 0.
1688+
return d.scale() == 0;
16641689
}
16651690
return false;
16661691
}

sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,18 +151,17 @@ private boolean isLazyDecodingSupported(
151151
// rebasing.
152152
switch (typeName) {
153153
case INT32: {
154-
boolean isDate = logicalTypeAnnotation instanceof DateLogicalTypeAnnotation;
155-
boolean isDecimal = logicalTypeAnnotation instanceof DecimalLogicalTypeAnnotation;
154+
boolean isDecimal = sparkType instanceof DecimalType;
156155
boolean needsUpcast = sparkType == LongType || sparkType == DoubleType ||
157-
(isDate && sparkType == TimestampNTZType) ||
156+
sparkType == TimestampNTZType ||
158157
(isDecimal && !DecimalType.is32BitDecimalType(sparkType));
159158
boolean needsRebase = logicalTypeAnnotation instanceof DateLogicalTypeAnnotation &&
160159
!"CORRECTED".equals(datetimeRebaseMode);
161160
isSupported = !needsUpcast && !needsRebase && !needsDecimalScaleRebase(sparkType);
162161
break;
163162
}
164163
case INT64: {
165-
boolean isDecimal = logicalTypeAnnotation instanceof DecimalLogicalTypeAnnotation;
164+
boolean isDecimal = sparkType instanceof DecimalType;
166165
boolean needsUpcast = (isDecimal && !DecimalType.is64BitDecimalType(sparkType)) ||
167166
updaterFactory.isTimestampTypeMatched(TimeUnit.MILLIS);
168167
boolean needsRebase = updaterFactory.isTimestampTypeMatched(TimeUnit.MICROS) &&

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,8 +1037,10 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS
10371037

10381038
withAllParquetReaders {
10391039
// We can read the decimal parquet field with a larger precision, if scale is the same.
1040-
val schema = "a DECIMAL(9, 1), b DECIMAL(18, 2), c DECIMAL(38, 2)"
1041-
checkAnswer(readParquet(schema, path), df)
1040+
val schema1 = "a DECIMAL(9, 1), b DECIMAL(18, 2), c DECIMAL(38, 2)"
1041+
checkAnswer(readParquet(schema1, path), df)
1042+
val schema2 = "a DECIMAL(18, 1), b DECIMAL(38, 2), c DECIMAL(38, 2)"
1043+
checkAnswer(readParquet(schema2, path), df)
10421044
}
10431045

10441046
withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
@@ -1067,10 +1069,12 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS
10671069

10681070
withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
10691071
checkAnswer(readParquet("a DECIMAL(3, 2)", path), sql("SELECT 1.00"))
1072+
checkAnswer(readParquet("a DECIMAL(11, 2)", path), sql("SELECT 1.00"))
10701073
checkAnswer(readParquet("b DECIMAL(3, 2)", path), Row(null))
10711074
checkAnswer(readParquet("b DECIMAL(11, 1)", path), sql("SELECT 123456.0"))
10721075
checkAnswer(readParquet("c DECIMAL(11, 1)", path), Row(null))
10731076
checkAnswer(readParquet("c DECIMAL(13, 0)", path), df.select("c"))
1077+
checkAnswer(readParquet("c DECIMAL(22, 0)", path), df.select("c"))
10741078
val e = intercept[SparkException] {
10751079
readParquet("d DECIMAL(3, 2)", path).collect()
10761080
}.getCause

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypeWideningSuite.scala

Lines changed: 106 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet
1818

1919
import java.io.File
2020

21+
import org.apache.parquet.column.{Encoding, ParquetProperties}
2122
import org.apache.hadoop.fs.Path
2223
import org.apache.parquet.format.converter.ParquetMetadataConverter
2324
import org.apache.parquet.hadoop.{ParquetFileReader, ParquetOutputFormat}
@@ -31,6 +32,7 @@ import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
3132
import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType
3233
import org.apache.spark.sql.test.SharedSparkSession
3334
import org.apache.spark.sql.types._
35+
import org.apache.spark.sql.types.DecimalType.{ByteDecimal, IntDecimal, LongDecimal, ShortDecimal}
3436

3537
class ParquetTypeWideningSuite
3638
extends QueryTest
@@ -121,6 +123,19 @@ class ParquetTypeWideningSuite
121123
if (dictionaryEnabled && !DecimalType.isByteArrayDecimalType(dataType)) {
122124
assertAllParquetFilesDictionaryEncoded(dir)
123125
}
126+
127+
// Check which encoding was used when writing Parquet V2 files.
128+
val isParquetV2 = spark.conf.getOption(ParquetOutputFormat.WRITER_VERSION)
129+
.contains(ParquetProperties.WriterVersion.PARQUET_2_0.toString)
130+
if (isParquetV2) {
131+
if (dictionaryEnabled) {
132+
assertParquetV2Encoding(dir, Encoding.PLAIN)
133+
} else if (DecimalType.is64BitDecimalType(dataType)) {
134+
assertParquetV2Encoding(dir, Encoding.DELTA_BINARY_PACKED)
135+
} else if (DecimalType.isByteArrayDecimalType(dataType)) {
136+
assertParquetV2Encoding(dir, Encoding.DELTA_BYTE_ARRAY)
137+
}
138+
}
124139
df
125140
}
126141

@@ -145,6 +160,27 @@ class ParquetTypeWideningSuite
145160
}
146161
}
147162

163+
/**
164+
* Asserts that all parquet files in the given directory have all their columns encoded with the
165+
* given encoding.
166+
*/
167+
private def assertParquetV2Encoding(dir: File, expected_encoding: Encoding): Unit = {
168+
dir.listFiles(_.getName.endsWith(".parquet")).foreach { file =>
169+
val parquetMetadata = ParquetFileReader.readFooter(
170+
spark.sessionState.newHadoopConf(),
171+
new Path(dir.toString, file.getName),
172+
ParquetMetadataConverter.NO_FILTER)
173+
parquetMetadata.getBlocks.forEach { block =>
174+
block.getColumns.forEach { col =>
175+
assert(
176+
col.getEncodings.contains(expected_encoding),
177+
s"Expected column '${col.getPath.toDotString}' to use encoding $expected_encoding " +
178+
s"but found ${col.getEncodings}.")
179+
}
180+
}
181+
}
182+
}
183+
148184
for {
149185
(values: Seq[String], fromType: DataType, toType: DataType) <- Seq(
150186
(Seq("1", "2", Short.MinValue.toString), ShortType, IntegerType),
@@ -157,24 +193,77 @@ class ParquetTypeWideningSuite
157193
(Seq("2020-01-01", "2020-01-02", "1312-02-27"), DateType, TimestampNTZType)
158194
)
159195
}
160-
test(s"parquet widening conversion $fromType -> $toType") {
161-
checkAllParquetReaders(values, fromType, toType, expectError = false)
162-
}
196+
test(s"parquet widening conversion $fromType -> $toType") {
197+
checkAllParquetReaders(values, fromType, toType, expectError = false)
198+
}
199+
200+
for {
201+
(values: Seq[String], fromType: DataType, toType: DataType) <- Seq(
202+
(Seq("1", Byte.MaxValue.toString), ByteType, IntDecimal),
203+
(Seq("1", Byte.MaxValue.toString), ByteType, LongDecimal),
204+
(Seq("1", Short.MaxValue.toString), ShortType, IntDecimal),
205+
(Seq("1", Short.MaxValue.toString), ShortType, LongDecimal),
206+
(Seq("1", Short.MaxValue.toString), ShortType, DecimalType(DecimalType.MAX_PRECISION, 0)),
207+
(Seq("1", Int.MaxValue.toString), IntegerType, IntDecimal),
208+
(Seq("1", Int.MaxValue.toString), IntegerType, LongDecimal),
209+
(Seq("1", Int.MaxValue.toString), IntegerType, DecimalType(DecimalType.MAX_PRECISION, 0)),
210+
(Seq("1", Long.MaxValue.toString), LongType, LongDecimal),
211+
(Seq("1", Long.MaxValue.toString), LongType, DecimalType(DecimalType.MAX_PRECISION, 0)),
212+
(Seq("1", Byte.MaxValue.toString), ByteType, DecimalType(IntDecimal.precision + 1, 1)),
213+
(Seq("1", Short.MaxValue.toString), ShortType, DecimalType(IntDecimal.precision + 1, 1)),
214+
(Seq("1", Int.MaxValue.toString), IntegerType, DecimalType(IntDecimal.precision + 1, 1)),
215+
(Seq("1", Long.MaxValue.toString), LongType, DecimalType(LongDecimal.precision + 1, 1))
216+
)
217+
}
218+
test(s"parquet widening conversion $fromType -> $toType") {
219+
checkAllParquetReaders(values, fromType, toType, expectError = false)
220+
}
163221

164222
for {
165223
(values: Seq[String], fromType: DataType, toType: DataType) <- Seq(
166224
(Seq("1", "2", Int.MinValue.toString), LongType, IntegerType),
167225
(Seq("1.23", "10.34"), DoubleType, FloatType),
168226
(Seq("1.23", "10.34"), FloatType, LongType),
227+
(Seq("1", "10"), LongType, DoubleType),
169228
(Seq("1", "10"), LongType, DateType),
170229
(Seq("1", "10"), IntegerType, TimestampType),
171230
(Seq("1", "10"), IntegerType, TimestampNTZType),
172231
(Seq("2020-01-01", "2020-01-02", "1312-02-27"), DateType, TimestampType)
173232
)
174233
}
175-
test(s"unsupported parquet conversion $fromType -> $toType") {
176-
checkAllParquetReaders(values, fromType, toType, expectError = true)
177-
}
234+
test(s"unsupported parquet conversion $fromType -> $toType") {
235+
checkAllParquetReaders(values, fromType, toType, expectError = true)
236+
}
237+
238+
for {
239+
(values: Seq[String], fromType: DataType, toType: DecimalType) <- Seq(
240+
// Parquet stores byte, short, int values as INT32, which then requires using a decimal that
241+
// can hold at least 4 byte integers.
242+
(Seq("1", "2"), ByteType, DecimalType(1, 0)),
243+
(Seq("1", "2"), ByteType, ByteDecimal),
244+
(Seq("1", "2"), ShortType, ByteDecimal),
245+
(Seq("1", "2"), ShortType, ShortDecimal),
246+
(Seq("1", "2"), IntegerType, ShortDecimal),
247+
(Seq("1", "2"), ByteType, DecimalType(ByteDecimal.precision + 1, 1)),
248+
(Seq("1", "2"), ShortType, DecimalType(ShortDecimal.precision + 1, 1)),
249+
(Seq("1", "2"), LongType, IntDecimal),
250+
(Seq("1", "2"), ByteType, DecimalType(ByteDecimal.precision - 1, 0)),
251+
(Seq("1", "2"), ShortType, DecimalType(ShortDecimal.precision - 1, 0)),
252+
(Seq("1", "2"), IntegerType, DecimalType(IntDecimal.precision - 1, 0)),
253+
(Seq("1", "2"), LongType, DecimalType(LongDecimal.precision - 1, 0)),
254+
(Seq("1", "2"), ByteType, DecimalType(ByteDecimal.precision, 1)),
255+
(Seq("1", "2"), ShortType, DecimalType(ShortDecimal.precision, 1)),
256+
(Seq("1", "2"), IntegerType, DecimalType(IntDecimal.precision, 1)),
257+
(Seq("1", "2"), LongType, DecimalType(LongDecimal.precision, 1))
258+
)
259+
}
260+
test(s"unsupported parquet conversion $fromType -> $toType") {
261+
checkAllParquetReaders(values, fromType, toType,
262+
expectError =
263+
// parquet-mr allows reading decimals into a smaller precision decimal type without
264+
// checking for overflows. See test below checking for the overflow case in parquet-mr.
265+
spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean)
266+
}
178267

179268
for {
180269
(values: Seq[String], fromType: DataType, toType: DataType) <- Seq(
@@ -201,17 +290,17 @@ class ParquetTypeWideningSuite
201290
Seq(5 -> 7, 5 -> 10, 5 -> 20, 10 -> 12, 10 -> 20, 20 -> 22) ++
202291
Seq(7 -> 5, 10 -> 5, 20 -> 5, 12 -> 10, 20 -> 10, 22 -> 20)
203292
}
204-
test(
205-
s"parquet decimal precision change Decimal($fromPrecision, 2) -> Decimal($toPrecision, 2)") {
206-
checkAllParquetReaders(
207-
values = Seq("1.23", "10.34"),
208-
fromType = DecimalType(fromPrecision, 2),
209-
toType = DecimalType(toPrecision, 2),
210-
expectError = fromPrecision > toPrecision &&
211-
// parquet-mr allows reading decimals into a smaller precision decimal type without
212-
// checking for overflows. See test below checking for the overflow case in parquet-mr.
213-
spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean)
214-
}
293+
test(
294+
s"parquet decimal precision change Decimal($fromPrecision, 2) -> Decimal($toPrecision, 2)") {
295+
checkAllParquetReaders(
296+
values = Seq("1.23", "10.34"),
297+
fromType = DecimalType(fromPrecision, 2),
298+
toType = DecimalType(toPrecision, 2),
299+
expectError = fromPrecision > toPrecision &&
300+
// parquet-mr allows reading decimals into a smaller precision decimal type without
301+
// checking for overflows. See test below checking for the overflow case in parquet-mr.
302+
spark.conf.get(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean)
303+
}
215304

216305
for {
217306
((fromPrecision, fromScale), (toPrecision, toScale)) <-

0 commit comments

Comments
 (0)