Skip to content

Commit 6ea0919

Browse files
author
Patrick Woody
committed
extra test make Map orderable safe
1 parent 1479bde commit 6ea0919

File tree

2 files changed

+29
-18
lines changed

2 files changed

+29
-18
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.columnar
1919

2020
import org.apache.spark.sql.catalyst.InternalRow
2121
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, RowOrdering}
22-
import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeRow}
22+
import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeMapData, UnsafeRow}
2323
import org.apache.spark.sql.catalyst.util.TypeUtils
2424
import org.apache.spark.sql.types._
2525
import org.apache.spark.unsafe.types.UTF8String
@@ -357,15 +357,15 @@ private abstract class OrderableSafeColumnStats[T](dataType: DataType) extends C
357357
Array[Any](lower, upper, nullCount, count, sizeInBytes)
358358
}
359359

360-
private[columnar] final class ArrayColumnStats(dataType: DataType)
360+
private[columnar] final class ArrayColumnStats(dataType: ArrayType)
361361
extends OrderableSafeColumnStats[UnsafeArrayData](dataType) {
362362
override def getValue(row: InternalRow, ordinal: Int): UnsafeArrayData =
363363
row.getArray(ordinal).asInstanceOf[UnsafeArrayData]
364364

365365
override def copy(value: UnsafeArrayData): UnsafeArrayData = value.copy()
366366
}
367367

368-
private[columnar] final class StructColumnStats(dataType: DataType)
368+
private[columnar] final class StructColumnStats(dataType: StructType)
369369
extends OrderableSafeColumnStats[UnsafeRow](dataType) {
370370
private val numFields = dataType.asInstanceOf[StructType].fields.length
371371

@@ -375,20 +375,12 @@ private[columnar] final class StructColumnStats(dataType: DataType)
375375
override def copy(value: UnsafeRow): UnsafeRow = value.copy()
376376
}
377377

378-
private[columnar] final class MapColumnStats(dataType: DataType) extends ColumnStats {
379-
private val columnType = ColumnType(dataType)
380-
381-
override def gatherStats(row: InternalRow, ordinal: Int): Unit = {
382-
if (!row.isNullAt(ordinal)) {
383-
sizeInBytes += columnType.actualSize(row, ordinal)
384-
count += 1
385-
} else {
386-
gatherNullStats()
387-
}
388-
}
378+
private[columnar] final class MapColumnStats(dataType: MapType)
379+
extends OrderableSafeColumnStats[UnsafeMapData](dataType) {
380+
override def getValue(row: InternalRow, ordinal: Int): UnsafeMapData =
381+
row.getMap(ordinal).asInstanceOf[UnsafeMapData]
389382

390-
override def collectedStatistics: Array[Any] =
391-
Array[Any](null, null, nullCount, count, sizeInBytes)
383+
override def copy(value: UnsafeMapData): UnsafeMapData = value.copy()
392384
}
393385

394386
private[columnar] final class NullColumnStats extends ColumnStats {

sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnStatsSuite.scala

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
package org.apache.spark.sql.execution.columnar
1919

2020
import org.apache.spark.SparkFunSuite
21-
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData}
21+
import org.apache.spark.sql.catalyst.InternalRow
22+
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeArrayData, UnsafeProjection}
2223
import org.apache.spark.sql.catalyst.util.TypeUtils
2324
import org.apache.spark.sql.types._
2425

@@ -182,7 +183,7 @@ class ColumnStatsSuite extends SparkFunSuite {
182183
}
183184
}
184185

185-
def testMapColumnStats(dataType: DataType, initialStatistics: Array[Any]): Unit = {
186+
def testMapColumnStats(dataType: MapType, initialStatistics: Array[Any]): Unit = {
186187
val columnType = ColumnType(dataType)
187188

188189
test(s"${dataType.typeName}: empty") {
@@ -226,4 +227,22 @@ class ColumnStatsSuite extends SparkFunSuite {
226227
assertResult(10)(collected(3))
227228
assertResult(10 * (4 + unsafeData.getSizeInBytes))(collected(4))
228229
}
230+
231+
test("Reuse UnsafeRow for stats") {
232+
val structType = StructType(Array(StructField("int", IntegerType)))
233+
val stats = new StructColumnStats(structType)
234+
val converter = UnsafeProjection.create(structType)
235+
val unsafeData = converter(InternalRow(1))
236+
(1 to 10).foreach { value =>
237+
val row = new GenericInternalRow(Array[Any](unsafeData))
238+
unsafeData.setInt(0, value)
239+
stats.gatherStats(row, 0)
240+
}
241+
val collected = stats.collectedStatistics
242+
assertResult(converter(InternalRow(1)))(collected(0))
243+
assertResult(converter(InternalRow(10)))(collected(1))
244+
assertResult(0)(collected(2))
245+
assertResult(10)(collected(3))
246+
assertResult(10 * (4 + unsafeData.getSizeInBytes))(collected(4))
247+
}
229248
}

0 commit comments

Comments
 (0)