Skip to content

Commit d439e34

Browse files
johanl-dbcloud-fan
authored andcommitted
[SPARK-40876][SQL] Widening type promotion for decimals with larger scale in Parquet readers
### What changes were proposed in this pull request? This is a follow-up from #44368 implementing an additional type promotion to decimals with larger precision and scale. As long as the precision increases by at least as much as the scale, the decimal values can be promoted without loss of precision: Decimal(6, 2) -> Decimal(8, 4): 1234.56 -> 1234.5600. The non-vectorized reader (parquet-mr) is already able to do this type promotion, this PR implements it for the vectorized reader. ### Why are the changes needed? This allows reading multiple parquet files that contain decimal with different precision/scales ### Does this PR introduce _any_ user-facing change? Yes, the following now succeeds when using the vectorized Parquet reader: ``` Seq(20).toDF($"a".cast(DecimalType(4, 2))).write.parquet(path) spark.read.schema("a decimal(6, 4)").parquet(path).collect() ``` It failed before with the vectorized reader and succeeded with the non-vectorized reader. ### How was this patch tested? - Tests added to `ParquetWideningTypeSuite` to cover decimal promotion between decimals with different physical types: INT32, INT64, FIXED_LEN_BYTE_ARRAY. ### Was this patch authored or co-authored using generative AI tooling? No Closes #44513 from johanl-db/SPARK-40876-parquet-type-promotion-decimal-scale. Authored-by: Johan Lasperas <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent f54ecd6 commit d439e34

File tree

4 files changed

+281
-14
lines changed

4 files changed

+281
-14
lines changed

sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java

Lines changed: 218 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@
3434
import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
3535
import org.apache.spark.sql.types.*;
3636

37+
import java.math.BigDecimal;
3738
import java.math.BigInteger;
39+
import java.math.RoundingMode;
3840
import java.time.ZoneId;
3941
import java.time.ZoneOffset;
4042
import java.util.Arrays;
@@ -108,6 +110,8 @@ public ParquetVectorUpdater getUpdater(ColumnDescriptor descriptor, DataType spa
108110
}
109111
} else if (sparkType instanceof YearMonthIntervalType) {
110112
return new IntegerUpdater();
113+
} else if (canReadAsDecimal(descriptor, sparkType)) {
114+
return new IntegerToDecimalUpdater(descriptor, (DecimalType) sparkType);
111115
}
112116
}
113117
case INT64 -> {
@@ -153,6 +157,8 @@ public ParquetVectorUpdater getUpdater(ColumnDescriptor descriptor, DataType spa
153157
return new LongAsMicrosUpdater();
154158
} else if (sparkType instanceof DayTimeIntervalType) {
155159
return new LongUpdater();
160+
} else if (canReadAsDecimal(descriptor, sparkType)) {
161+
return new LongToDecimalUpdater(descriptor, (DecimalType) sparkType);
156162
}
157163
}
158164
case FLOAT -> {
@@ -194,6 +200,8 @@ public ParquetVectorUpdater getUpdater(ColumnDescriptor descriptor, DataType spa
194200
if (sparkType == DataTypes.StringType || sparkType == DataTypes.BinaryType ||
195201
canReadAsBinaryDecimal(descriptor, sparkType)) {
196202
return new BinaryUpdater();
203+
} else if (canReadAsDecimal(descriptor, sparkType)) {
204+
return new BinaryToDecimalUpdater(descriptor, (DecimalType) sparkType);
197205
}
198206
}
199207
case FIXED_LEN_BYTE_ARRAY -> {
@@ -206,6 +214,8 @@ public ParquetVectorUpdater getUpdater(ColumnDescriptor descriptor, DataType spa
206214
return new FixedLenByteArrayUpdater(arrayLen);
207215
} else if (sparkType == DataTypes.BinaryType) {
208216
return new FixedLenByteArrayUpdater(arrayLen);
217+
} else if (canReadAsDecimal(descriptor, sparkType)) {
218+
return new FixedLenByteArrayToDecimalUpdater(descriptor, (DecimalType) sparkType);
209219
}
210220
}
211221
default -> {}
@@ -1358,6 +1368,188 @@ public void decodeSingleDictionaryId(
13581368
}
13591369
}
13601370

