Skip to content

Commit b9a0c56

Browse files
huaxingaoviirya
authored andcommitted
[SPARK-36646][SQL] Push down group by partition column for aggregate
### What changes were proposed in this pull request? lift the restriction for aggregate push down for parquet and orc if group by columns are the same as the partition cols ### Why are the changes needed? previously, if there are group by columns, we don't push down aggregate to data source. After the change, if the group by columns are the same as the partition columns, we will push down aggregates. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? new tests Closes #34445 from huaxingao/group_by. Authored-by: Huaxin Gao <[email protected]> Signed-off-by: Liang-Chi Hsieh <[email protected]>
1 parent 6450f6b commit b9a0c56

File tree

6 files changed

+187
-37
lines changed

6 files changed

+187
-37
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.execution.datasources
1919

2020
import org.apache.spark.sql.catalyst.InternalRow
21-
import org.apache.spark.sql.catalyst.expressions.Expression
21+
import org.apache.spark.sql.catalyst.expressions.{Expression, GenericInternalRow}
2222
import org.apache.spark.sql.connector.expressions.NamedReference
2323
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Count, CountStar, Max, Min}
2424
import org.apache.spark.sql.execution.RowToColumnConverter
@@ -81,19 +81,37 @@ object AggregatePushDownUtils {
8181
}
8282
}
8383

84-
if (aggregation.groupByColumns.nonEmpty || dataFilters.nonEmpty) {
84+
if (dataFilters.nonEmpty) {
8585
// Parquet/ORC footer has max/min/count for columns
8686
// e.g. SELECT COUNT(col1) FROM t
8787
// but footer doesn't have max/min/count for a column if max/min/count
8888
// are combined with filter or group by
8989
// e.g. SELECT COUNT(col1) FROM t WHERE col2 = 8
9090
// SELECT COUNT(col1) FROM t GROUP BY col2
9191
// However, if the filter is on partition column, max/min/count can still be pushed down
92-
// Todo: add support if groupby column is partition col
93-
// (https://issues.apache.org/jira/browse/SPARK-36646)
9492
return None
9593
}
9694

95+
if (aggregation.groupByColumns.nonEmpty &&
96+
partitionNames.size != aggregation.groupByColumns.length) {
97+
// If there are group by columns, we only push down if the group by columns are the same as
98+
// the partition columns. In theory, if group by columns are a subset of partition columns,
99+
// we should still be able to push down. e.g. if table t has partition columns p1, p2, and p3,
100+
// SELECT MAX(c) FROM t GROUP BY p1, p2 should still be able to push down. However, the
101+
// partial aggregation pushed down to data source needs to be
102+
// SELECT p1, p2, p3, MAX(c) FROM t GROUP BY p1, p2, p3, and Spark layer
103+
// needs to have a final aggregation such as SELECT MAX(c) FROM t GROUP BY p1, p2, then the
104+
// pushed down query schema is different from the query schema at Spark. We will keep
105+
// aggregate push down simple and don't handle this complicate case for now.
106+
return None
107+
}
108+
aggregation.groupByColumns.foreach { col =>
109+
// don't push down if the group by columns are not the same as the partition columns (orders
110+
// doesn't matter because reorder can be done at data source layer)
111+
if (col.fieldNames.length != 1 || !isPartitionCol(col)) return None
112+
finalSchema = finalSchema.add(getStructFieldForCol(col))
113+
}
114+
97115
aggregation.aggregateExpressions.foreach {
98116
case max: Max =>
99117
if (!processMinOrMax(max)) return None
@@ -138,4 +156,44 @@ object AggregatePushDownUtils {
138156
converter.convert(aggregatesAsRow, columnVectors.toArray)
139157
new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]], 1)
140158
}
159+
160+
/**
161+
* Return the schema for aggregates only (exclude group by columns)
162+
*/
163+
def getSchemaWithoutGroupingExpression(
164+
aggSchema: StructType,
165+
aggregation: Aggregation): StructType = {
166+
val numOfGroupByColumns = aggregation.groupByColumns.length
167+
if (numOfGroupByColumns > 0) {
168+
new StructType(aggSchema.fields.drop(numOfGroupByColumns))
169+
} else {
170+
aggSchema
171+
}
172+
}
173+
174+
/**
175+
* Reorder partition cols if they are not in the same order as group by columns
176+
*/
177+
def reOrderPartitionCol(
178+
partitionSchema: StructType,
179+
aggregation: Aggregation,
180+
partitionValues: InternalRow): InternalRow = {
181+
val groupByColNames = aggregation.groupByColumns.map(_.fieldNames.head)
182+
assert(groupByColNames.length == partitionSchema.length &&
183+
groupByColNames.length == partitionValues.numFields, "The number of group by columns " +
184+
s"${groupByColNames.length} should be the same as partition schema length " +
185+
s"${partitionSchema.length} and the number of fields ${partitionValues.numFields} " +
186+
s"in partitionValues")
187+
var reorderedPartColValues = Array.empty[Any]
188+
if (!partitionSchema.names.sameElements(groupByColNames)) {
189+
groupByColNames.foreach { col =>
190+
val index = partitionSchema.names.indexOf(col)
191+
val v = partitionValues.asInstanceOf[GenericInternalRow].values(index)
192+
reorderedPartColValues = reorderedPartColValues :+ v
193+
}
194+
new GenericInternalRow(reorderedPartColValues)
195+
} else {
196+
partitionValues
197+
}
198+
}
141199
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,12 @@ import org.apache.spark.internal.Logging
3535
import org.apache.spark.sql.{SPARK_VERSION_METADATA_KEY, SparkSession}
3636
import org.apache.spark.sql.catalyst.InternalRow
3737
import org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
38+
import org.apache.spark.sql.catalyst.expressions.JoinedRow
3839
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
3940
import org.apache.spark.sql.catalyst.util.{quoteIdentifier, CharVarcharUtils}
4041
import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, CountStar, Max, Min}
4142
import org.apache.spark.sql.errors.QueryExecutionErrors
42-
import org.apache.spark.sql.execution.datasources.SchemaMergeUtils
43+
import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, SchemaMergeUtils}
4344
import org.apache.spark.sql.types._
4445
import org.apache.spark.util.{ThreadUtils, Utils}
4546

