Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ private[columnar] class BasicColumnBuilder[JvmType](
}

private[columnar] class NullColumnBuilder
extends BasicColumnBuilder[Any](new ObjectColumnStats(NullType), NULL)
extends BasicColumnBuilder[Any](new NullColumnStats, NULL)
with NullableColumnBuilder

private[columnar] abstract class ComplexColumnBuilder[JvmType](
Expand Down Expand Up @@ -132,13 +132,13 @@ private[columnar] class DecimalColumnBuilder(dataType: DecimalType)
extends ComplexColumnBuilder(new DecimalColumnStats(dataType), LARGE_DECIMAL(dataType))

private[columnar] class StructColumnBuilder(dataType: StructType)
extends ComplexColumnBuilder(new ObjectColumnStats(dataType), STRUCT(dataType))
extends ComplexColumnBuilder(new StructColumnStats(dataType), STRUCT(dataType))

private[columnar] class ArrayColumnBuilder(dataType: ArrayType)
extends ComplexColumnBuilder(new ObjectColumnStats(dataType), ARRAY(dataType))
extends ComplexColumnBuilder(new ArrayColumnStats(dataType), ARRAY(dataType))

private[columnar] class MapColumnBuilder(dataType: MapType)
extends ComplexColumnBuilder(new ObjectColumnStats(dataType), MAP(dataType))
extends ComplexColumnBuilder(new MapColumnStats(dataType), MAP(dataType))

private[columnar] object ColumnBuilder {
val DEFAULT_INITIAL_BUFFER_SIZE = 128 * 1024
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
package org.apache.spark.sql.execution.columnar

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, RowOrdering}
import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeMapData, UnsafeRow}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -80,7 +82,7 @@ private[columnar] final class NoopColumnStats extends ColumnStats {
if (!row.isNullAt(ordinal)) {
count += 1
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -96,7 +98,7 @@ private[columnar] final class BooleanColumnStats extends ColumnStats {
val value = row.getBoolean(ordinal)
gatherValueStats(value)
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -120,7 +122,7 @@ private[columnar] final class ByteColumnStats extends ColumnStats {
val value = row.getByte(ordinal)
gatherValueStats(value)
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -144,7 +146,7 @@ private[columnar] final class ShortColumnStats extends ColumnStats {
val value = row.getShort(ordinal)
gatherValueStats(value)
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -168,7 +170,7 @@ private[columnar] final class IntColumnStats extends ColumnStats {
val value = row.getInt(ordinal)
gatherValueStats(value)
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -192,7 +194,7 @@ private[columnar] final class LongColumnStats extends ColumnStats {
val value = row.getLong(ordinal)
gatherValueStats(value)
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -216,7 +218,7 @@ private[columnar] final class FloatColumnStats extends ColumnStats {
val value = row.getFloat(ordinal)
gatherValueStats(value)
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -240,7 +242,7 @@ private[columnar] final class DoubleColumnStats extends ColumnStats {
val value = row.getDouble(ordinal)
gatherValueStats(value)
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -265,7 +267,7 @@ private[columnar] final class StringColumnStats extends ColumnStats {
val size = STRING.actualSize(row, ordinal)
gatherValueStats(value, size)
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -287,7 +289,7 @@ private[columnar] final class BinaryColumnStats extends ColumnStats {
sizeInBytes += size
count += 1
} else {
gatherNullStats
gatherNullStats()
}
}

Expand All @@ -307,7 +309,7 @@ private[columnar] final class DecimalColumnStats(precision: Int, scale: Int) ext
// TODO: this is not right for DecimalType with precision > 18
gatherValueStats(value)
} else {
gatherNullStats
gatherNullStats()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the change to gatherNullStats is necessary...

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this was mostly from the scala style guide since this mutates the backing stats. http://docs.scala-lang.org/style/method-invocation.html#arity-0
I don't have a strong opinion though, so happy to swap it back.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just swap it back to make the diff small.

}
}

Expand All @@ -322,19 +324,68 @@ private[columnar] final class DecimalColumnStats(precision: Int, scale: Int) ext
Array[Any](lower, upper, nullCount, count, sizeInBytes)
}

private[columnar] final class ObjectColumnStats(dataType: DataType) extends ColumnStats {
val columnType = ColumnType(dataType)
private abstract class OrderableSafeColumnStats[T](dataType: DataType) extends ColumnStats {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OrderableObjectColumnStats?

protected var upper: T = _
protected var lower: T = _

private val columnType = ColumnType(dataType)
private val ordering = dataType match {
case x if RowOrdering.isOrderable(dataType) =>
Option(TypeUtils.getInterpretedOrdering(x))
case _ => None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this class is only for "orderable", maybe we don't need optional here and ordering can just be Ordering[T].

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for DataTypes that could be orderable since Arrays and Structs may have children data types that aren't.

}

override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
if (!row.isNullAt(ordinal)) {
val size = columnType.actualSize(row, ordinal)
sizeInBytes += size
sizeInBytes += columnType.actualSize(row, ordinal)
count += 1
ordering.foreach { order =>
Copy link
Member

@kiszk kiszk Jan 11, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have more than one elements in ordering? If not, can we write this without foreach? It could achieve better performance.

val value = getValue(row, ordinal)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can we move this statement out of foreach since this is loop-invariant?

if (upper == null || order.gt(value, upper)) upper = copy(value)
if (lower == null || order.lt(value, lower)) lower = copy(value)
}
} else {
gatherNullStats
gatherNullStats()
}
}

def getValue(row: InternalRow, ordinal: Int): T

def copy(value: T): T

override def collectedStatistics: Array[Any] =
Array[Any](lower, upper, nullCount, count, sizeInBytes)
}

private[columnar] final class ArrayColumnStats(dataType: ArrayType)
extends OrderableSafeColumnStats[UnsafeArrayData](dataType) {
override def getValue(row: InternalRow, ordinal: Int): UnsafeArrayData =
row.getArray(ordinal).asInstanceOf[UnsafeArrayData]

override def copy(value: UnsafeArrayData): UnsafeArrayData = value.copy()
}

private[columnar] final class StructColumnStats(dataType: StructType)
extends OrderableSafeColumnStats[UnsafeRow](dataType) {
private val numFields = dataType.asInstanceOf[StructType].fields.length

override def getValue(row: InternalRow, ordinal: Int): UnsafeRow =
row.getStruct(ordinal, numFields).asInstanceOf[UnsafeRow]

override def copy(value: UnsafeRow): UnsafeRow = value.copy()
}

private[columnar] final class MapColumnStats(dataType: MapType)
extends OrderableSafeColumnStats[UnsafeMapData](dataType) {
override def getValue(row: InternalRow, ordinal: Int): UnsafeMapData =
row.getMap(ordinal).asInstanceOf[UnsafeMapData]

override def copy(value: UnsafeMapData): UnsafeMapData = value.copy()
}

private[columnar] final class NullColumnStats extends ColumnStats {
override def gatherStats(row: InternalRow, ordinal: Int): Unit = gatherNullStats()

override def collectedStatistics: Array[Any] =
Array[Any](null, null, nullCount, count, sizeInBytes)
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,56 @@
package org.apache.spark.sql.execution.columnar

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData, UnsafeProjection}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._

class ColumnStatsSuite extends SparkFunSuite {
testColumnStats(classOf[BooleanColumnStats], BOOLEAN, Array(true, false, 0))
testColumnStats(classOf[ByteColumnStats], BYTE, Array(Byte.MaxValue, Byte.MinValue, 0))
testColumnStats(classOf[ShortColumnStats], SHORT, Array(Short.MaxValue, Short.MinValue, 0))
testColumnStats(classOf[IntColumnStats], INT, Array(Int.MaxValue, Int.MinValue, 0))
testColumnStats(classOf[LongColumnStats], LONG, Array(Long.MaxValue, Long.MinValue, 0))
testColumnStats(classOf[FloatColumnStats], FLOAT, Array(Float.MaxValue, Float.MinValue, 0))
testColumnStats(classOf[DoubleColumnStats], DOUBLE, Array(Double.MaxValue, Double.MinValue, 0))
testColumnStats(classOf[StringColumnStats], STRING, Array(null, null, 0))
testDecimalColumnStats(Array(null, null, 0))
testColumnStats(classOf[BooleanColumnStats], BOOLEAN, Array(true, false, 0, 0, 0))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those changes to testColumnStats seems unnecessary?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The column statistics have 5 fields in their array, so the zip comparison on the initial stats will drop the final two.

testColumnStats(classOf[ByteColumnStats], BYTE, Array(Byte.MaxValue, Byte.MinValue, 0, 0, 0))
testColumnStats(classOf[ShortColumnStats], SHORT, Array(Short.MaxValue, Short.MinValue, 0, 0, 0))
testColumnStats(classOf[IntColumnStats], INT, Array(Int.MaxValue, Int.MinValue, 0, 0, 0))
testColumnStats(classOf[LongColumnStats], LONG, Array(Long.MaxValue, Long.MinValue, 0, 0, 0))
testColumnStats(classOf[FloatColumnStats], FLOAT, Array(Float.MaxValue, Float.MinValue, 0, 0, 0))
testColumnStats(
classOf[DoubleColumnStats], DOUBLE,
Array(Double.MaxValue, Double.MinValue, 0, 0, 0)
)
testColumnStats(classOf[StringColumnStats], STRING, Array(null, null, 0, 0, 0))
testDecimalColumnStats(Array(null, null, 0, 0, 0))

private val orderableArrayDataType = ArrayType(IntegerType)
testOrderableColumnStats(
orderableArrayDataType,
() => new ArrayColumnStats(orderableArrayDataType),
ARRAY(orderableArrayDataType),
orderable = true,
Array(null, null, 0, 0, 0)
)

private val unorderableArrayDataType = ArrayType(MapType(IntegerType, StringType))
testOrderableColumnStats(
unorderableArrayDataType,
() => new ArrayColumnStats(unorderableArrayDataType),
ARRAY(unorderableArrayDataType),
orderable = false,
Array(null, null, 0, 0, 0)
)

private val structDataType = StructType(Array(StructField("test", DataTypes.StringType)))
testOrderableColumnStats(
structDataType,
() => new StructColumnStats(structDataType),
STRUCT(structDataType),
orderable = true,
Array(null, null, 0, 0, 0)
)
testMapColumnStats(
MapType(IntegerType, StringType),
Array(null, null, 0, 0, 0)
)


def testColumnStats[T <: AtomicType, U <: ColumnStats](
columnStatsClass: Class[U],
Expand Down Expand Up @@ -103,4 +141,108 @@ class ColumnStatsSuite extends SparkFunSuite {
}
}
}

def testOrderableColumnStats[T](
dataType: DataType,
statsSupplier: () => OrderableSafeColumnStats[T],
columnType: ColumnType[T],
orderable: Boolean,
initialStatistics: Array[Any]): Unit = {

test(s"${dataType.typeName}, $orderable: empty") {
val objectStats = statsSupplier()
objectStats.collectedStatistics.zip(initialStatistics).foreach {
case (actual, expected) => assert(actual === expected)
}
}

test(s"${dataType.typeName}, $orderable: non-empty") {
import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._
val objectStats = statsSupplier()
val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1))
rows.foreach(objectStats.gatherStats(_, 0))

val stats = objectStats.collectedStatistics
if (orderable) {
val values = rows.take(10).map(_.get(0, columnType.dataType))
val ordering = TypeUtils.getInterpretedOrdering(dataType)

assertResult(values.min(ordering), "Wrong lower bound")(stats(0))
assertResult(values.max(ordering), "Wrong upper bound")(stats(1))
} else {
assertResult(null, "Wrong lower bound")(stats(0))
assertResult(null, "Wrong upper bound")(stats(1))
}
assertResult(10, "Wrong null count")(stats(2))
assertResult(20, "Wrong row count")(stats(3))
assertResult(stats(4), "Wrong size in bytes") {
rows.map { row =>
if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0)
}.sum
}
}
}

def testMapColumnStats(dataType: MapType, initialStatistics: Array[Any]): Unit = {
val columnType = ColumnType(dataType)

test(s"${dataType.typeName}: empty") {
val objectStats = new MapColumnStats(dataType)
objectStats.collectedStatistics.zip(initialStatistics).foreach {
case (actual, expected) => assert(actual === expected)
}
}

test(s"${dataType.typeName}: non-empty") {
import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._
val objectStats = new MapColumnStats(dataType)
val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1))
rows.foreach(objectStats.gatherStats(_, 0))

val stats = objectStats.collectedStatistics
assertResult(null, "Wrong lower bound")(stats(0))
assertResult(null, "Wrong upper bound")(stats(1))
assertResult(10, "Wrong null count")(stats(2))
assertResult(20, "Wrong row count")(stats(3))
assertResult(stats(4), "Wrong size in bytes") {
rows.map { row =>
if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0)
}.sum
}
}
}

