Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,9 @@ object Decimal {
val ROUND_CEILING = BigDecimal.RoundingMode.CEILING
val ROUND_FLOOR = BigDecimal.RoundingMode.FLOOR

/** Maximum number of decimal digits a Int can represent */
val MAX_INT_DIGITS = 9

/** Maximum number of decimal digits a Long can represent */
val MAX_LONG_DIGITS = 18

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,17 @@ object DecimalType extends AbstractDataType {
}
}

/**
* Returns if dt is a DecimalType that fits inside a int
*/
def is32BitDecimalType(dt: DataType): Boolean = {
dt match {
case t: DecimalType =>
t.precision <= Decimal.MAX_INT_DIGITS
case _ => false
}
}

/**
* Returns if dt is a DecimalType that fits inside a long
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,7 @@ private void initializeInternal() throws IOException {
throw new IOException("Unsupported type: " + t);
}
if (originalTypes[i] == OriginalType.DECIMAL &&
primitiveType.getDecimalMetadata().getPrecision() >
CatalystSchemaConverter.MAX_PRECISION_FOR_INT64()) {
primitiveType.getDecimalMetadata().getPrecision() > Decimal.MAX_LONG_DIGITS()) {
throw new IOException("Decimal with high precision is not supported.");
}
if (primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.INT96) {
Expand Down Expand Up @@ -439,7 +438,7 @@ private void decodeFixedLenArrayAsDecimalBatch(int col, int num) throws IOExcept
PrimitiveType type = requestedSchema.getFields().get(col).asPrimitiveType();
int precision = type.getDecimalMetadata().getPrecision();
int scale = type.getDecimalMetadata().getScale();
Preconditions.checkState(precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64(),
Preconditions.checkState(precision <= Decimal.MAX_LONG_DIGITS(),
"Unsupported precision.");

for (int n = 0; n < num; ++n) {
Expand Down Expand Up @@ -480,11 +479,6 @@ private final class ColumnReader {
*/
private boolean useDictionary;

/**
* If useDictionary is true, the staging vector used to decode the ids.
*/
private ColumnVector dictionaryIds;

