Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.nio.{ByteBuffer, ByteOrder}

import org.apache.spark.sql.catalyst.expressions.MutableRow
import org.apache.spark.sql.columnar.compression.CompressibleColumnAccessor
import org.apache.spark.sql.types.{BinaryType, DataType, NativeType}
import org.apache.spark.sql.types._

/**
* An `Iterator` like trait used to extract values from columnar byte buffer. When a value is
Expand Down Expand Up @@ -89,6 +89,9 @@ private[sql] class DoubleColumnAccessor(buffer: ByteBuffer)
private[sql] class FloatColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, FLOAT)

private[sql] class FixedDecimalColumnAccessor(buffer: ByteBuffer, precision: Int, scale: Int)
extends NativeColumnAccessor(buffer, FIXED_DECIMAL(precision, scale))

private[sql] class StringColumnAccessor(buffer: ByteBuffer)
extends NativeColumnAccessor(buffer, STRING)

Expand All @@ -107,24 +110,28 @@ private[sql] class GenericColumnAccessor(buffer: ByteBuffer)
with NullableColumnAccessor

private[sql] object ColumnAccessor {
def apply(buffer: ByteBuffer): ColumnAccessor = {
def apply(dataType: DataType, buffer: ByteBuffer): ColumnAccessor = {
val dup = buffer.duplicate().order(ByteOrder.nativeOrder)
// The first 4 bytes in the buffer indicate the column type.
val columnTypeId = dup.getInt()

columnTypeId match {
case INT.typeId => new IntColumnAccessor(dup)
case LONG.typeId => new LongColumnAccessor(dup)
case FLOAT.typeId => new FloatColumnAccessor(dup)
case DOUBLE.typeId => new DoubleColumnAccessor(dup)
case BOOLEAN.typeId => new BooleanColumnAccessor(dup)
case BYTE.typeId => new ByteColumnAccessor(dup)
case SHORT.typeId => new ShortColumnAccessor(dup)
case STRING.typeId => new StringColumnAccessor(dup)
case DATE.typeId => new DateColumnAccessor(dup)
case TIMESTAMP.typeId => new TimestampColumnAccessor(dup)
case BINARY.typeId => new BinaryColumnAccessor(dup)
case GENERIC.typeId => new GenericColumnAccessor(dup)

// The first 4 bytes in the buffer indicate the column type. This field is not used now,
// because we always know the data type of the column ahead of time.
dup.getInt()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this line is not necessary any more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This call has side effect, still need to call it to read 4 bytes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, we can remove this line after removing the whole column type ID stuff.


dataType match {
case IntegerType => new IntColumnAccessor(dup)
case LongType => new LongColumnAccessor(dup)
case FloatType => new FloatColumnAccessor(dup)
case DoubleType => new DoubleColumnAccessor(dup)
case BooleanType => new BooleanColumnAccessor(dup)
case ByteType => new ByteColumnAccessor(dup)
case ShortType => new ShortColumnAccessor(dup)
case StringType => new StringColumnAccessor(dup)
case BinaryType => new BinaryColumnAccessor(dup)
case DateType => new DateColumnAccessor(dup)
case TimestampType => new TimestampColumnAccessor(dup)
case DecimalType.Fixed(precision, scale) if precision < 19 =>
new FixedDecimalColumnAccessor(dup, precision, scale)
case _ => new GenericColumnAccessor(dup)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,13 @@ private[sql] class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleCol

private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT)

private[sql] class FixedDecimalColumnBuilder(
precision: Int,
scale: Int)
extends NativeColumnBuilder(
new FixedDecimalColumnStats,
FIXED_DECIMAL(precision, scale))

private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING)

private[sql] class DateColumnBuilder extends NativeColumnBuilder(new DateColumnStats, DATE)
Expand Down Expand Up @@ -139,25 +146,25 @@ private[sql] object ColumnBuilder {
}

def apply(
typeId: Int,
dataType: DataType,
initialSize: Int = 0,
columnName: String = "",
useCompression: Boolean = false): ColumnBuilder = {

val builder = (typeId match {
case INT.typeId => new IntColumnBuilder
case LONG.typeId => new LongColumnBuilder
case FLOAT.typeId => new FloatColumnBuilder
case DOUBLE.typeId => new DoubleColumnBuilder
case BOOLEAN.typeId => new BooleanColumnBuilder
case BYTE.typeId => new ByteColumnBuilder
case SHORT.typeId => new ShortColumnBuilder
case STRING.typeId => new StringColumnBuilder
case BINARY.typeId => new BinaryColumnBuilder
case GENERIC.typeId => new GenericColumnBuilder
case DATE.typeId => new DateColumnBuilder
case TIMESTAMP.typeId => new TimestampColumnBuilder
}).asInstanceOf[ColumnBuilder]
val builder: ColumnBuilder = dataType match {
case IntegerType => new IntColumnBuilder
case LongType => new LongColumnBuilder
case DoubleType => new DoubleColumnBuilder
case BooleanType => new BooleanColumnBuilder
case ByteType => new ByteColumnBuilder
case ShortType => new ShortColumnBuilder
case StringType => new StringColumnBuilder
case BinaryType => new BinaryColumnBuilder
case DateType => new DateColumnBuilder
case TimestampType => new TimestampColumnBuilder
case DecimalType.Fixed(precision, scale) if precision < 19 =>
new FixedDecimalColumnBuilder(precision, scale)
case _ => new GenericColumnBuilder
}

builder.initialize(initialSize, columnName, useCompression)
builder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,23 @@ private[sql] class FloatColumnStats extends ColumnStats {
def collectedStatistics = Row(lower, upper, nullCount, count, sizeInBytes)
}

private[sql] class FixedDecimalColumnStats extends ColumnStats {
protected var upper: Decimal = null
protected var lower: Decimal = null

override def gatherStats(row: Row, ordinal: Int): Unit = {
super.gatherStats(row, ordinal)
if (!row.isNullAt(ordinal)) {
val value = row(ordinal).asInstanceOf[Decimal]
if (upper == null || value.compareTo(upper) > 0) upper = value
if (lower == null || value.compareTo(lower) < 0) lower = value
sizeInBytes += FIXED_DECIMAL.defaultSize
}
}

override def collectedStatistics: Row = Row(lower, upper, nullCount, count, sizeInBytes)
}

private[sql] class IntColumnStats extends ColumnStats {
protected var upper = Int.MinValue
protected var lower = Int.MaxValue
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,33 @@ private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 12) {
}
}

private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int)
extends NativeColumnType(
DecimalType(Some(PrecisionInfo(precision, scale))),
10,
FIXED_DECIMAL.defaultSize) {

override def extract(buffer: ByteBuffer): Decimal = {
Decimal(buffer.getLong(), precision, scale)
}

override def append(v: Decimal, buffer: ByteBuffer): Unit = {
buffer.putLong(v.toUnscaledLong)
}

override def getField(row: Row, ordinal: Int): Decimal = {
row(ordinal).asInstanceOf[Decimal]
}

override def setField(row: MutableRow, ordinal: Int, value: Decimal): Unit = {
row(ordinal) = value
}
}

private[sql] object FIXED_DECIMAL {
val defaultSize = 8
}

private[sql] sealed abstract class ByteArrayColumnType[T <: DataType](
typeId: Int,
defaultSize: Int)
Expand All @@ -394,7 +421,7 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType](
}
}

private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](10, 16) {
private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](11, 16) {
override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = {
row(ordinal) = value
}
Expand All @@ -405,7 +432,7 @@ private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](10, 16)
// Used to process generic objects (all types other than those listed above). Objects should be
// serialized first before appending to the column `ByteBuffer`, and is also extracted as serialized
// byte array.
private[sql] object GENERIC extends ByteArrayColumnType[DataType](11, 16) {
private[sql] object GENERIC extends ByteArrayColumnType[DataType](12, 16) {
override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = {
row(ordinal) = SparkSqlSerializer.deserialize[Any](value)
}
Expand All @@ -416,18 +443,20 @@ private[sql] object GENERIC extends ByteArrayColumnType[DataType](11, 16) {
private[sql] object ColumnType {
def apply(dataType: DataType): ColumnType[_, _] = {
dataType match {
case IntegerType => INT
case LongType => LONG
case FloatType => FLOAT
case DoubleType => DOUBLE
case BooleanType => BOOLEAN
case ByteType => BYTE
case ShortType => SHORT
case StringType => STRING
case BinaryType => BINARY
case DateType => DATE
case IntegerType => INT
case LongType => LONG
case FloatType => FLOAT
case DoubleType => DOUBLE
case BooleanType => BOOLEAN
case ByteType => BYTE
case ShortType => SHORT
case StringType => STRING
case BinaryType => BINARY
case DateType => DATE
case TimestampType => TIMESTAMP
case _ => GENERIC
case DecimalType.Fixed(precision, scale) if precision < 19 =>
FIXED_DECIMAL(precision, scale)
case _ => GENERIC
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ private[sql] case class InMemoryRelation(
val columnBuilders = output.map { attribute =>
val columnType = ColumnType(attribute.dataType)
val initialBufferSize = columnType.defaultSize * batchSize
ColumnBuilder(columnType.typeId, initialBufferSize, attribute.name, useCompression)
ColumnBuilder(attribute.dataType, initialBufferSize, attribute.name, useCompression)
}.toArray

var rowCount = 0
Expand Down Expand Up @@ -274,8 +274,10 @@ private[sql] case class InMemoryColumnarTableScan(
def cachedBatchesToRows(cacheBatches: Iterator[CachedBatch]) = {
val rows = cacheBatches.flatMap { cachedBatch =>
// Build column accessors
val columnAccessors = requestedColumnIndices.map { batch =>
ColumnAccessor(ByteBuffer.wrap(cachedBatch.buffers(batch)))
val columnAccessors = requestedColumnIndices.map { batchColumnIndex =>
ColumnAccessor(
relation.output(batchColumnIndex).dataType,
ByteBuffer.wrap(cachedBatch.buffers(batchColumnIndex)))
}

// Extract rows via column accessors
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class ColumnStatsSuite extends FunSuite {
testColumnStats(classOf[LongColumnStats], LONG, Row(Long.MaxValue, Long.MinValue, 0))
testColumnStats(classOf[FloatColumnStats], FLOAT, Row(Float.MaxValue, Float.MinValue, 0))
testColumnStats(classOf[DoubleColumnStats], DOUBLE, Row(Double.MaxValue, Double.MinValue, 0))
testColumnStats(classOf[FixedDecimalColumnStats], FIXED_DECIMAL(15, 10), Row(null, null, 0))
testColumnStats(classOf[StringColumnStats], STRING, Row(null, null, 0))
testColumnStats(classOf[DateColumnStats], DATE, Row(Int.MaxValue, Int.MinValue, 0))
testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, Row(null, null, 0))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ class ColumnTypeSuite extends FunSuite with Logging {

test("defaultSize") {
val checks = Map(
INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4, BOOLEAN -> 1,
STRING -> 8, DATE -> 4, TIMESTAMP -> 12, BINARY -> 16, GENERIC -> 16)
INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4,
FIXED_DECIMAL(15, 10) -> 8, BOOLEAN -> 1, STRING -> 8, DATE -> 4, TIMESTAMP -> 12,
BINARY -> 16, GENERIC -> 16)

checks.foreach { case (columnType, expectedSize) =>
assertResult(expectedSize, s"Wrong defaultSize for $columnType") {
Expand All @@ -56,15 +57,16 @@ class ColumnTypeSuite extends FunSuite with Logging {
}
}

checkActualSize(INT, Int.MaxValue, 4)
checkActualSize(SHORT, Short.MaxValue, 2)
checkActualSize(LONG, Long.MaxValue, 8)
checkActualSize(BYTE, Byte.MaxValue, 1)
checkActualSize(DOUBLE, Double.MaxValue, 8)
checkActualSize(FLOAT, Float.MaxValue, 4)
checkActualSize(BOOLEAN, true, 1)
checkActualSize(STRING, "hello", 4 + "hello".getBytes("utf-8").length)
checkActualSize(DATE, 0, 4)
checkActualSize(INT, Int.MaxValue, 4)
checkActualSize(SHORT, Short.MaxValue, 2)
checkActualSize(LONG, Long.MaxValue, 8)
checkActualSize(BYTE, Byte.MaxValue, 1)
checkActualSize(DOUBLE, Double.MaxValue, 8)
checkActualSize(FLOAT, Float.MaxValue, 4)
checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8)
checkActualSize(BOOLEAN, true, 1)
checkActualSize(STRING, "hello", 4 + "hello".getBytes("utf-8").length)
checkActualSize(DATE, 0, 4)
checkActualSize(TIMESTAMP, new Timestamp(0L), 12)

val binary = Array.fill[Byte](4)(0: Byte)
Expand Down Expand Up @@ -93,12 +95,20 @@ class ColumnTypeSuite extends FunSuite with Logging {

testNativeColumnType[DoubleType.type](DOUBLE, _.putDouble(_), _.getDouble)

testNativeColumnType[DecimalType](
FIXED_DECIMAL(15, 10),
(buffer: ByteBuffer, decimal: Decimal) => {
buffer.putLong(decimal.toUnscaledLong)
},
(buffer: ByteBuffer) => {
Decimal(buffer.getLong(), 15, 10)
})

testNativeColumnType[FloatType.type](FLOAT, _.putFloat(_), _.getFloat)

testNativeColumnType[StringType.type](
STRING,
(buffer: ByteBuffer, string: String) => {

val bytes = string.getBytes("utf-8")
buffer.putInt(bytes.length)
buffer.put(bytes)
Expand Down Expand Up @@ -206,4 +216,16 @@ class ColumnTypeSuite extends FunSuite with Logging {
if (sb.nonEmpty) sb.setLength(sb.length - 1)
sb.toString()
}

test("column type for decimal types with different precision") {
(1 to 18).foreach { i =>
assertResult(FIXED_DECIMAL(i, 0)) {
ColumnType(DecimalType(i, 0))
}
}

assertResult(GENERIC) {
ColumnType(DecimalType(19, 0))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import scala.util.Random

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.types.{DataType, NativeType}
import org.apache.spark.sql.types.{Decimal, DataType, NativeType}

object ColumnarTestUtils {
def makeNullRow(length: Int) = {
Expand All @@ -41,16 +41,17 @@ object ColumnarTestUtils {
}

(columnType match {
case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte
case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort
case INT => Random.nextInt()
case LONG => Random.nextLong()
case FLOAT => Random.nextFloat()
case DOUBLE => Random.nextDouble()
case STRING => Random.nextString(Random.nextInt(32))
case BOOLEAN => Random.nextBoolean()
case BINARY => randomBytes(Random.nextInt(32))
case DATE => Random.nextInt()
case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte
case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort
case INT => Random.nextInt()
case LONG => Random.nextLong()
case FLOAT => Random.nextFloat()
case DOUBLE => Random.nextDouble()
case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale)
case STRING => Random.nextString(Random.nextInt(32))
case BOOLEAN => Random.nextBoolean()
case BINARY => randomBytes(Random.nextInt(32))
case DATE => Random.nextInt()
case TIMESTAMP =>
val timestamp = new Timestamp(Random.nextLong())
timestamp.setNanos(Random.nextInt(999999999))
Expand Down
Loading