@@ -396,9 +397,8 @@ object OrcUtils extends Logging {
396397
dataSchema: StructType,
397398
partitionSchema: StructType,
398399
aggregation: Aggregation,
399-
aggSchema: StructType): InternalRow = {
400-
require(aggregation.groupByColumns.length == 0,
401-
s"aggregate $aggregation with group-by column shouldn't be pushed down")
400+
aggSchema: StructType,
401+
partitionValues: InternalRow): InternalRow = {
402402
var columnsStatistics: OrcColumnStatistics = null
403403
try {
404404
columnsStatistics = OrcFooterReader.readStatistics(reader)
@@ -457,17 +457,22 @@ object OrcUtils extends Logging {
457457
}
458458
}
459459

460+
// if there are group by columns, we will build result row first,
461+
// and then append group by columns values (partition columns values) to the result row.
462+
val schemaWithoutGroupBy =
463+
AggregatePushDownUtils.getSchemaWithoutGroupingExpression(aggSchema, aggregation)
464+
460465
val aggORCValues: Seq[WritableComparable[_]] =
461466
aggregation.aggregateExpressions.zipWithIndex.map {
462467
case (max: Max, index) =>
463468
val columnName = max.column.fieldNames.head
464469
val statistics = getColumnStatistics(columnName)
465-
val dataType = aggSchema(index).dataType
470+
val dataType = schemaWithoutGroupBy(index).dataType
466471
getMinMaxFromColumnStatistics(statistics, dataType, isMax = true)
467472
case (min: Min, index) =>
468473
val columnName = min.column.fieldNames.head
469474
val statistics = getColumnStatistics(columnName)
470-
val dataType = aggSchema.apply(index).dataType
475+
val dataType = schemaWithoutGroupBy.apply(index).dataType
471476
getMinMaxFromColumnStatistics(statistics, dataType, isMax = false)
472477
case (count: Count, _) =>
473478
val columnName = count.column.fieldNames.head
@@ -490,7 +495,15 @@ object OrcUtils extends Logging {
490495
s"createAggInternalRowFromFooter should not take $x as the aggregate expression")
491496
}
492497

493-
val orcValuesDeserializer = new OrcDeserializer(aggSchema, (0 until aggSchema.length).toArray)
494-
orcValuesDeserializer.deserializeFromValues(aggORCValues)
498+
val orcValuesDeserializer = new OrcDeserializer(schemaWithoutGroupBy,
499+
(0 until schemaWithoutGroupBy.length).toArray)
500+
val resultRow = orcValuesDeserializer.deserializeFromValues(aggORCValues)
501+
if (aggregation.groupByColumns.nonEmpty) {
502+
val reOrderedPartitionValues = AggregatePushDownUtils.reOrderPartitionCol(
503+
partitionSchema, aggregation, partitionValues)
504+
new JoinedRow(reOrderedPartitionValues, resultRow)
505+
} else {
506+
resultRow
507+
}
495508
}
496509
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
3131
import org.apache.spark.SparkException
3232
import org.apache.spark.sql.SparkSession
3333
import org.apache.spark.sql.catalyst.InternalRow
34+
import org.apache.spark.sql.catalyst.expressions.JoinedRow
3435
import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, CountStar, Max, Min}
35-
import org.apache.spark.sql.execution.datasources.PartitioningUtils
36+
import org.apache.spark.sql.execution.datasources.AggregatePushDownUtils
3637
import org.apache.spark.sql.internal.SQLConf.{LegacyBehaviorPolicy, PARQUET_AGGREGATE_PUSHDOWN_ENABLED}
3738
import org.apache.spark.sql.types.StructType
3839

