Skip to content

Commit c27ba0d

Browse files
Davies Liudavies
authored andcommitted
[SPARK-13582] [SQL] defer dictionary decoding in parquet reader
## What changes were proposed in this pull request? This PR defer the resolution from a id of dictionary to value until the column is actually accessed (inside getInt/getLong), this is very useful for those columns and rows that are filtered out. It's also useful for binary type, we will not need to copy all the byte arrays. This PR also change the underlying type for small decimal that could be fit within a Int, in order to use getInt() to lookup the value from IntDictionary. ## How was this patch tested? Manually test TPCDS Q7 with scale factor 10, saw about 30% improvements (after PR #11274). Author: Davies Liu <[email protected]> Closes #11437 from davies/decode_dict.
1 parent c37bbb3 commit c27ba0d

File tree

15 files changed

+221
-203
lines changed

15 files changed

+221
-203
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,9 @@ object Decimal {
340340
val ROUND_CEILING = BigDecimal.RoundingMode.CEILING
341341
val ROUND_FLOOR = BigDecimal.RoundingMode.FLOOR
342342

343+
/** Maximum number of decimal digits a Int can represent */
344+
val MAX_INT_DIGITS = 9
345+
343346
/** Maximum number of decimal digits a Long can represent */
344347
val MAX_LONG_DIGITS = 18
345348

sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,17 @@ object DecimalType extends AbstractDataType {
150150
}
151151
}
152152

153+
/**
154+
* Returns if dt is a DecimalType that fits inside a int
155+
*/
156+
def is32BitDecimalType(dt: DataType): Boolean = {
157+
dt match {
158+
case t: DecimalType =>
159+
t.precision <= Decimal.MAX_INT_DIGITS
160+
case _ => false
161+
}
162+
}
163+
153164
/**
154165
* Returns if dt is a DecimalType that fits inside a long
155166
*/

sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java