1371+
private abstract static class DecimalUpdater implements ParquetVectorUpdater {
1372+
1373+
private final DecimalType sparkType;
1374+
1375+
DecimalUpdater(DecimalType sparkType) {
1376+
this.sparkType = sparkType;
1377+
}
1378+
1379+
@Override
1380+
public void readValues(
1381+
int total,
1382+
int offset,
1383+
WritableColumnVector values,
1384+
VectorizedValuesReader valuesReader) {
1385+
for (int i = 0; i < total; i++) {
1386+
readValue(offset + i, values, valuesReader);
1387+
}
1388+
}
1389+
1390+
protected void writeDecimal(int offset, WritableColumnVector values, BigDecimal decimal) {
1391+
BigDecimal scaledDecimal = decimal.setScale(sparkType.scale(), RoundingMode.UNNECESSARY);
1392+
if (DecimalType.is32BitDecimalType(sparkType)) {
1393+
values.putInt(offset, scaledDecimal.unscaledValue().intValue());
1394+
} else if (DecimalType.is64BitDecimalType(sparkType)) {
1395+
values.putLong(offset, scaledDecimal.unscaledValue().longValue());
1396+
} else {
1397+
values.putByteArray(offset, scaledDecimal.unscaledValue().toByteArray());
1398+
}
1399+
}
1400+
}
1401+
1402+
private static class IntegerToDecimalUpdater extends DecimalUpdater {
1403+
private final int parquetScale;
1404+
1405+
IntegerToDecimalUpdater(ColumnDescriptor descriptor, DecimalType sparkType) {
1406+
super(sparkType);
1407+
LogicalTypeAnnotation typeAnnotation =
1408+
descriptor.getPrimitiveType().getLogicalTypeAnnotation();
1409+
this.parquetScale = ((DecimalLogicalTypeAnnotation) typeAnnotation).getScale();
1410+
}
1411+
1412+
@Override
1413+
public void skipValues(int total, VectorizedValuesReader valuesReader) {
1414+
valuesReader.skipIntegers(total);
1415+
}
1416+
1417+
@Override
1418+
public void readValue(
1419+
int offset,
1420+
WritableColumnVector values,
1421+
VectorizedValuesReader valuesReader) {
1422+
BigDecimal decimal = BigDecimal.valueOf(valuesReader.readInteger(), parquetScale);
1423+
writeDecimal(offset, values, decimal);
1424+
}
1425+
1426+
@Override
1427+
public void decodeSingleDictionaryId(
1428+
int offset,
1429+
WritableColumnVector values,
1430+
WritableColumnVector dictionaryIds,
1431+
Dictionary dictionary) {
1432+
BigDecimal decimal =
1433+
BigDecimal.valueOf(dictionary.decodeToInt(dictionaryIds.getDictId(offset)), parquetScale);
1434+
writeDecimal(offset, values, decimal);
1435+
}
1436+
}
1437+
1438+
private static class LongToDecimalUpdater extends DecimalUpdater {
1439+
private final int parquetScale;
1440+
1441+
LongToDecimalUpdater(ColumnDescriptor descriptor, DecimalType sparkType) {
1442+
super(sparkType);
1443+
LogicalTypeAnnotation typeAnnotation =
1444+
descriptor.getPrimitiveType().getLogicalTypeAnnotation();
1445+
this.parquetScale = ((DecimalLogicalTypeAnnotation) typeAnnotation).getScale();
1446+
}
1447+
1448+
@Override
1449+
public void skipValues(int total, VectorizedValuesReader valuesReader) {
1450+
valuesReader.skipLongs(total);
1451+
}
1452+
1453+
@Override
1454+
public void readValue(
1455+
int offset,
1456+
WritableColumnVector values,
1457+
VectorizedValuesReader valuesReader) {
1458+
BigDecimal decimal = BigDecimal.valueOf(valuesReader.readLong(), parquetScale);
1459+
writeDecimal(offset, values, decimal);
1460+
}
1461+
1462+
@Override
1463+
public void decodeSingleDictionaryId(
1464+
int offset,
1465+
WritableColumnVector values,
1466+
WritableColumnVector dictionaryIds,
1467+
Dictionary dictionary) {
1468+
BigDecimal decimal =
1469+
BigDecimal.valueOf(dictionary.decodeToLong(dictionaryIds.getDictId(offset)), parquetScale);
1470+
writeDecimal(offset, values, decimal);
1471+
}
1472+
}
1473+
1474+
private static class BinaryToDecimalUpdater extends DecimalUpdater {
1475+
private final int parquetScale;
1476+
1477+
BinaryToDecimalUpdater(ColumnDescriptor descriptor, DecimalType sparkType) {
1478+
super(sparkType);
1479+
LogicalTypeAnnotation typeAnnotation =
1480+
descriptor.getPrimitiveType().getLogicalTypeAnnotation();
1481+
this.parquetScale = ((DecimalLogicalTypeAnnotation) typeAnnotation).getScale();
1482+
}
1483+
1484+
@Override
1485+
public void skipValues(int total, VectorizedValuesReader valuesReader) {
1486+
valuesReader.skipBinary(total);
1487+
}
1488+
1489+
@Override
1490+
public void readValue(
1491+
int offset,
1492+
WritableColumnVector values,
1493+
VectorizedValuesReader valuesReader) {
1494+
valuesReader.readBinary(1, values, offset);
1495+
BigInteger value = new BigInteger(values.getBinary(offset));
1496+
BigDecimal decimal = new BigDecimal(value, parquetScale);
1497+
writeDecimal(offset, values, decimal);
1498+
}
1499+
1500+
@Override
1501+
public void decodeSingleDictionaryId(
1502+
int offset,
1503+
WritableColumnVector values,
1504+
WritableColumnVector dictionaryIds,
1505+
Dictionary dictionary) {
1506+
BigInteger value =
1507+
new BigInteger(dictionary.decodeToBinary(dictionaryIds.getDictId(offset)).getBytes());
1508+
BigDecimal decimal = new BigDecimal(value, parquetScale);
1509+
writeDecimal(offset, values, decimal);
1510+
}
1511+
}
1512+
1513+
private static class FixedLenByteArrayToDecimalUpdater extends DecimalUpdater {
1514+
private final int parquetScale;
1515+
private final int arrayLen;
1516+
1517+
FixedLenByteArrayToDecimalUpdater(ColumnDescriptor descriptor, DecimalType sparkType) {
1518+
super(sparkType);
1519+
LogicalTypeAnnotation typeAnnotation =
1520+
descriptor.getPrimitiveType().getLogicalTypeAnnotation();
1521+
this.parquetScale = ((DecimalLogicalTypeAnnotation) typeAnnotation).getScale();
1522+
this.arrayLen = descriptor.getPrimitiveType().getTypeLength();
1523+
}
1524+
1525+
@Override
1526+
public void skipValues(int total, VectorizedValuesReader valuesReader) {
1527+
valuesReader.skipFixedLenByteArray(total, arrayLen);
1528+
}
1529+
1530+
@Override
1531+
public void readValue(
1532+
int offset,
1533+
WritableColumnVector values,
1534+
VectorizedValuesReader valuesReader) {
1535+
BigInteger value = new BigInteger(valuesReader.readBinary(arrayLen).getBytes());
1536+
BigDecimal decimal = new BigDecimal(value, this.parquetScale);
1537+
writeDecimal(offset, values, decimal);
1538+
}
1539+
1540+
@Override
1541+
public void decodeSingleDictionaryId(
1542+
int offset,
1543+
WritableColumnVector values,
1544+
WritableColumnVector dictionaryIds,
1545+
Dictionary dictionary) {
1546+
BigInteger value =
1547+
new BigInteger(dictionary.decodeToBinary(dictionaryIds.getDictId(offset)).getBytes());
1548+
BigDecimal decimal = new BigDecimal(value, this.parquetScale);
1549+
writeDecimal(offset, values, decimal);
1550+
}
1551+
}
1552+
13611553
private static int rebaseDays(int julianDays, final boolean failIfRebase) {
13621554
if (failIfRebase) {
13631555
if (julianDays < RebaseDateTime.lastSwitchJulianDay()) {
@@ -1418,16 +1610,21 @@ private SchemaColumnConvertNotSupportedException constructConvertNotSupportedExc
14181610

14191611
private static boolean canReadAsIntDecimal(ColumnDescriptor descriptor, DataType dt) {
14201612
if (!DecimalType.is32BitDecimalType(dt)) return false;
1421-
return isDecimalTypeMatched(descriptor, dt);
1613+
return isDecimalTypeMatched(descriptor, dt) && isSameDecimalScale(descriptor, dt);
14221614
}
14231615

14241616
private static boolean canReadAsLongDecimal(ColumnDescriptor descriptor, DataType dt) {
14251617
if (!DecimalType.is64BitDecimalType(dt)) return false;
1426-
return isDecimalTypeMatched(descriptor, dt);
1618+
return isDecimalTypeMatched(descriptor, dt) && isSameDecimalScale(descriptor, dt);
14271619
}
14281620

14291621
private static boolean canReadAsBinaryDecimal(ColumnDescriptor descriptor, DataType dt) {
14301622
if (!DecimalType.isByteArrayDecimalType(dt)) return false;
1623+
return isDecimalTypeMatched(descriptor, dt) && isSameDecimalScale(descriptor, dt);
1624+
}
1625+
1626+
private static boolean canReadAsDecimal(ColumnDescriptor descriptor, DataType dt) {
1627+
if (!(dt instanceof DecimalType)) return false;
14311628
return isDecimalTypeMatched(descriptor, dt);
14321629
}
14331630

@@ -1444,14 +1641,29 @@ private static boolean isDateTypeMatched(ColumnDescriptor descriptor) {
14441641
}
14451642

14461643
private static boolean isDecimalTypeMatched(ColumnDescriptor descriptor, DataType dt) {
1644+
DecimalType requestedType = (DecimalType) dt;
1645+
LogicalTypeAnnotation typeAnnotation = descriptor.getPrimitiveType().getLogicalTypeAnnotation();
1646+
if (typeAnnotation instanceof DecimalLogicalTypeAnnotation) {
1647+
DecimalLogicalTypeAnnotation parquetType = (DecimalLogicalTypeAnnotation) typeAnnotation;
1648+
// If the required scale is larger than or equal to the physical decimal scale in the Parquet
1649+
// metadata, we can upscale the value as long as the precision also increases by as much so
1650+
// that there is no loss of precision.
1651+
int scaleIncrease = requestedType.scale() - parquetType.getScale();
1652+
int precisionIncrease = requestedType.precision() - parquetType.getPrecision();
1653+
return scaleIncrease >= 0 && precisionIncrease >= scaleIncrease;
1654+
}
1655+
return false;
1656+
}
1657+
1658+
private static boolean isSameDecimalScale(ColumnDescriptor descriptor, DataType dt) {
14471659
DecimalType d = (DecimalType) dt;
14481660
LogicalTypeAnnotation typeAnnotation = descriptor.getPrimitiveType().getLogicalTypeAnnotation();
1449-
if (typeAnnotation instanceof DecimalLogicalTypeAnnotation decimalType) {
1450-
// It's OK if the required decimal precision is larger than or equal to the physical decimal
1451-
// precision in the Parquet metadata, as long as the decimal scale is the same.
1452-
return decimalType.getPrecision() <= d.precision() && decimalType.getScale() == d.scale();
1661+
if (typeAnnotation instanceof DecimalLogicalTypeAnnotation) {
1662+
DecimalLogicalTypeAnnotation decimalType = (DecimalLogicalTypeAnnotation) typeAnnotation;
1663+
return decimalType.getScale() == d.scale();
14531664
}
14541665
return false;
14551666
}
1667+
14561668
}
14571669

sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -152,32 +152,51 @@ private boolean isLazyDecodingSupported(
152152
switch (typeName) {
153153
case INT32: {
154154
boolean isDate = logicalTypeAnnotation instanceof DateLogicalTypeAnnotation;
155-
boolean needsUpcast = sparkType == LongType || (isDate && sparkType == TimestampNTZType) ||
156-
!DecimalType.is32BitDecimalType(sparkType);
155+
boolean isDecimal = logicalTypeAnnotation instanceof DecimalLogicalTypeAnnotation;
156+
boolean needsUpcast = sparkType == LongType || sparkType == DoubleType ||
157+
(isDate && sparkType == TimestampNTZType) ||
158+
(isDecimal && !DecimalType.is32BitDecimalType(sparkType));
157159
boolean needsRebase = logicalTypeAnnotation instanceof DateLogicalTypeAnnotation &&
158160
!"CORRECTED".equals(datetimeRebaseMode);
159-
isSupported = !needsUpcast && !needsRebase;
161+
isSupported = !needsUpcast && !needsRebase && !needsDecimalScaleRebase(sparkType);
160162
break;
161163
}
162164
case INT64: {
163-
boolean needsUpcast = !DecimalType.is64BitDecimalType(sparkType) ||
165+
boolean isDecimal = logicalTypeAnnotation instanceof DecimalLogicalTypeAnnotation;
166+
boolean needsUpcast = (isDecimal && !DecimalType.is64BitDecimalType(sparkType)) ||
164167
updaterFactory.isTimestampTypeMatched(TimeUnit.MILLIS);
165168
boolean needsRebase = updaterFactory.isTimestampTypeMatched(TimeUnit.MICROS) &&
166169
!"CORRECTED".equals(datetimeRebaseMode);
167-
isSupported = !needsUpcast && !needsRebase;
170+
isSupported = !needsUpcast && !needsRebase && !needsDecimalScaleRebase(sparkType);
168171
break;
169172
}
170173
case FLOAT:
171174
isSupported = sparkType == FloatType;
172175
break;
173176
case DOUBLE:
174-
case BINARY:
175177
isSupported = true;
176178
break;
179+
case BINARY:
180+
isSupported = !needsDecimalScaleRebase(sparkType);
181+
break;
177182
}
178183
return isSupported;
179184
}
180185

186+
/**
187+
* Returns whether the Parquet type of this column and the given spark type are two decimal types
188+
* with different scale.
189+
*/
190+
private boolean needsDecimalScaleRebase(DataType sparkType) {
191+
LogicalTypeAnnotation typeAnnotation =
192+
descriptor.getPrimitiveType().getLogicalTypeAnnotation();
193+
if (!(typeAnnotation instanceof DecimalLogicalTypeAnnotation)) return false;
194+
if (!(sparkType instanceof DecimalType)) return false;
195+
DecimalLogicalTypeAnnotation parquetDecimal = (DecimalLogicalTypeAnnotation) typeAnnotation;
196+
DecimalType sparkDecimal = (DecimalType) sparkType;
197+
return parquetDecimal.getScale() != sparkDecimal.scale();
198+
}
199+
181200
/**
182201
* Reads `total` rows from this columnReader into column.
183202
*/

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1049,7 +1049,9 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS
10491049
}
10501050

10511051
withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") {
1052-
Seq("a DECIMAL(3, 2)", "b DECIMAL(18, 1)", "c DECIMAL(37, 1)").foreach { schema =>
1052+
val schema1 = "a DECIMAL(3, 2), b DECIMAL(18, 3), c DECIMAL(37, 3)"
1053+
checkAnswer(readParquet(schema1, path), df)
1054+
Seq("a DECIMAL(3, 0)", "b DECIMAL(18, 1)", "c DECIMAL(37, 1)").foreach { schema =>
10531055
val e = intercept[SparkException] {
10541056
readParquet(schema, path).collect()
10551057
}.getCause.getCause

0 commit comments

Comments
 (0)