@@ -157,17 +158,22 @@ object ParquetUtils {
157158
partitionSchema: StructType,
158159
aggregation: Aggregation,
159160
aggSchema: StructType,
160-
datetimeRebaseMode: LegacyBehaviorPolicy.Value,
161-
isCaseSensitive: Boolean): InternalRow = {
161+
partitionValues: InternalRow,
162+
datetimeRebaseMode: LegacyBehaviorPolicy.Value): InternalRow = {
162163
val (primitiveTypes, values) = getPushedDownAggResult(
163-
footer, filePath, dataSchema, partitionSchema, aggregation, isCaseSensitive)
164+
footer, filePath, dataSchema, partitionSchema, aggregation)
164165

165166
val builder = Types.buildMessage
166167
primitiveTypes.foreach(t => builder.addField(t))
167168
val parquetSchema = builder.named("root")
168169

170+
// if there are group by columns, we will build result row first,
171+
// and then append group by columns values (partition columns values) to the result row.
172+
val schemaWithoutGroupBy =
173+
AggregatePushDownUtils.getSchemaWithoutGroupingExpression(aggSchema, aggregation)
174+
169175
val schemaConverter = new ParquetToSparkSchemaConverter
170-
val converter = new ParquetRowConverter(schemaConverter, parquetSchema, aggSchema,
176+
val converter = new ParquetRowConverter(schemaConverter, parquetSchema, schemaWithoutGroupBy,
171177
None, datetimeRebaseMode, LegacyBehaviorPolicy.CORRECTED, NoopUpdater)
172178
val primitiveTypeNames = primitiveTypes.map(_.getPrimitiveTypeName)
173179
primitiveTypeNames.zipWithIndex.foreach {
@@ -195,7 +201,14 @@ object ParquetUtils {
195201
case (_, i) =>
196202
throw new SparkException("Unexpected parquet type name: " + primitiveTypeNames(i))
197203
}
198-
converter.currentRecord
204+
205+
if (aggregation.groupByColumns.nonEmpty) {
206+
val reorderedPartitionValues = AggregatePushDownUtils.reOrderPartitionCol(
207+
partitionSchema, aggregation, partitionValues)
208+
new JoinedRow(reorderedPartitionValues, converter.currentRecord)
209+
} else {
210+
converter.currentRecord
211+
}
199212
}
200213

201214
/**
@@ -211,16 +224,14 @@ object ParquetUtils {
211224
filePath: String,
212225
dataSchema: StructType,
213226
partitionSchema: StructType,
214-
aggregation: Aggregation,
215-
isCaseSensitive: Boolean)
227+
aggregation: Aggregation)
216228
: (Array[PrimitiveType], Array[Any]) = {
217229
val footerFileMetaData = footer.getFileMetaData
218230
val fields = footerFileMetaData.getSchema.getFields
219231
val blocks = footer.getBlocks
220232
val primitiveTypeBuilder = mutable.ArrayBuilder.make[PrimitiveType]
221233
val valuesBuilder = mutable.ArrayBuilder.make[Any]
222234

223-
assert(aggregation.groupByColumns.length == 0, "group by shouldn't be pushed down")
224235
aggregation.aggregateExpressions.foreach { agg =>
225236
var value: Any = None
226237
var rowCount = 0L
@@ -250,8 +261,7 @@ object ParquetUtils {
250261
schemaName = "count(" + count.column.fieldNames.head + ")"
251262
rowCount += block.getRowCount
252263
var isPartitionCol = false
253-
if (partitionSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive))
254-
.toSet.contains(count.column.fieldNames.head)) {
264+
if (partitionSchema.fields.map(_.name).toSet.contains(count.column.fieldNames.head)) {
255265
isPartitionCol = true
256266
}
257267
isCount = true

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,10 @@ case class OrcPartitionReaderFactory(
8383

8484
override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = {
8585
val conf = broadcastedConf.value.value
86-
val filePath = new Path(new URI(file.filePath))
87-
8886
if (aggregation.nonEmpty) {
89-
return buildReaderWithAggregates(filePath, conf)
87+
return buildReaderWithAggregates(file, conf)
9088
}
89+
val filePath = new Path(new URI(file.filePath))
9190

9291
val resultedColPruneInfo =
9392
Utils.tryWithResource(createORCReader(filePath, conf)) { reader =>
@@ -127,11 +126,10 @@ case class OrcPartitionReaderFactory(
127126

128127
override def buildColumnarReader(file: PartitionedFile): PartitionReader[ColumnarBatch] = {
129128
val conf = broadcastedConf.value.value
130-
val filePath = new Path(new URI(file.filePath))
131-
132129
if (aggregation.nonEmpty) {
133-
return buildColumnarReaderWithAggregates(filePath, conf)
130+
return buildColumnarReaderWithAggregates(file, conf)
134131
}
132+
val filePath = new Path(new URI(file.filePath))
135133

136134
val resultedColPruneInfo =
137135
Utils.tryWithResource(createORCReader(filePath, conf)) { reader =>
@@ -183,14 +181,16 @@ case class OrcPartitionReaderFactory(
183181
* Build reader with aggregate push down.
184182
*/
185183
private def buildReaderWithAggregates(
186-
filePath: Path,
184+
file: PartitionedFile,
187185
conf: Configuration): PartitionReader[InternalRow] = {
186+
val filePath = new Path(new URI(file.filePath))
188187
new PartitionReader[InternalRow] {
189188
private var hasNext = true
190189
private lazy val row: InternalRow = {
191190
Utils.tryWithResource(createORCReader(filePath, conf)) { reader =>
192191
OrcUtils.createAggInternalRowFromFooter(
193-
reader, filePath.toString, dataSchema, partitionSchema, aggregation.get, readDataSchema)
192+
reader, filePath.toString, dataSchema, partitionSchema, aggregation.get,
193+
readDataSchema, file.partitionValues)
194194
}
195195
}
196196

@@ -209,15 +209,16 @@ case class OrcPartitionReaderFactory(
209209
* Build columnar reader with aggregate push down.
210210
*/
211211
private def buildColumnarReaderWithAggregates(
212-
filePath: Path,
212+
file: PartitionedFile,
213213
conf: Configuration): PartitionReader[ColumnarBatch] = {
214+
val filePath = new Path(new URI(file.filePath))
214215
new PartitionReader[ColumnarBatch] {
215216
private var hasNext = true
216217
private lazy val batch: ColumnarBatch = {
217218
Utils.tryWithResource(createORCReader(filePath, conf)) { reader =>
218219
val row = OrcUtils.createAggInternalRowFromFooter(
219220
reader, filePath.toString, dataSchema, partitionSchema, aggregation.get,
220-
readDataSchema)
221+
readDataSchema, file.partitionValues)
221222
AggregatePushDownUtils.convertAggregatesRowToBatch(row, readDataSchema, offHeap = false)
222223
}
223224
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,11 @@ case class ParquetPartitionReaderFactory(
129129
private var hasNext = true
130130
private lazy val row: InternalRow = {
131131
val footer = getFooter(file)
132+
132133
if (footer != null && footer.getBlocks.size > 0) {
133134
ParquetUtils.createAggInternalRowFromFooter(footer, file.filePath, dataSchema,
134-
partitionSchema, aggregation.get, readDataSchema,
135-
getDatetimeRebaseMode(footer.getFileMetaData), isCaseSensitive)
135+
partitionSchema, aggregation.get, readDataSchema, file.partitionValues,
136+
getDatetimeRebaseMode(footer.getFileMetaData))
136137
} else {
137138
null
138139
}
@@ -174,8 +175,8 @@ case class ParquetPartitionReaderFactory(
174175
val footer = getFooter(file)
175176
if (footer != null && footer.getBlocks.size > 0) {
176177
val row = ParquetUtils.createAggInternalRowFromFooter(footer, file.filePath,
177-
dataSchema, partitionSchema, aggregation.get, readDataSchema,
178-
getDatetimeRebaseMode(footer.getFileMetaData), isCaseSensitive)
178+
dataSchema, partitionSchema, aggregation.get, readDataSchema, file.partitionValues,
179+
getDatetimeRebaseMode(footer.getFileMetaData))
179180
AggregatePushDownUtils.convertAggregatesRowToBatch(
180181
row, readDataSchema, enableOffHeapColumnVector && Option(TaskContext.get()).isDefined)
181182
} else {

0 commit comments

Comments
 (0)