diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 38ce1604b1ede..6a59e9728a9f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -340,6 +340,9 @@ object Decimal { val ROUND_CEILING = BigDecimal.RoundingMode.CEILING val ROUND_FLOOR = BigDecimal.RoundingMode.FLOOR + /** Maximum number of decimal digits a Int can represent */ + val MAX_INT_DIGITS = 9 + /** Maximum number of decimal digits a Long can represent */ val MAX_LONG_DIGITS = 18 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 2e03ddae760b2..9c1319c1c5e6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -150,6 +150,17 @@ object DecimalType extends AbstractDataType { } } + /** + * Returns if dt is a DecimalType that fits inside a int + */ + def is32BitDecimalType(dt: DataType): Boolean = { + dt match { + case t: DecimalType => + t.precision <= Decimal.MAX_INT_DIGITS + case _ => false + } + } + /** * Returns if dt is a DecimalType that fits inside a long */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java index e7f0ec2e77895..57dbd7c2ff56f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -257,8 +257,7 @@ private void initializeInternal() throws IOException { throw new IOException("Unsupported type: " + t); } if (originalTypes[i] == OriginalType.DECIMAL && - primitiveType.getDecimalMetadata().getPrecision() > - CatalystSchemaConverter.MAX_PRECISION_FOR_INT64()) { + primitiveType.getDecimalMetadata().getPrecision() > Decimal.MAX_LONG_DIGITS()) { throw new IOException("Decimal with high precision is not supported."); } if (primitiveType.getPrimitiveTypeName() == PrimitiveType.PrimitiveTypeName.INT96) { @@ -439,7 +438,7 @@ private void decodeFixedLenArrayAsDecimalBatch(int col, int num) throws IOExcept PrimitiveType type = requestedSchema.getFields().get(col).asPrimitiveType(); int precision = type.getDecimalMetadata().getPrecision(); int scale = type.getDecimalMetadata().getScale(); - Preconditions.checkState(precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64(), + Preconditions.checkState(precision <= Decimal.MAX_LONG_DIGITS(), "Unsupported precision."); for (int n = 0; n < num; ++n) { @@ -480,11 +479,6 @@ private final class ColumnReader { */ private boolean useDictionary; - /** - * If useDictionary is true, the staging vector used to decode the ids. - */ - private ColumnVector dictionaryIds; - /** * Maximum definition level for this column. */ @@ -620,18 +614,13 @@ private void readBatch(int total, ColumnVector column) throws IOException { } int num = Math.min(total, leftInPage); if (useDictionary) { - // Data is dictionary encoded. We will vector decode the ids and then resolve the values. - if (dictionaryIds == null) { - dictionaryIds = ColumnVector.allocate(total, DataTypes.IntegerType, MemoryMode.ON_HEAP); - } else { - dictionaryIds.reset(); - dictionaryIds.reserve(total); - } // Read and decode dictionary ids. + ColumnVector dictionaryIds = column.reserveDictionaryIds(total);; defColumn.readIntegers( num, dictionaryIds, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); - decodeDictionaryIds(rowId, num, column); + decodeDictionaryIds(rowId, num, column, dictionaryIds); } else { + column.setDictionary(null); switch (descriptor.getType()) { case BOOLEAN: readBooleanBatch(rowId, num, column); @@ -668,55 +657,25 @@ private void readBatch(int total, ColumnVector column) throws IOException { /** * Reads `num` values into column, decoding the values from `dictionaryIds` and `dictionary`. */ - private void decodeDictionaryIds(int rowId, int num, ColumnVector column) { + private void decodeDictionaryIds(int rowId, int num, ColumnVector column, + ColumnVector dictionaryIds) { switch (descriptor.getType()) { case INT32: - if (column.dataType() == DataTypes.IntegerType) { - for (int i = rowId; i < rowId + num; ++i) { - column.putInt(i, dictionary.decodeToInt(dictionaryIds.getInt(i))); - } - } else if (column.dataType() == DataTypes.ByteType) { - for (int i = rowId; i < rowId + num; ++i) { - column.putByte(i, (byte) dictionary.decodeToInt(dictionaryIds.getInt(i))); - } - } else if (column.dataType() == DataTypes.ShortType) { - for (int i = rowId; i < rowId + num; ++i) { - column.putShort(i, (short) dictionary.decodeToInt(dictionaryIds.getInt(i))); - } - } else if (DecimalType.is64BitDecimalType(column.dataType())) { - for (int i = rowId; i < rowId + num; ++i) { - column.putLong(i, dictionary.decodeToInt(dictionaryIds.getInt(i))); - } - } else { - throw new NotImplementedException("Unimplemented type: " + column.dataType()); - } - break; - case INT64: - if (column.dataType() == DataTypes.LongType || - DecimalType.is64BitDecimalType(column.dataType())) { - for (int i = rowId; i < rowId + num; ++i) { - column.putLong(i, dictionary.decodeToLong(dictionaryIds.getInt(i))); - } - } else { - throw new NotImplementedException("Unimplemented type: " + column.dataType()); - } - break; - case FLOAT: - for (int i = rowId; i < rowId + num; ++i) { - column.putFloat(i, dictionary.decodeToFloat(dictionaryIds.getInt(i))); - } - break; - case DOUBLE: - for (int i = rowId; i < rowId + num; ++i) { - column.putDouble(i, dictionary.decodeToDouble(dictionaryIds.getInt(i))); - } + case BINARY: + column.setDictionary(dictionary); break; case FIXED_LEN_BYTE_ARRAY: - if (DecimalType.is64BitDecimalType(column.dataType())) { + // DecimalType written in the legacy mode + if (DecimalType.is32BitDecimalType(column.dataType())) { + for (int i = rowId; i < rowId + num; ++i) { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); + column.putInt(i, (int) CatalystRowConverter.binaryToUnscaledLong(v)); + } + } else if (DecimalType.is64BitDecimalType(column.dataType())) { for (int i = rowId; i < rowId + num; ++i) { Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); column.putLong(i, CatalystRowConverter.binaryToUnscaledLong(v)); @@ -726,17 +685,6 @@ private void decodeDictionaryIds(int rowId, int num, ColumnVector column) { } break; - case BINARY: - // TODO: this is incredibly inefficient as it blows up the dictionary right here. We - // need to do this better. We should probably add the dictionary data to the ColumnVector - // and reuse it across batches. This should mean adding a ByteArray would just update - // the length and offset. - for (int i = rowId; i < rowId + num; ++i) { - Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); - column.putByteArray(i, v.getBytes()); - } - break; - default: throw new NotImplementedException("Unsupported type: " + descriptor.getType()); } @@ -756,15 +704,13 @@ private void readBooleanBatch(int rowId, int num, ColumnVector column) throws IO private void readIntBatch(int rowId, int num, ColumnVector column) throws IOException { // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions - if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType) { + if (column.dataType() == DataTypes.IntegerType || column.dataType() == DataTypes.DateType || + DecimalType.is32BitDecimalType(column.dataType())) { defColumn.readIntegers( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else if (column.dataType() == DataTypes.ByteType) { defColumn.readBytes( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); - } else if (DecimalType.is64BitDecimalType(column.dataType())) { - defColumn.readIntsAsLongs( - num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); } else if (column.dataType() == DataTypes.ShortType) { defColumn.readShorts( num, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); @@ -822,7 +768,16 @@ private void readFixedLenByteArrayBatch(int rowId, int num, VectorizedValuesReader data = (VectorizedValuesReader) dataColumn; // This is where we implement support for the valid type conversions. // TODO: implement remaining type conversions - if (DecimalType.is64BitDecimalType(column.dataType())) { + if (DecimalType.is32BitDecimalType(column.dataType())) { + for (int i = 0; i < num; i++) { + if (defColumn.readInteger() == maxDefLevel) { + column.putInt(rowId + i, + (int) CatalystRowConverter.binaryToUnscaledLong(data.readBinary(arrayLen))); + } else { + column.putNull(rowId + i); + } + } + } else if (DecimalType.is64BitDecimalType(column.dataType())) { for (int i = 0; i < num; i++) { if (defColumn.readInteger() == maxDefLevel) { column.putLong(rowId + i, diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index 8613fcae0b805..62157389013bb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -25,7 +25,6 @@ import org.apache.parquet.io.ParquetDecodingException; import org.apache.parquet.io.api.Binary; -import org.apache.spark.sql.Column; import org.apache.spark.sql.execution.vectorized.ColumnVector; /** @@ -239,38 +238,6 @@ public void readBooleans(int total, ColumnVector c, } } - public void readIntsAsLongs(int total, ColumnVector c, - int rowId, int level, VectorizedValuesReader data) { - int left = total; - while (left > 0) { - if (this.currentCount == 0) this.readNextGroup(); - int n = Math.min(left, this.currentCount); - switch (mode) { - case RLE: - if (currentValue == level) { - for (int i = 0; i < n; i++) { - c.putLong(rowId + i, data.readInteger()); - } - } else { - c.putNulls(rowId, n); - } - break; - case PACKED: - for (int i = 0; i < n; ++i) { - if (currentBuffer[currentBufferIdx++] == level) { - c.putLong(rowId + i, data.readInteger()); - } else { - c.putNull(rowId + i); - } - } - break; - } - rowId += n; - left -= n; - currentCount -= n; - } - } - public void readBytes(int total, ColumnVector c, int rowId, int level, VectorizedValuesReader data) { int left = total; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index 0514252a8e53d..bb0247c2fbedf 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -19,6 +19,10 @@ import java.math.BigDecimal; import java.math.BigInteger; +import org.apache.commons.lang.NotImplementedException; +import org.apache.parquet.column.Dictionary; +import org.apache.parquet.io.api.Binary; + import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.ArrayData; @@ -27,8 +31,6 @@ import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; -import org.apache.commons.lang.NotImplementedException; - /** * This class represents a column of values and provides the main APIs to access the data * values. It supports all the types and contains get/put APIs as well as their batched versions. @@ -157,7 +159,7 @@ public Object[] array() { } else if (dt instanceof StringType) { for (int i = 0; i < length; i++) { if (!data.getIsNull(offset + i)) { - list[i] = ColumnVectorUtils.toString(data.getByteArray(offset + i)); + list[i] = getUTF8String(i).toString(); } } } else if (dt instanceof CalendarIntervalType) { @@ -204,28 +206,17 @@ public float getFloat(int ordinal) { @Override public Decimal getDecimal(int ordinal, int precision, int scale) { - if (precision <= Decimal.MAX_LONG_DIGITS()) { - return Decimal.apply(getLong(ordinal), precision, scale); - } else { - byte[] bytes = getBinary(ordinal); - BigInteger bigInteger = new BigInteger(bytes); - BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); - return Decimal.apply(javaDecimal, precision, scale); - } + return data.getDecimal(offset + ordinal, precision, scale); } @Override public UTF8String getUTF8String(int ordinal) { - Array child = data.getByteArray(offset + ordinal); - return UTF8String.fromBytes(child.byteArray, child.byteArrayOffset, child.length); + return data.getUTF8String(offset + ordinal); } @Override public byte[] getBinary(int ordinal) { - ColumnVector.Array array = data.getByteArray(offset + ordinal); - byte[] bytes = new byte[array.length]; - System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length); - return bytes; + return data.getBinary(offset + ordinal); } @Override @@ -534,12 +525,57 @@ public final int putByteArray(int rowId, byte[] value) { /** * Returns the value for rowId. */ - public final Array getByteArray(int rowId) { + private Array getByteArray(int rowId) { Array array = getArray(rowId); array.data.loadBytes(array); return array; } + /** + * Returns the decimal for rowId. + */ + public final Decimal getDecimal(int rowId, int precision, int scale) { + if (precision <= Decimal.MAX_INT_DIGITS()) { + return Decimal.apply(getInt(rowId), precision, scale); + } else if (precision <= Decimal.MAX_LONG_DIGITS()) { + return Decimal.apply(getLong(rowId), precision, scale); + } else { + // TODO: best perf? + byte[] bytes = getBinary(rowId); + BigInteger bigInteger = new BigInteger(bytes); + BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); + return Decimal.apply(javaDecimal, precision, scale); + } + } + + /** + * Returns the UTF8String for rowId. + */ + public final UTF8String getUTF8String(int rowId) { + if (dictionary == null) { + ColumnVector.Array a = getByteArray(rowId); + return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length); + } else { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(rowId)); + return UTF8String.fromBytes(v.getBytes()); + } + } + + /** + * Returns the byte array for rowId. + */ + public final byte[] getBinary(int rowId) { + if (dictionary == null) { + ColumnVector.Array array = getByteArray(rowId); + byte[] bytes = new byte[array.length]; + System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length); + return bytes; + } else { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(rowId)); + return v.getBytes(); + } + } + /** * Append APIs. These APIs all behave similarly and will append data to the current vector. It * is not valid to mix the put and append APIs. The append APIs are slower and should only be @@ -816,6 +852,39 @@ public final int appendStruct(boolean isNull) { */ protected final ColumnarBatch.Row resultStruct; + /** + * The Dictionary for this column. + * + * If it's not null, will be used to decode the value in getXXX(). + */ + protected Dictionary dictionary; + + /** + * Reusable column for ids of dictionary. + */ + protected ColumnVector dictionaryIds; + + /** + * Update the dictionary. + */ + public void setDictionary(Dictionary dictionary) { + this.dictionary = dictionary; + } + + /** + * Reserve a integer column for ids of dictionary. + */ + public ColumnVector reserveDictionaryIds(int capacity) { + if (dictionaryIds == null) { + dictionaryIds = allocate(capacity, DataTypes.IntegerType, + this instanceof OnHeapColumnVector ? MemoryMode.ON_HEAP : MemoryMode.OFF_HEAP); + } else { + dictionaryIds.reset(); + dictionaryIds.reserve(capacity); + } + return dictionaryIds; + } + /** * Sets up the common state and also handles creating the child columns if this is a nested * type. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index 2aeef7f2f90fe..681ace3387139 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -22,24 +22,20 @@ import java.util.Iterator; import java.util.List; +import org.apache.commons.lang.NotImplementedException; + import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.util.DateTimeUtils; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.CalendarInterval; -import org.apache.commons.lang.NotImplementedException; - /** * Utilities to help manipulate data associate with ColumnVectors. These should be used mostly * for debugging or other non-performance critical paths. * These utilities are mostly used to convert ColumnVectors into other formats. */ public class ColumnVectorUtils { - public static String toString(ColumnVector.Array a) { - return new String(a.byteArray, a.byteArrayOffset, a.length); - } - /** * Returns the array data as the java primitive array. * For example, an array of IntegerType will return an int[]. diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index 070d897a7158c..8a0d7f8b12379 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -16,11 +16,11 @@ */ package org.apache.spark.sql.execution.vectorized; -import java.math.BigDecimal; -import java.math.BigInteger; import java.util.Arrays; import java.util.Iterator; +import org.apache.commons.lang.NotImplementedException; + import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericMutableRow; @@ -31,8 +31,6 @@ import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; -import org.apache.commons.lang.NotImplementedException; - /** * This class is the in memory representation of rows as they are streamed through operators. It * is designed to maximize CPU efficiency and not storage footprint. Since it is expected that @@ -193,29 +191,17 @@ public final boolean anyNull() { @Override public final Decimal getDecimal(int ordinal, int precision, int scale) { - if (precision <= Decimal.MAX_LONG_DIGITS()) { - return Decimal.apply(getLong(ordinal), precision, scale); - } else { - // TODO: best perf? - byte[] bytes = getBinary(ordinal); - BigInteger bigInteger = new BigInteger(bytes); - BigDecimal javaDecimal = new BigDecimal(bigInteger, scale); - return Decimal.apply(javaDecimal, precision, scale); - } + return columns[ordinal].getDecimal(rowId, precision, scale); } @Override public final UTF8String getUTF8String(int ordinal) { - ColumnVector.Array a = columns[ordinal].getByteArray(rowId); - return UTF8String.fromBytes(a.byteArray, a.byteArrayOffset, a.length); + return columns[ordinal].getUTF8String(rowId); } @Override public final byte[] getBinary(int ordinal) { - ColumnVector.Array array = columns[ordinal].getByteArray(rowId); - byte[] bytes = new byte[array.length]; - System.arraycopy(array.byteArray, array.byteArrayOffset, bytes, 0, bytes.length); - return bytes; + return columns[ordinal].getBinary(rowId); } @Override diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index e38ed051219b7..b06b7f2457b54 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -18,25 +18,11 @@ import java.nio.ByteOrder; -import org.apache.spark.memory.MemoryMode; -import org.apache.spark.sql.execution.vectorized.ColumnVector.Array; -import org.apache.spark.sql.types.BooleanType; -import org.apache.spark.sql.types.ByteType; -import org.apache.spark.sql.types.DataType; -import org.apache.spark.sql.types.DateType; -import org.apache.spark.sql.types.DecimalType; -import org.apache.spark.sql.types.DoubleType; -import org.apache.spark.sql.types.FloatType; -import org.apache.spark.sql.types.IntegerType; -import org.apache.spark.sql.types.LongType; -import org.apache.spark.sql.types.ShortType; -import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.types.UTF8String; - - import org.apache.commons.lang.NotImplementedException; -import org.apache.commons.lang.NotImplementedException; +import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.Platform; /** * Column data backed using offheap memory. @@ -171,7 +157,11 @@ public final void putBytes(int rowId, int count, byte[] src, int srcIndex) { @Override public final byte getByte(int rowId) { - return Platform.getByte(null, data + rowId); + if (dictionary == null) { + return Platform.getByte(null, data + rowId); + } else { + return (byte) dictionary.decodeToInt(dictionaryIds.getInt(rowId)); + } } // @@ -199,7 +189,11 @@ public final void putShorts(int rowId, int count, short[] src, int srcIndex) { @Override public final short getShort(int rowId) { - return Platform.getShort(null, data + 2 * rowId); + if (dictionary == null) { + return Platform.getShort(null, data + 2 * rowId); + } else { + return (short) dictionary.decodeToInt(dictionaryIds.getInt(rowId)); + } } // @@ -233,7 +227,11 @@ public final void putIntsLittleEndian(int rowId, int count, byte[] src, int srcI @Override public final int getInt(int rowId) { - return Platform.getInt(null, data + 4 * rowId); + if (dictionary == null) { + return Platform.getInt(null, data + 4 * rowId); + } else { + return dictionary.decodeToInt(dictionaryIds.getInt(rowId)); + } } // @@ -267,7 +265,11 @@ public final void putLongsLittleEndian(int rowId, int count, byte[] src, int src @Override public final long getLong(int rowId) { - return Platform.getLong(null, data + 8 * rowId); + if (dictionary == null) { + return Platform.getLong(null, data + 8 * rowId); + } else { + return dictionary.decodeToLong(dictionaryIds.getInt(rowId)); + } } // @@ -301,7 +303,11 @@ public final void putFloats(int rowId, int count, byte[] src, int srcIndex) { @Override public final float getFloat(int rowId) { - return Platform.getFloat(null, data + rowId * 4); + if (dictionary == null) { + return Platform.getFloat(null, data + rowId * 4); + } else { + return dictionary.decodeToFloat(dictionaryIds.getInt(rowId)); + } } @@ -336,7 +342,11 @@ public final void putDoubles(int rowId, int count, byte[] src, int srcIndex) { @Override public final double getDouble(int rowId) { - return Platform.getDouble(null, data + rowId * 8); + if (dictionary == null) { + return Platform.getDouble(null, data + rowId * 8); + } else { + return dictionary.decodeToDouble(dictionaryIds.getInt(rowId)); + } } // @@ -394,7 +404,7 @@ private final void reserveInternal(int newCapacity) { } else if (type instanceof ShortType) { this.data = Platform.reallocateMemory(data, elementsAppended * 2, newCapacity * 2); } else if (type instanceof IntegerType || type instanceof FloatType || - type instanceof DateType) { + type instanceof DateType || DecimalType.is32BitDecimalType(type)) { this.data = Platform.reallocateMemory(data, elementsAppended * 4, newCapacity * 4); } else if (type instanceof LongType || type instanceof DoubleType || DecimalType.is64BitDecimalType(type)) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 3502d31bd1dfa..305e84a86bdc7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -16,13 +16,12 @@ */ package org.apache.spark.sql.execution.vectorized; +import java.util.Arrays; + import org.apache.spark.memory.MemoryMode; -import org.apache.spark.sql.execution.vectorized.ColumnVector.Array; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; -import java.util.Arrays; - /** * A column backed by an in memory JVM array. This stores the NULLs as a byte per value * and a java array for the values. @@ -68,7 +67,6 @@ public final void close() { doubleData = null; } - // // APIs dealing with nulls // @@ -154,7 +152,11 @@ public final void putBytes(int rowId, int count, byte[] src, int srcIndex) { @Override public final byte getByte(int rowId) { - return byteData[rowId]; + if (dictionary == null) { + return byteData[rowId]; + } else { + return (byte) dictionary.decodeToInt(dictionaryIds.getInt(rowId)); + } } // @@ -180,7 +182,11 @@ public final void putShorts(int rowId, int count, short[] src, int srcIndex) { @Override public final short getShort(int rowId) { - return shortData[rowId]; + if (dictionary == null) { + return shortData[rowId]; + } else { + return (short) dictionary.decodeToInt(dictionaryIds.getInt(rowId)); + } } @@ -217,7 +223,11 @@ public final void putIntsLittleEndian(int rowId, int count, byte[] src, int srcI @Override public final int getInt(int rowId) { - return intData[rowId]; + if (dictionary == null) { + return intData[rowId]; + } else { + return dictionary.decodeToInt(dictionaryIds.getInt(rowId)); + } } // @@ -253,7 +263,11 @@ public final void putLongsLittleEndian(int rowId, int count, byte[] src, int src @Override public final long getLong(int rowId) { - return longData[rowId]; + if (dictionary == null) { + return longData[rowId]; + } else { + return dictionary.decodeToLong(dictionaryIds.getInt(rowId)); + } } // @@ -280,7 +294,13 @@ public final void putFloats(int rowId, int count, byte[] src, int srcIndex) { } @Override - public final float getFloat(int rowId) { return floatData[rowId]; } + public final float getFloat(int rowId) { + if (dictionary == null) { + return floatData[rowId]; + } else { + return dictionary.decodeToFloat(dictionaryIds.getInt(rowId)); + } + } // // APIs dealing with doubles @@ -309,7 +329,11 @@ public final void putDoubles(int rowId, int count, byte[] src, int srcIndex) { @Override public final double getDouble(int rowId) { - return doubleData[rowId]; + if (dictionary == null) { + return doubleData[rowId]; + } else { + return dictionary.decodeToDouble(dictionaryIds.getInt(rowId)); + } } // @@ -377,7 +401,8 @@ private final void reserveInternal(int newCapacity) { short[] newData = new short[newCapacity]; if (shortData != null) System.arraycopy(shortData, 0, newData, 0, elementsAppended); shortData = newData; - } else if (type instanceof IntegerType || type instanceof DateType) { + } else if (type instanceof IntegerType || type instanceof DateType || + DecimalType.is32BitDecimalType(type)) { int[] newData = new int[newCapacity]; if (intData != null) System.arraycopy(intData, 0, newData, 0, elementsAppended); intData = newData; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index 42d89f4bf81d6..8a128b4b61769 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -368,7 +368,7 @@ private[parquet] class CatalystRowConverter( } protected def decimalFromBinary(value: Binary): Decimal = { - if (precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64) { + if (precision <= Decimal.MAX_LONG_DIGITS) { // Constructs a `Decimal` with an unscaled `Long` value if possible. val unscaled = CatalystRowConverter.binaryToUnscaledLong(value) Decimal(unscaled, precision, scale) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala index ab4250d0adbae..6f6340f541ada 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala @@ -26,7 +26,7 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ import org.apache.parquet.schema.Type.Repetition._ import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.{maxPrecisionForBytes, MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64} +import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.maxPrecisionForBytes import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -145,7 +145,7 @@ private[parquet] class CatalystSchemaConverter( case INT_16 => ShortType case INT_32 | null => IntegerType case DATE => DateType - case DECIMAL => makeDecimalType(MAX_PRECISION_FOR_INT32) + case DECIMAL => makeDecimalType(Decimal.MAX_INT_DIGITS) case UINT_8 => typeNotSupported() case UINT_16 => typeNotSupported() case UINT_32 => typeNotSupported() @@ -156,7 +156,7 @@ private[parquet] class CatalystSchemaConverter( case INT64 => originalType match { case INT_64 | null => LongType - case DECIMAL => makeDecimalType(MAX_PRECISION_FOR_INT64) + case DECIMAL => makeDecimalType(Decimal.MAX_LONG_DIGITS) case UINT_64 => typeNotSupported() case TIMESTAMP_MILLIS => typeNotImplemented() case _ => illegalType() @@ -403,7 +403,7 @@ private[parquet] class CatalystSchemaConverter( // Uses INT32 for 1 <= precision <= 9 case DecimalType.Fixed(precision, scale) - if precision <= MAX_PRECISION_FOR_INT32 && !writeLegacyParquetFormat => + if precision <= Decimal.MAX_INT_DIGITS && !writeLegacyParquetFormat => Types .primitive(INT32, repetition) .as(DECIMAL) @@ -413,7 +413,7 @@ private[parquet] class CatalystSchemaConverter( // Uses INT64 for 1 <= precision <= 18 case DecimalType.Fixed(precision, scale) - if precision <= MAX_PRECISION_FOR_INT64 && !writeLegacyParquetFormat => + if precision <= Decimal.MAX_LONG_DIGITS && !writeLegacyParquetFormat => Types .primitive(INT64, repetition) .as(DECIMAL) @@ -569,10 +569,6 @@ private[parquet] object CatalystSchemaConverter { // Returns the minimum number of bytes needed to store a decimal with a given `precision`. val minBytesForPrecision = Array.tabulate[Int](39)(computeMinBytesForPrecision) - val MAX_PRECISION_FOR_INT32 = maxPrecisionForBytes(4) /* 9 */ - - val MAX_PRECISION_FOR_INT64 = maxPrecisionForBytes(8) /* 18 */ - // Max precision of a decimal value stored in `numBytes` bytes def maxPrecisionForBytes(numBytes: Int): Int = { Math.round( // convert double to long diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala index 3508220c9541f..0252c79d8e143 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala @@ -33,7 +33,7 @@ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.{minBytesForPrecision, MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64} +import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.minBytesForPrecision import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -253,13 +253,13 @@ private[parquet] class CatalystWriteSupport extends WriteSupport[InternalRow] wi writeLegacyParquetFormat match { // Standard mode, 1 <= precision <= 9, writes as INT32 - case false if precision <= MAX_PRECISION_FOR_INT32 => int32Writer + case false if precision <= Decimal.MAX_INT_DIGITS => int32Writer // Standard mode, 10 <= precision <= 18, writes as INT64 - case false if precision <= MAX_PRECISION_FOR_INT64 => int64Writer + case false if precision <= Decimal.MAX_LONG_DIGITS => int64Writer // Legacy mode, 1 <= precision <= 18, writes as FIXED_LEN_BYTE_ARRAY - case true if precision <= MAX_PRECISION_FOR_INT64 => binaryWriterUsingUnscaledLong + case true if precision <= Decimal.MAX_LONG_DIGITS => binaryWriterUsingUnscaledLong // Either standard or legacy mode, 19 <= precision <= 38, writes as FIXED_LEN_BYTE_ARRAY case _ => binaryWriterUsingUnscaledBytes diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala index cef6b79a094d1..281a2cffa894a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetEncodingSuite.scala @@ -47,7 +47,7 @@ class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContex assert(batch.column(0).getByte(i) == 1) assert(batch.column(1).getInt(i) == 2) assert(batch.column(2).getLong(i) == 3) - assert(ColumnVectorUtils.toString(batch.column(3).getByteArray(i)) == "abc") + assert(batch.column(3).getUTF8String(i).toString == "abc") i += 1 } reader.close() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala index 8efdf8adb042a..97638a66ab473 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchBenchmark.scala @@ -370,7 +370,7 @@ object ColumnarBatchBenchmark { } i = 0 while (i < count) { - sum += column.getByteArray(i).length + sum += column.getUTF8String(i).numBytes() i += 1 } column.reset() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 445f311107e33..b3c3e66fbcbd5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -360,7 +360,7 @@ class ColumnarBatchSuite extends SparkFunSuite { reference.zipWithIndex.foreach { v => assert(v._1.length == column.getArrayLength(v._2), "MemoryMode=" + memMode) - assert(v._1 == ColumnVectorUtils.toString(column.getByteArray(v._2)), + assert(v._1 == column.getUTF8String(v._2).toString, "MemoryMode" + memMode) } @@ -488,7 +488,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(batch.column(1).getDouble(0) == 1.1) assert(batch.column(1).getIsNull(0) == false) assert(batch.column(2).getIsNull(0) == true) - assert(ColumnVectorUtils.toString(batch.column(3).getByteArray(0)) == "Hello") + assert(batch.column(3).getUTF8String(0).toString == "Hello") // Verify the iterator works correctly. val it = batch.rowIterator() @@ -499,7 +499,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(row.getDouble(1) == 1.1) assert(row.isNullAt(1) == false) assert(row.isNullAt(2) == true) - assert(ColumnVectorUtils.toString(batch.column(3).getByteArray(0)) == "Hello") + assert(batch.column(3).getUTF8String(0).toString == "Hello") assert(it.hasNext == false) assert(it.hasNext == false)