Lines changed: 28 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,7 @@ private void initializeInternal() throws IOException {
257257
throw new IOException("Unsupported type: " + t);
258258
}
259259
if (originalTypes[i] == OriginalType.DECIMAL &&
260-
primitiveType.getDecimalMetadata().getPrecision() >
261-
CatalystSchemaConverter.MAX_PRECISION_FOR_INT64()) {
260+
primitiveType.getDecimalMetadata().getPrecision() > Decimal.MAX_LONG_DIGITS()) {
262261
throw new IOException("Decimal with high precision is not supported.");
263262
}
264263
if (primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.INT96) {
@@ -439,7 +438,7 @@ private void decodeFixedLenArrayAsDecimalBatch(int col, int num) throws IOExcept
439438
PrimitiveType type = requestedSchema.getFields().get(col).asPrimitiveType();
440439
int precision = type.getDecimalMetadata().getPrecision();
441440
int scale = type.getDecimalMetadata().getScale();
442-
Preconditions.checkState(precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64(),
441+
Preconditions.checkState(precision <= Decimal.MAX_LONG_DIGITS(),
443442
"Unsupported precision.");
444443

445444
for (int n = 0; n < num; ++n) {
@@ -480,11 +479,6 @@ private final class ColumnReader {
480479
*/
481480
private boolean useDictionary;
482481

483-
/**
484-
* If useDictionary is true, the staging vector used to decode the ids.
485-
*/
486-
private ColumnVector dictionaryIds;
487-
488482
/**
489483
* Maximum definition level for this column.
490484
*/
@@ -620,18 +614,13 @@ private void readBatch(int total, ColumnVector column) throws IOException {
620614
}
621615
int num = Math.min(total, leftInPage);
622616
if (useDictionary) {
623-
// Data is dictionary encoded. We will vector decode the ids and then resolve the values.
624-
if (dictionaryIds == null) {
625-
dictionaryIds = ColumnVector.allocate(total, DataTypes.IntegerType, MemoryMode.ON_HEAP);
626-
} else {
627-
dictionaryIds.reset();
628-
dictionaryIds.reserve(total);
629-
}
630617
// Read and decode dictionary ids.
618+
ColumnVector dictionaryIds = column.reserveDictionaryIds(total);;
631619
defColumn.readIntegers(
632620
num, dictionaryIds, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
633-
decodeDictionaryIds(rowId, num, column);
621+
decodeDictionaryIds(rowId, num, column, dictionaryIds);
634622
} else {
623+
column.setDictionary(null);
635624
switch (descriptor.getType()) {
636625
case BOOLEAN:
637626
readBooleanBatch(rowId, num, column);
@@ -668,55 +657,25 @@ private void readBatch(int total, ColumnVector column) throws IOException {
668657
/**
669658
* Reads `num` values into column, decoding the values from `dictionaryIds` and `dictionary`.
670659
*/
671-
private void decodeDictionaryIds(int rowId, int num, ColumnVector column) {
660+
private void decodeDictionaryIds(int rowId, int num, ColumnVector column,
661+
ColumnVector dictionaryIds) {
672662
switch (descriptor.getType()) {
673663
case INT32:
674-
if (column.dataType() == DataTypes.IntegerType) {
675-
for (int i = rowId; i < rowId + num; ++i) {
676-
column.putInt(i, dictionary.decodeToInt(dictionaryIds.getInt(i)));
677-
}
678-
} else if (column.dataType() == DataTypes.ByteType) {
679-
for (int i = rowId; i < rowId + num; ++i) {
680-
column.putByte(i, (byte) dictionary.decodeToInt(dictionaryIds.getInt(i)));
681-
}
682-
} else if (column.dataType() == DataTypes.ShortType) {
683-
for (int i = rowId; i < rowId + num; ++i) {
684-
column.putShort(i, (short) dictionary.decodeToInt(dictionaryIds.getInt(i)));
685-
}
686-
} else if (DecimalType.is64BitDecimalType(column.dataType())) {
687-
for (int i = rowId; i < rowId + num; ++i) {
688-
column.putLong(i, dictionary.decodeToInt(dictionaryIds.getInt(i)));
689-
}
690-
} else {
691-
throw new NotImplementedException("Unimplemented type: " + column.dataType());
692-
}
693-
break;
694-
695664
case INT64:
696-
if (column.dataType() == DataTypes.LongType ||
697-
DecimalType.is64BitDecimalType(column.dataType())) {
698-
for (int i = rowId; i < rowId + num; ++i) {
699-
column.putLong(i, dictionary.decodeToLong(dictionaryIds.getInt(i)));
700-
}
701-
} else {
702-
throw new NotImplementedException("Unimplemented type: " + column.dataType());
703-
}
704-
break;
705-
706665
case FLOAT:
707-
for (int i = rowId; i < rowId + num; ++i) {
708-
column.putFloat(i, dictionary.decodeToFloat(dictionaryIds.getInt(i)));
709-
}
710-
break;
711-
712666
case DOUBLE:
713-
for (int i = rowId; i < rowId + num; ++i) {
714-
column.putDouble(i, dictionary.decodeToDouble(dictionaryIds.getInt(i)));
715-
}
667+
case BINARY:
668+
column.setDictionary(dictionary);
716669
break;
717670

718671
case FIXED_LEN_BYTE_ARRAY:
719-
if (DecimalType.is64BitDecimalType(column.dataType())) {
672+
// DecimalType written in the legacy mode
673+
if (DecimalType.is32BitDecimalType(column.dataType())) {
674+
for (int i = rowId; i < rowId + num; ++i) {
675+
Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i));
676+
column.putInt(i, (int) CatalystRowConverter.binaryToUnscaledLong(v));
677+
}
678+
} else if (DecimalType.is64BitDecimalType(column.dataType())) {
720679
for (int i = rowId; i < rowId + num; ++i) {
721680
Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i));
722681
column.putLong(i, CatalystRowConverter.binaryToUnscaledLong(v));
@@ -726,17 +685,6 @@ private void decodeDictionaryIds(int rowId, int num, ColumnVector column) {
726685
}
727686
break;
728687

729-
case BINARY:
730-
// TODO: this is incredibly inefficient as it blows up the dictionary right here. We
731-
// need to do this better. We should probably add the dictionary data to the ColumnVector
732-
// and reuse it across batches. This should mean adding a ByteArray would just update
733-
// the length and offset.
734-
for (int i = rowId; i < rowId + num; ++i) {
735-
Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i));
736-
column.putByteArray(i, v.getBytes());
737-
}
738-
break;
739-
740688
default:
741689
throw new NotImplementedException("Unsupported type: " + descriptor.getType());
742690
}
@@ -756,15 +704,13 @@ private void readBooleanBatch(int rowId, int num, ColumnVector column) throws IO
756704
private void readIntBatch(int rowId, int num, ColumnVector column) throws IOException {
757705
// This is where we implement support for the valid type conversions.
758706
// TODO: implement remaining type conversions
759-
if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType) {
707+
if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType ||
708+
DecimalType.is32BitDecimalType(column.dataType())) {
760709
defColumn.readIntegers(
761710
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
762711
} else if (column.dataType() == DataTypes.ByteType) {
763712
defColumn.readBytes(
764713
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
765-
} else if (DecimalType.is64BitDecimalType(column.dataType())) {
766-
defColumn.readIntsAsLongs(
767-
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
768714
} else if (column.dataType() == DataTypes.ShortType) {
769715
defColumn.readShorts(
770716
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
@@ -822,7 +768,16 @@ private void readFixedLenByteArrayBatch(int rowId, int num,
822768
VectorizedValuesReader data = (VectorizedValuesReader) dataColumn;
823769
// This is where we implement support for the valid type conversions.
824770
// TODO: implement remaining type conversions
825-
if (DecimalType.is64BitDecimalType(column.dataType())) {
771+
if (DecimalType.is32BitDecimalType(column.dataType())) {
772+
for (int i = 0; i < num; i++) {
773+
if (defColumn.readInteger() == maxDefLevel) {
774+
column.putInt(rowId + i,
775+
(int) CatalystRowConverter.binaryToUnscaledLong(data.readBinary(arrayLen)));
776+
} else {
777+
column.putNull(rowId + i);
778+
}
779+
}
780+
} else if (DecimalType.is64BitDecimalType(column.dataType())) {
826781
for (int i = 0; i < num; i++) {
827782
if (defColumn.readInteger() == maxDefLevel) {
828783
column.putLong(rowId + i,

sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import org.apache.parquet.io.ParquetDecodingException;
2626
import org.apache.parquet.io.api.Binary;
2727

28-
import org.apache.spark.sql.Column;
2928
import org.apache.spark.sql.execution.vectorized.ColumnVector;
3029

3130
/**
@@ -239,38 +238,6 @@ public void readBooleans(int total, ColumnVector c,
239238
}
240239
}
241240

242-
public void readIntsAsLongs(int total, ColumnVector c,
243-
int rowId, int level, VectorizedValuesReader data) {
244-
int left = total;
245-
while (left > 0) {
246-
if (this.currentCount == 0) this.readNextGroup();
247-
int n = Math.min(left, this.currentCount);
248-
switch (mode) {
249-
case RLE:
250-
if (currentValue == level) {
251-
for (int i = 0; i < n; i++) {
252-
c.putLong(rowId + i, data.readInteger());
253-
}
254-
} else {
255-
c.putNulls(rowId, n);
256-
}
257-
break;
258-
case PACKED:
259-
for (int i = 0; i < n; ++i) {
260-
if (currentBuffer[currentBufferIdx++] == level) {
261-
c.putLong(rowId + i, data.readInteger());
262-
} else {
263-
c.putNull(rowId + i);
264-
}
265-
}
266-
break;
267-
}
268-
rowId += n;
269-
left -= n;
270-
currentCount -= n;
271-
}
272-
}
273-
274241
public void readBytes(int total, ColumnVector c,
275242
int rowId, int level, VectorizedValuesReader data) {
276243
int left = total;

0 commit comments

Comments
 (0)