diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala index d30655e0c4a2..48434d8d86a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala @@ -85,7 +85,7 @@ private[columnar] class BasicColumnBuilder[JvmType]( } private[columnar] class NullColumnBuilder - extends BasicColumnBuilder[Any](new ObjectColumnStats(NullType), NULL) + extends BasicColumnBuilder[Any](new NullColumnStats, NULL) with NullableColumnBuilder private[columnar] abstract class ComplexColumnBuilder[JvmType]( @@ -132,13 +132,13 @@ private[columnar] class DecimalColumnBuilder(dataType: DecimalType) extends ComplexColumnBuilder(new DecimalColumnStats(dataType), LARGE_DECIMAL(dataType)) private[columnar] class StructColumnBuilder(dataType: StructType) - extends ComplexColumnBuilder(new ObjectColumnStats(dataType), STRUCT(dataType)) + extends ComplexColumnBuilder(new StructColumnStats(dataType), STRUCT(dataType)) private[columnar] class ArrayColumnBuilder(dataType: ArrayType) - extends ComplexColumnBuilder(new ObjectColumnStats(dataType), ARRAY(dataType)) + extends ComplexColumnBuilder(new ArrayColumnStats(dataType), ARRAY(dataType)) private[columnar] class MapColumnBuilder(dataType: MapType) - extends ComplexColumnBuilder(new ObjectColumnStats(dataType), MAP(dataType)) + extends ComplexColumnBuilder(new MapColumnStats(dataType), MAP(dataType)) private[columnar] object ColumnBuilder { val DEFAULT_INITIAL_BUFFER_SIZE = 128 * 1024 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala index bc7e73ae1ba8..f054d21860f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.execution.columnar import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, RowOrdering} +import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeMapData, UnsafeRow} +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -80,7 +82,7 @@ private[columnar] final class NoopColumnStats extends ColumnStats { if (!row.isNullAt(ordinal)) { count += 1 } else { - gatherNullStats + gatherNullStats() } } @@ -96,7 +98,7 @@ private[columnar] final class BooleanColumnStats extends ColumnStats { val value = row.getBoolean(ordinal) gatherValueStats(value) } else { - gatherNullStats + gatherNullStats() } } @@ -120,7 +122,7 @@ private[columnar] final class ByteColumnStats extends ColumnStats { val value = row.getByte(ordinal) gatherValueStats(value) } else { - gatherNullStats + gatherNullStats() } } @@ -144,7 +146,7 @@ private[columnar] final class ShortColumnStats extends ColumnStats { val value = row.getShort(ordinal) gatherValueStats(value) } else { - gatherNullStats + gatherNullStats() } } @@ -168,7 +170,7 @@ private[columnar] final class IntColumnStats extends ColumnStats { val value = row.getInt(ordinal) gatherValueStats(value) } else { - gatherNullStats + gatherNullStats() } } @@ -192,7 +194,7 @@ private[columnar] final class LongColumnStats extends ColumnStats { val value = row.getLong(ordinal) gatherValueStats(value) } else { - gatherNullStats + gatherNullStats() } } @@ -216,7 +218,7 @@ private[columnar] final class FloatColumnStats extends ColumnStats { val value = row.getFloat(ordinal) gatherValueStats(value) } else { - gatherNullStats + gatherNullStats() } } @@ -240,7 +242,7 @@ private[columnar] final class DoubleColumnStats extends ColumnStats { val value = row.getDouble(ordinal) gatherValueStats(value) } else { - gatherNullStats + gatherNullStats() } } @@ -265,7 +267,7 @@ private[columnar] final class StringColumnStats extends ColumnStats { val size = STRING.actualSize(row, ordinal) gatherValueStats(value, size) } else { - gatherNullStats + gatherNullStats() } } @@ -287,7 +289,7 @@ private[columnar] final class BinaryColumnStats extends ColumnStats { sizeInBytes += size count += 1 } else { - gatherNullStats + gatherNullStats() } } @@ -307,7 +309,7 @@ private[columnar] final class DecimalColumnStats(precision: Int, scale: Int) ext // TODO: this is not right for DecimalType with precision > 18 gatherValueStats(value) } else { - gatherNullStats + gatherNullStats() } } @@ -322,19 +324,68 @@ private[columnar] final class DecimalColumnStats(precision: Int, scale: Int) ext Array[Any](lower, upper, nullCount, count, sizeInBytes) } -private[columnar] final class ObjectColumnStats(dataType: DataType) extends ColumnStats { - val columnType = ColumnType(dataType) +private abstract class OrderableSafeColumnStats[T](dataType: DataType) extends ColumnStats { + protected var upper: T = _ + protected var lower: T = _ + + private val columnType = ColumnType(dataType) + private val ordering = dataType match { + case x if RowOrdering.isOrderable(dataType) => + Option(TypeUtils.getInterpretedOrdering(x)) + case _ => None + } override def gatherStats(row: InternalRow, ordinal: Int): Unit = { if (!row.isNullAt(ordinal)) { - val size = columnType.actualSize(row, ordinal) - sizeInBytes += size + sizeInBytes += columnType.actualSize(row, ordinal) count += 1 + ordering.foreach { order => + val value = getValue(row, ordinal) + if (upper == null || order.gt(value, upper)) upper = copy(value) + if (lower == null || order.lt(value, lower)) lower = copy(value) + } } else { - gatherNullStats + gatherNullStats() } } + def getValue(row: InternalRow, ordinal: Int): T + + def copy(value: T): T + + override def collectedStatistics: Array[Any] = + Array[Any](lower, upper, nullCount, count, sizeInBytes) +} + +private[columnar] final class ArrayColumnStats(dataType: ArrayType) + extends OrderableSafeColumnStats[UnsafeArrayData](dataType) { + override def getValue(row: InternalRow, ordinal: Int): UnsafeArrayData = + row.getArray(ordinal).asInstanceOf[UnsafeArrayData] + + override def copy(value: UnsafeArrayData): UnsafeArrayData = value.copy() +} + +private[columnar] final class StructColumnStats(dataType: StructType) + extends OrderableSafeColumnStats[UnsafeRow](dataType) { + private val numFields = dataType.asInstanceOf[StructType].fields.length + + override def getValue(row: InternalRow, ordinal: Int): UnsafeRow = + row.getStruct(ordinal, numFields).asInstanceOf[UnsafeRow] + + override def copy(value: UnsafeRow): UnsafeRow = value.copy() +} + +private[columnar] final class MapColumnStats(dataType: MapType) + extends OrderableSafeColumnStats[UnsafeMapData](dataType) { + override def getValue(row: InternalRow, ordinal: Int): UnsafeMapData = + row.getMap(ordinal).asInstanceOf[UnsafeMapData] + + override def copy(value: UnsafeMapData): UnsafeMapData = value.copy() +} + +private[columnar] final class NullColumnStats extends ColumnStats { + override def gatherStats(row: InternalRow, ordinal: Int): Unit = gatherNullStats() + override def collectedStatistics: Array[Any] = Array[Any](null, null, nullCount, count, sizeInBytes) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala index d4e7e362c6c8..398f00b9395b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala @@ -18,18 +18,56 @@ package org.apache.spark.sql.execution.columnar import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData, UnsafeProjection} +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ class ColumnStatsSuite extends SparkFunSuite { - testColumnStats(classOf[BooleanColumnStats], BOOLEAN, Array(true, false, 0)) - testColumnStats(classOf[ByteColumnStats], BYTE, Array(Byte.MaxValue, Byte.MinValue, 0)) - testColumnStats(classOf[ShortColumnStats], SHORT, Array(Short.MaxValue, Short.MinValue, 0)) - testColumnStats(classOf[IntColumnStats], INT, Array(Int.MaxValue, Int.MinValue, 0)) - testColumnStats(classOf[LongColumnStats], LONG, Array(Long.MaxValue, Long.MinValue, 0)) - testColumnStats(classOf[FloatColumnStats], FLOAT, Array(Float.MaxValue, Float.MinValue, 0)) - testColumnStats(classOf[DoubleColumnStats], DOUBLE, Array(Double.MaxValue, Double.MinValue, 0)) - testColumnStats(classOf[StringColumnStats], STRING, Array(null, null, 0)) - testDecimalColumnStats(Array(null, null, 0)) + testColumnStats(classOf[BooleanColumnStats], BOOLEAN, Array(true, false, 0, 0, 0)) + testColumnStats(classOf[ByteColumnStats], BYTE, Array(Byte.MaxValue, Byte.MinValue, 0, 0, 0)) + testColumnStats(classOf[ShortColumnStats], SHORT, Array(Short.MaxValue, Short.MinValue, 0, 0, 0)) + testColumnStats(classOf[IntColumnStats], INT, Array(Int.MaxValue, Int.MinValue, 0, 0, 0)) + testColumnStats(classOf[LongColumnStats], LONG, Array(Long.MaxValue, Long.MinValue, 0, 0, 0)) + testColumnStats(classOf[FloatColumnStats], FLOAT, Array(Float.MaxValue, Float.MinValue, 0, 0, 0)) + testColumnStats( + classOf[DoubleColumnStats], DOUBLE, + Array(Double.MaxValue, Double.MinValue, 0, 0, 0) + ) + testColumnStats(classOf[StringColumnStats], STRING, Array(null, null, 0, 0, 0)) + testDecimalColumnStats(Array(null, null, 0, 0, 0)) + + private val orderableArrayDataType = ArrayType(IntegerType) + testOrderableColumnStats( + orderableArrayDataType, + () => new ArrayColumnStats(orderableArrayDataType), + ARRAY(orderableArrayDataType), + orderable = true, + Array(null, null, 0, 0, 0) + ) + + private val unorderableArrayDataType = ArrayType(MapType(IntegerType, StringType)) + testOrderableColumnStats( + unorderableArrayDataType, + () => new ArrayColumnStats(unorderableArrayDataType), + ARRAY(unorderableArrayDataType), + orderable = false, + Array(null, null, 0, 0, 0) + ) + + private val structDataType = StructType(Array(StructField("test", DataTypes.StringType))) + testOrderableColumnStats( + structDataType, + () => new StructColumnStats(structDataType), + STRUCT(structDataType), + orderable = true, + Array(null, null, 0, 0, 0) + ) + testMapColumnStats( + MapType(IntegerType, StringType), + Array(null, null, 0, 0, 0) + ) + def testColumnStats[T <: AtomicType, U <: ColumnStats]( columnStatsClass: Class[U], @@ -103,4 +141,108 @@ class ColumnStatsSuite extends SparkFunSuite { } } } + + def testOrderableColumnStats[T]( + dataType: DataType, + statsSupplier: () => OrderableSafeColumnStats[T], + columnType: ColumnType[T], + orderable: Boolean, + initialStatistics: Array[Any]): Unit = { + + test(s"${dataType.typeName}, $orderable: empty") { + val objectStats = statsSupplier() + objectStats.collectedStatistics.zip(initialStatistics).foreach { + case (actual, expected) => assert(actual === expected) + } + } + + test(s"${dataType.typeName}, $orderable: non-empty") { + import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ + val objectStats = statsSupplier() + val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) + rows.foreach(objectStats.gatherStats(_, 0)) + + val stats = objectStats.collectedStatistics + if (orderable) { + val values = rows.take(10).map(_.get(0, columnType.dataType)) + val ordering = TypeUtils.getInterpretedOrdering(dataType) + + assertResult(values.min(ordering), "Wrong lower bound")(stats(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats(1)) + } else { + assertResult(null, "Wrong lower bound")(stats(0)) + assertResult(null, "Wrong upper bound")(stats(1)) + } + assertResult(10, "Wrong null count")(stats(2)) + assertResult(20, "Wrong row count")(stats(3)) + assertResult(stats(4), "Wrong size in bytes") { + rows.map { row => + if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) + }.sum + } + } + } + + def testMapColumnStats(dataType: MapType, initialStatistics: Array[Any]): Unit = { + val columnType = ColumnType(dataType) + + test(s"${dataType.typeName}: empty") { + val objectStats = new MapColumnStats(dataType) + objectStats.collectedStatistics.zip(initialStatistics).foreach { + case (actual, expected) => assert(actual === expected) + } + } + + test(s"${dataType.typeName}: non-empty") { + import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ + val objectStats = new MapColumnStats(dataType) + val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) + rows.foreach(objectStats.gatherStats(_, 0)) + + val stats = objectStats.collectedStatistics + assertResult(null, "Wrong lower bound")(stats(0)) + assertResult(null, "Wrong upper bound")(stats(1)) + assertResult(10, "Wrong null count")(stats(2)) + assertResult(20, "Wrong row count")(stats(3)) + assertResult(stats(4), "Wrong size in bytes") { + rows.map { row => + if (row.isNullAt(0)) 4 else columnType.actualSize(row, 0) + }.sum + } + } + } + + test("Reuse UnsafeArrayData for stats") { + val stats = new ArrayColumnStats(ArrayType(IntegerType)) + val unsafeData = UnsafeArrayData.fromPrimitiveArray(Array(1)) + (1 to 10).foreach { value => + val row = new GenericInternalRow(Array[Any](unsafeData)) + unsafeData.setInt(0, value) + stats.gatherStats(row, 0) + } + val collected = stats.collectedStatistics + assertResult(UnsafeArrayData.fromPrimitiveArray(Array(1)))(collected(0)) + assertResult(UnsafeArrayData.fromPrimitiveArray(Array(10)))(collected(1)) + assertResult(0)(collected(2)) + assertResult(10)(collected(3)) + assertResult(10 * (4 + unsafeData.getSizeInBytes))(collected(4)) + } + + test("Reuse UnsafeRow for stats") { + val structType = StructType(Array(StructField("int", IntegerType))) + val stats = new StructColumnStats(structType) + val converter = UnsafeProjection.create(structType) + val unsafeData = converter(InternalRow(1)) + (1 to 10).foreach { value => + val row = new GenericInternalRow(Array[Any](unsafeData)) + unsafeData.setInt(0, value) + stats.gatherStats(row, 0) + } + val collected = stats.collectedStatistics + assertResult(converter(InternalRow(1)))(collected(0)) + assertResult(converter(InternalRow(10)))(collected(1)) + assertResult(0)(collected(2)) + assertResult(10)(collected(3)) + assertResult(10 * (4 + unsafeData.getSizeInBytes))(collected(4)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala index 686c8fa6f5fa..2647331c3f72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala @@ -21,9 +21,9 @@ import scala.collection.immutable.HashSet import scala.util.Random import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} -import org.apache.spark.sql.types.{AtomicType, Decimal} +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData, UnsafeMapData, UnsafeProjection} +import org.apache.spark.sql.catalyst.util.ArrayBasedMapData +import org.apache.spark.sql.types.{AtomicType, DataType, Decimal, IntegerType, MapType, StringType, StructField, StructType} import org.apache.spark.unsafe.types.UTF8String object ColumnarTestUtils { @@ -54,12 +54,22 @@ object ColumnarTestUtils { case COMPACT_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale) case LARGE_DECIMAL(precision, scale) => Decimal(Random.nextLong(), precision, scale) case STRUCT(_) => - new GenericInternalRow(Array[Any](UTF8String.fromString(Random.nextString(10)))) + val schema = StructType(Array(StructField("test", StringType))) + val converter = UnsafeProjection.create(schema) + converter(InternalRow(Array(UTF8String.fromString(Random.nextString(10))): _*)) case ARRAY(_) => - new GenericArrayData(Array[Any](Random.nextInt(), Random.nextInt())) + UnsafeArrayData.fromPrimitiveArray(Array(Random.nextInt(), Random.nextInt())) case MAP(_) => - ArrayBasedMapData( - Map(Random.nextInt() -> UTF8String.fromString(Random.nextString(Random.nextInt(32))))) + val unsafeConverter = + UnsafeProjection.create(Array[DataType](MapType(IntegerType, StringType))) + val row = new GenericInternalRow(1) + def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = { + row.update(0, map) + val unsafeRow = unsafeConverter.apply(row) + unsafeRow.getMap(0).copy + } + toUnsafeMap(ArrayBasedMapData( + Map(Random.nextInt() -> UTF8String.fromString(Random.nextString(Random.nextInt(32)))))) case _ => throw new IllegalArgumentException(s"Unknown column type $columnType") }).asInstanceOf[JvmType] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 26b63e8e8490..bc5ab3dcd7ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -30,7 +30,6 @@ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel._ -import org.apache.spark.util.Utils class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -527,4 +526,10 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-23819: Complex type pruning should utilize proper statistics") { + val df = Seq((Array(1), (1, 1))).toDF("arr", "struct").cache() + assert(df.where("arr <=> array(1)").count() === 1) + assert(df.where("struct <=> named_struct('_1', 1, '_2', 1)").count() === 1) + } }