-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-23819][SQL] Fix InMemoryTableScanExec complex type pruning #20935
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
83e1e53
5c95cef
a63eb59
426374b
1479bde
6ea0919
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,7 +18,9 @@ | |
| package org.apache.spark.sql.execution.columnar | ||
|
|
||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} | ||
| import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, RowOrdering} | ||
| import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeMapData, UnsafeRow} | ||
| import org.apache.spark.sql.catalyst.util.TypeUtils | ||
| import org.apache.spark.sql.types._ | ||
| import org.apache.spark.unsafe.types.UTF8String | ||
|
|
||
|
|
@@ -80,7 +82,7 @@ private[columnar] final class NoopColumnStats extends ColumnStats { | |
| if (!row.isNullAt(ordinal)) { | ||
| count += 1 | ||
| } else { | ||
| gatherNullStats | ||
| gatherNullStats() | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -96,7 +98,7 @@ private[columnar] final class BooleanColumnStats extends ColumnStats { | |
| val value = row.getBoolean(ordinal) | ||
| gatherValueStats(value) | ||
| } else { | ||
| gatherNullStats | ||
| gatherNullStats() | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -120,7 +122,7 @@ private[columnar] final class ByteColumnStats extends ColumnStats { | |
| val value = row.getByte(ordinal) | ||
| gatherValueStats(value) | ||
| } else { | ||
| gatherNullStats | ||
| gatherNullStats() | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -144,7 +146,7 @@ private[columnar] final class ShortColumnStats extends ColumnStats { | |
| val value = row.getShort(ordinal) | ||
| gatherValueStats(value) | ||
| } else { | ||
| gatherNullStats | ||
| gatherNullStats() | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -168,7 +170,7 @@ private[columnar] final class IntColumnStats extends ColumnStats { | |
| val value = row.getInt(ordinal) | ||
| gatherValueStats(value) | ||
| } else { | ||
| gatherNullStats | ||
| gatherNullStats() | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -192,7 +194,7 @@ private[columnar] final class LongColumnStats extends ColumnStats { | |
| val value = row.getLong(ordinal) | ||
| gatherValueStats(value) | ||
| } else { | ||
| gatherNullStats | ||
| gatherNullStats() | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -216,7 +218,7 @@ private[columnar] final class FloatColumnStats extends ColumnStats { | |
| val value = row.getFloat(ordinal) | ||
| gatherValueStats(value) | ||
| } else { | ||
| gatherNullStats | ||
| gatherNullStats() | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -240,7 +242,7 @@ private[columnar] final class DoubleColumnStats extends ColumnStats { | |
| val value = row.getDouble(ordinal) | ||
| gatherValueStats(value) | ||
| } else { | ||
| gatherNullStats | ||
| gatherNullStats() | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -265,7 +267,7 @@ private[columnar] final class StringColumnStats extends ColumnStats { | |
| val size = STRING.actualSize(row, ordinal) | ||
| gatherValueStats(value, size) | ||
| } else { | ||
| gatherNullStats | ||
| gatherNullStats() | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -287,7 +289,7 @@ private[columnar] final class BinaryColumnStats extends ColumnStats { | |
| sizeInBytes += size | ||
| count += 1 | ||
| } else { | ||
| gatherNullStats | ||
| gatherNullStats() | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -307,7 +309,7 @@ private[columnar] final class DecimalColumnStats(precision: Int, scale: Int) ext | |
| // TODO: this is not right for DecimalType with precision > 18 | ||
| gatherValueStats(value) | ||
| } else { | ||
| gatherNullStats | ||
| gatherNullStats() | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -322,19 +324,68 @@ private[columnar] final class DecimalColumnStats(precision: Int, scale: Int) ext | |
| Array[Any](lower, upper, nullCount, count, sizeInBytes) | ||
| } | ||
|
|
||
| private[columnar] final class ObjectColumnStats(dataType: DataType) extends ColumnStats { | ||
| val columnType = ColumnType(dataType) | ||
| private abstract class OrderableSafeColumnStats[T](dataType: DataType) extends ColumnStats { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| protected var upper: T = _ | ||
| protected var lower: T = _ | ||
|
|
||
| private val columnType = ColumnType(dataType) | ||
| private val ordering = dataType match { | ||
| case x if RowOrdering.isOrderable(dataType) => | ||
| Option(TypeUtils.getInterpretedOrdering(x)) | ||
| case _ => None | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this class is only for "orderable", maybe we don't need optional here and
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is for DataTypes that could be orderable since Arrays and Structs may have children data types that aren't. |
||
| } | ||
|
|
||
| override def gatherStats(row: InternalRow, ordinal: Int): Unit = { | ||
| if (!row.isNullAt(ordinal)) { | ||
| val size = columnType.actualSize(row, ordinal) | ||
| sizeInBytes += size | ||
| sizeInBytes += columnType.actualSize(row, ordinal) | ||
| count += 1 | ||
| ordering.foreach { order => | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we have more than one elements in |
||
| val value = getValue(row, ordinal) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Can we move this statement out of |
||
| if (upper == null || order.gt(value, upper)) upper = copy(value) | ||
| if (lower == null || order.lt(value, lower)) lower = copy(value) | ||
| } | ||
| } else { | ||
| gatherNullStats | ||
| gatherNullStats() | ||
| } | ||
| } | ||
|
|
||
| def getValue(row: InternalRow, ordinal: Int): T | ||
|
|
||
| def copy(value: T): T | ||
|
|
||
| override def collectedStatistics: Array[Any] = | ||
| Array[Any](lower, upper, nullCount, count, sizeInBytes) | ||
| } | ||
|
|
||
| private[columnar] final class ArrayColumnStats(dataType: ArrayType) | ||
| extends OrderableSafeColumnStats[UnsafeArrayData](dataType) { | ||
| override def getValue(row: InternalRow, ordinal: Int): UnsafeArrayData = | ||
| row.getArray(ordinal).asInstanceOf[UnsafeArrayData] | ||
|
|
||
| override def copy(value: UnsafeArrayData): UnsafeArrayData = value.copy() | ||
| } | ||
|
|
||
| private[columnar] final class StructColumnStats(dataType: StructType) | ||
| extends OrderableSafeColumnStats[UnsafeRow](dataType) { | ||
| private val numFields = dataType.asInstanceOf[StructType].fields.length | ||
|
|
||
| override def getValue(row: InternalRow, ordinal: Int): UnsafeRow = | ||
| row.getStruct(ordinal, numFields).asInstanceOf[UnsafeRow] | ||
|
|
||
| override def copy(value: UnsafeRow): UnsafeRow = value.copy() | ||
| } | ||
|
|
||
| private[columnar] final class MapColumnStats(dataType: MapType) | ||
| extends OrderableSafeColumnStats[UnsafeMapData](dataType) { | ||
| override def getValue(row: InternalRow, ordinal: Int): UnsafeMapData = | ||
| row.getMap(ordinal).asInstanceOf[UnsafeMapData] | ||
|
|
||
| override def copy(value: UnsafeMapData): UnsafeMapData = value.copy() | ||
| } | ||
|
|
||
| private[columnar] final class NullColumnStats extends ColumnStats { | ||
| override def gatherStats(row: InternalRow, ordinal: Int): Unit = gatherNullStats() | ||
|
|
||
| override def collectedStatistics: Array[Any] = | ||
| Array[Any](null, null, nullCount, count, sizeInBytes) | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,18 +18,56 @@ | |
| package org.apache.spark.sql.execution.columnar | ||
|
|
||
| import org.apache.spark.SparkFunSuite | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData, UnsafeProjection} | ||
| import org.apache.spark.sql.catalyst.util.TypeUtils | ||
| import org.apache.spark.sql.types._ | ||
|
|
||
| class ColumnStatsSuite extends SparkFunSuite { | ||
| testColumnStats(classOf[BooleanColumnStats], BOOLEAN, Array(true, false, 0)) | ||
| testColumnStats(classOf[ByteColumnStats], BYTE, Array(Byte.MaxValue, Byte.MinValue, 0)) | ||
| testColumnStats(classOf[ShortColumnStats], SHORT, Array(Short.MaxValue, Short.MinValue, 0)) | ||
| testColumnStats(classOf[IntColumnStats], INT, Array(Int.MaxValue, Int.MinValue, 0)) | ||
| testColumnStats(classOf[LongColumnStats], LONG, Array(Long.MaxValue, Long.MinValue, 0)) | ||
| testColumnStats(classOf[FloatColumnStats], FLOAT, Array(Float.MaxValue, Float.MinValue, 0)) | ||
| testColumnStats(classOf[DoubleColumnStats], DOUBLE, Array(Double.MaxValue, Double.MinValue, 0)) | ||
| testColumnStats(classOf[StringColumnStats], STRING, Array(null, null, 0)) | ||
| testDecimalColumnStats(Array(null, null, 0)) | ||
| testColumnStats(classOf[BooleanColumnStats], BOOLEAN, Array(true, false, 0, 0, 0)) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Those changes to
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The column statistics have 5 fields in their array, so the zip comparison on the initial stats will drop the final two. |
||
| testColumnStats(classOf[ByteColumnStats], BYTE, Array(Byte.MaxValue, Byte.MinValue, 0, 0, 0)) | ||
| testColumnStats(classOf[ShortColumnStats], SHORT, Array(Short.MaxValue, Short.MinValue, 0, 0, 0)) | ||
| testColumnStats(classOf[IntColumnStats], INT, Array(Int.MaxValue, Int.MinValue, 0, 0, 0)) | ||
| testColumnStats(classOf[LongColumnStats], LONG, Array(Long.MaxValue, Long.MinValue, 0, 0, 0)) | ||
| testColumnStats(classOf[FloatColumnStats], FLOAT, Array(Float.MaxValue, Float.MinValue, 0, 0, 0)) | ||
| testColumnStats( | ||
| classOf[DoubleColumnStats], DOUBLE, | ||
| Array(Double.MaxValue, Double.MinValue, 0, 0, 0) | ||
| ) | ||
| testColumnStats(classOf[StringColumnStats], STRING, Array(null, null, 0, 0, 0)) | ||
| testDecimalColumnStats(Array(null, null, 0, 0, 0)) | ||
|
|
||
| private val orderableArrayDataType = ArrayType(IntegerType) | ||
| testOrderableColumnStats( | ||
| orderableArrayDataType, | ||
| () => new ArrayColumnStats(orderableArrayDataType), | ||
| ARRAY(orderableArrayDataType), | ||
| orderable = true, | ||
| Array(null, null, 0, 0, 0) | ||
| ) | ||
|
|
||
| private val unorderableArrayDataType = ArrayType(MapType(IntegerType, StringType)) | ||
| testOrderableColumnStats( | ||
| unorderableArrayDataType, | ||
| () => new ArrayColumnStats(unorderableArrayDataType), | ||
| ARRAY(unorderableArrayDataType), | ||
| orderable = false, | ||
| Array(null, null, 0, 0, 0) | ||
| ) | ||
|
|
||
| private val structDataType = StructType(Array(StructField("test", DataTypes.StringType))) | ||
| testOrderableColumnStats( | ||
| structDataType, | ||
| () => new StructColumnStats(structDataType), | ||
| STRUCT(structDataType), | ||
| orderable = true, | ||
| Array(null, null, 0, 0, 0) | ||
| ) | ||
| testMapColumnStats( | ||
| MapType(IntegerType, StringType), | ||
| Array(null, null, 0, 0, 0) | ||
| ) | ||
|
|
||
|
|
||
| def testColumnStats[T <: AtomicType, U <: ColumnStats]( | ||
| columnStatsClass: Class[U], | ||
|
|
@@ -103,4 +141,108 @@ class ColumnStatsSuite extends SparkFunSuite { | |
| } | ||
| } | ||
| } | ||
|
|
||
| def testOrderableColumnStats[T]( | ||
| dataType: DataType, | ||
| statsSupplier: () => OrderableSafeColumnStats[T], | ||
| columnType: ColumnType[T], | ||
| orderable: Boolean, | ||
| initialStatistics: Array[Any]): Unit = { | ||
|
|
||
| test(s"${dataType.typeName}, $orderable: empty") { | ||
| val objectStats = statsSupplier() | ||
| objectStats.collectedStatistics.zip(initialStatistics).foreach { | ||
| case (actual, expected) => assert(actual === expected) | ||
| } | ||
| } | ||
|
|
||
| test(s"${dataType.typeName}, $orderable: non-empty") { | ||
| import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ | ||
| val objectStats = statsSupplier() | ||
| val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) | ||
| rows.foreach(objectStats.gatherStats(_, 0)) | ||
|
|
||
| val stats = objectStats.collectedStatistics | ||
| if (orderable) { | ||
| val values = rows.take(10).map(_.get(0, columnType.dataType)) | ||
| val ordering = TypeUtils.getInterpretedOrdering(dataType) | ||
|
|
||
| assertResult(values.min(ordering), "Wrong lower bound")(stats(0)) | ||
| assertResult(values.max(ordering), "Wrong upper bound")(stats(1)) | ||
| } else { | ||
| assertResult(null, "Wrong lower bound")(stats(0)) | ||
| assertResult(null, "Wrong upper bound")(stats(1)) | ||
| } | ||
| assertResult(10, "Wrong null count")(stats(2)) | ||
| assertResult(20, "Wrong row count")(stats(3)) | ||
| assertResult(stats(4), "Wrong size in bytes") { | ||
| rows.map { row => | ||
| if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) | ||
| }.sum | ||
| } | ||
| } | ||
| } | ||
|
|
||
| def testMapColumnStats(dataType: MapType, initialStatistics: Array[Any]): Unit = { | ||
| val columnType = ColumnType(dataType) | ||
|
|
||
| test(s"${dataType.typeName}: empty") { | ||
| val objectStats = new MapColumnStats(dataType) | ||
| objectStats.collectedStatistics.zip(initialStatistics).foreach { | ||
| case (actual, expected) => assert(actual === expected) | ||
| } | ||
| } | ||
|
|
||
| test(s"${dataType.typeName}: non-empty") { | ||
| import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ | ||
| val objectStats = new MapColumnStats(dataType) | ||
| val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) | ||
| rows.foreach(objectStats.gatherStats(_, 0)) | ||
|
|
||
| val stats = objectStats.collectedStatistics | ||
| assertResult(null, "Wrong lower bound")(stats(0)) | ||
| assertResult(null, "Wrong upper bound")(stats(1)) | ||
| assertResult(10, "Wrong null count")(stats(2)) | ||
| assertResult(20, "Wrong row count")(stats(3)) | ||
| assertResult(stats(4), "Wrong size in bytes") { | ||
| rows.map { row => | ||
| if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) | ||
| }.sum | ||
| } | ||
| } | ||
| } | ||
|
|
||
| test("Reuse UnsafeArrayData for stats") { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should also test against UnsafeRow too. |
||
| val stats = new ArrayColumnStats(ArrayType(IntegerType)) | ||
| val unsafeData = UnsafeArrayData.fromPrimitiveArray(Array(1)) | ||
| (1 to 10).foreach { value => | ||
| val row = new GenericInternalRow(Array[Any](unsafeData)) | ||
| unsafeData.setInt(0, value) | ||
| stats.gatherStats(row, 0) | ||
| } | ||
| val collected = stats.collectedStatistics | ||
| assertResult(UnsafeArrayData.fromPrimitiveArray(Array(1)))(collected(0)) | ||
| assertResult(UnsafeArrayData.fromPrimitiveArray(Array(10)))(collected(1)) | ||
| assertResult(0)(collected(2)) | ||
| assertResult(10)(collected(3)) | ||
| assertResult(10 * (4 + unsafeData.getSizeInBytes))(collected(4)) | ||
| } | ||
|
|
||
| test("Reuse UnsafeRow for stats") { | ||
| val structType = StructType(Array(StructField("int", IntegerType))) | ||
| val stats = new StructColumnStats(structType) | ||
| val converter = UnsafeProjection.create(structType) | ||
| val unsafeData = converter(InternalRow(1)) | ||
| (1 to 10).foreach { value => | ||
| val row = new GenericInternalRow(Array[Any](unsafeData)) | ||
| unsafeData.setInt(0, value) | ||
| stats.gatherStats(row, 0) | ||
| } | ||
| val collected = stats.collectedStatistics | ||
| assertResult(converter(InternalRow(1)))(collected(0)) | ||
| assertResult(converter(InternalRow(10)))(collected(1)) | ||
| assertResult(0)(collected(2)) | ||
| assertResult(10)(collected(3)) | ||
| assertResult(10 * (4 + unsafeData.getSizeInBytes))(collected(4)) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think the change to
gatherNullStatsis necessary...There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this was mostly from the scala style guide since this mutates the backing stats. http://docs.scala-lang.org/style/method-invocation.html#arity-0
I don't have a strong opinion though, so happy to swap it back.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's just swap it back to make the diff small.