/**
* Maximum definition level for this column.
*/
Expand Down Expand Up @@ -620,18 +614,13 @@ private void readBatch(int total, ColumnVector column) throws IOException {
}
int num = Math.min(total, leftInPage);
if (useDictionary) {
// Data is dictionary encoded. We will vector decode the ids and then resolve the values.
if (dictionaryIds == null) {
dictionaryIds = ColumnVector.allocate(total, DataTypes.IntegerType, MemoryMode.ON_HEAP);
} else {
dictionaryIds.reset();
dictionaryIds.reserve(total);
}
// Read and decode dictionary ids.
ColumnVector dictionaryIds = column.reserveDictionaryIds(total);;
defColumn.readIntegers(
num, dictionaryIds, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
decodeDictionaryIds(rowId, num, column);
decodeDictionaryIds(rowId, num, column, dictionaryIds);
} else {
column.setDictionary(null);
switch (descriptor.getType()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove dictionaryIds from this class.

case BOOLEAN:
readBooleanBatch(rowId, num, column);
Expand Down Expand Up @@ -668,55 +657,25 @@ private void readBatch(int total, ColumnVector column) throws IOException {
/**
* Reads `num` values into column, decoding the values from `dictionaryIds` and `dictionary`.
*/
private void decodeDictionaryIds(int rowId, int num, ColumnVector column) {
private void decodeDictionaryIds(int rowId, int num, ColumnVector column,
ColumnVector dictionaryIds) {
switch (descriptor.getType()) {
case INT32:
if (column.dataType() == DataTypes.IntegerType) {
for (int i = rowId; i < rowId + num; ++i) {
column.putInt(i, dictionary.decodeToInt(dictionaryIds.getInt(i)));
}
} else if (column.dataType() == DataTypes.ByteType) {
for (int i = rowId; i < rowId + num; ++i) {
column.putByte(i, (byte) dictionary.decodeToInt(dictionaryIds.getInt(i)));
}
} else if (column.dataType() == DataTypes.ShortType) {
for (int i = rowId; i < rowId + num; ++i) {
column.putShort(i, (short) dictionary.decodeToInt(dictionaryIds.getInt(i)));
}
} else if (DecimalType.is64BitDecimalType(column.dataType())) {
for (int i = rowId; i < rowId + num; ++i) {
column.putLong(i, dictionary.decodeToInt(dictionaryIds.getInt(i)));
}
} else {
throw new NotImplementedException("Unimplemented type: " + column.dataType());
}
break;

case INT64:
if (column.dataType() == DataTypes.LongType ||
DecimalType.is64BitDecimalType(column.dataType())) {
for (int i = rowId; i < rowId + num; ++i) {
column.putLong(i, dictionary.decodeToLong(dictionaryIds.getInt(i)));
}
} else {
throw new NotImplementedException("Unimplemented type: " + column.dataType());
}
break;

case FLOAT:
for (int i = rowId; i < rowId + num; ++i) {
column.putFloat(i, dictionary.decodeToFloat(dictionaryIds.getInt(i)));
}
break;

case DOUBLE:
for (int i = rowId; i < rowId + num; ++i) {
column.putDouble(i, dictionary.decodeToDouble(dictionaryIds.getInt(i)));
}
case BINARY:
column.setDictionary(dictionary);
break;

case FIXED_LEN_BYTE_ARRAY:
if (DecimalType.is64BitDecimalType(column.dataType())) {
// DecimalType written in the legacy mode
if (DecimalType.is32BitDecimalType(column.dataType())) {
for (int i = rowId; i < rowId + num; ++i) {
Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i));
column.putInt(i, (int) CatalystRowConverter.binaryToUnscaledLong(v));
}
} else if (DecimalType.is64BitDecimalType(column.dataType())) {
for (int i = rowId; i < rowId + num; ++i) {
Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i));
column.putLong(i, CatalystRowConverter.binaryToUnscaledLong(v));
Expand All @@ -726,17 +685,6 @@ private void decodeDictionaryIds(int rowId, int num, ColumnVector column) {
}
break;

case BINARY:
// TODO: this is incredibly inefficient as it blows up the dictionary right here. We
// need to do this better. We should probably add the dictionary data to the ColumnVector
// and reuse it across batches. This should mean adding a ByteArray would just update
// the length and offset.
for (int i = rowId; i < rowId + num; ++i) {
Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i));
column.putByteArray(i, v.getBytes());
}
break;

default:
throw new NotImplementedException("Unsupported type: " + descriptor.getType());
}
Expand All @@ -756,15 +704,13 @@ private void readBooleanBatch(int rowId, int num, ColumnVector column) throws IO
private void readIntBatch(int rowId, int num, ColumnVector column) throws IOException {
// This is where we implement support for the valid type conversions.
// TODO: implement remaining type conversions
if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType) {
if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType ||
DecimalType.is32BitDecimalType(column.dataType())) {
defColumn.readIntegers(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
} else if (column.dataType() == DataTypes.ByteType) {
defColumn.readBytes(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
} else if (DecimalType.is64BitDecimalType(column.dataType())) {
defColumn.readIntsAsLongs(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
} else if (column.dataType() == DataTypes.ShortType) {
defColumn.readShorts(
num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn);
Expand Down Expand Up @@ -822,7 +768,16 @@ private void readFixedLenByteArrayBatch(int rowId, int num,
VectorizedValuesReader data = (VectorizedValuesReader) dataColumn;
// This is where we implement support for the valid type conversions.
// TODO: implement remaining type conversions
if (DecimalType.is64BitDecimalType(column.dataType())) {
if (DecimalType.is32BitDecimalType(column.dataType())) {
for (int i = 0; i < num; i++) {
if (defColumn.readInteger() == maxDefLevel) {
column.putInt(rowId + i,
(int) CatalystRowConverter.binaryToUnscaledLong(data.readBinary(arrayLen)));
} else {
column.putNull(rowId + i);
}
}
} else if (DecimalType.is64BitDecimalType(column.dataType())) {
for (int i = 0; i < num; i++) {
if (defColumn.readInteger() == maxDefLevel) {
column.putLong(rowId + i,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import org.apache.parquet.io.ParquetDecodingException;
import org.apache.parquet.io.api.Binary;

import org.apache.spark.sql.Column;
import org.apache.spark.sql.execution.vectorized.ColumnVector;

/**
Expand Down Expand Up @@ -239,38 +238,6 @@ public void readBooleans(int total, ColumnVector c,
}
}

public void readIntsAsLongs(int total, ColumnVector c,
int rowId, int level, VectorizedValuesReader data) {
int left = total;
while (left > 0) {
if (this.currentCount == 0) this.readNextGroup();
int n = Math.min(left, this.currentCount);
switch (mode) {
case RLE:
if (currentValue == level) {
for (int i = 0; i < n; i++) {
c.putLong(rowId + i, data.readInteger());
}
} else {
c.putNulls(rowId, n);
}
break;
case PACKED:
for (int i = 0; i < n; ++i) {
if (currentBuffer[currentBufferIdx++] == level) {
c.putLong(rowId + i, data.readInteger());
} else {
c.putNull(rowId + i);
}
}
break;
}
rowId += n;
left -= n;
currentCount -= n;
}
}

public void readBytes(int total, ColumnVector c,
int rowId, int level, VectorizedValuesReader data) {
int left = total;
Expand Down
Loading