test("Reuse UnsafeArrayData for stats") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also test against UnsafeRow too.

val stats = new ArrayColumnStats(ArrayType(IntegerType))
val unsafeData = UnsafeArrayData.fromPrimitiveArray(Array(1))
(1 to 10).foreach { value =>
val row = new GenericInternalRow(Array[Any](unsafeData))
unsafeData.setInt(0, value)
stats.gatherStats(row, 0)
}
val collected = stats.collectedStatistics
assertResult(UnsafeArrayData.fromPrimitiveArray(Array(1)))(collected(0))
assertResult(UnsafeArrayData.fromPrimitiveArray(Array(10)))(collected(1))
assertResult(0)(collected(2))
assertResult(10)(collected(3))
assertResult(10 * (4 + unsafeData.getSizeInBytes))(collected(4))
}

test("Reuse UnsafeRow for stats") {
val structType = StructType(Array(StructField("int", IntegerType)))
val stats = new StructColumnStats(structType)
val converter = UnsafeProjection.create(structType)
val unsafeData = converter(InternalRow(1))
(1 to 10).foreach { value =>
val row = new GenericInternalRow(Array[Any](unsafeData))
unsafeData.setInt(0, value)
stats.gatherStats(row, 0)
}
val collected = stats.collectedStatistics
assertResult(converter(InternalRow(1)))(collected(0))
assertResult(converter(InternalRow(10)))(collected(1))
assertResult(0)(collected(2))
assertResult(10)(collected(3))
assertResult(10 * (4 + unsafeData.getSizeInBytes))(collected(4))
}
}
Loading