Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.collection
import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -258,6 +259,13 @@ object NestedColumnAliasing {
.filter(!_.references.subsetOf(exclusiveAttrSet))
.groupBy(_.references.head.canonicalized.asInstanceOf[Attribute])
.flatMap { case (attr: Attribute, nestedFields: collection.Seq[ExtractValue]) =>

// Check if `ExtractValue` expressions contain any aggregate functions in their tree. Those
// that do should not have an alias generated as it can lead to pushing the aggregate down
// into a projection.
def containsAggregateFunction(ev: ExtractValue): Boolean =
ev.find(_.isInstanceOf[AggregateFunction]).isDefined

// Remove redundant [[ExtractValue]]s if they share the same parent nest field.
// For example, when `a.b` and `a.b.c` are in project list, we only need to alias `a.b`.
// Because `a.b` requires all of the inner fields of `b`, we cannot prune `a.b.c`.
Expand All @@ -268,15 +276,18 @@ object NestedColumnAliasing {
val child = e.children.head
nestedFields.forall(f => child.find(_.semanticEquals(f)).isEmpty)
case _ => true
}.distinct
}
.distinct
// Discard [[ExtractValue]]s that contain aggregate functions.
.filterNot(containsAggregateFunction)

// If all nested fields of `attr` are used, we don't need to introduce new aliases.
// By default, the [[ColumnPruning]] rule uses `attr` already.
// Note that we need to remove cosmetic variations first, so we only count a
// nested field once.
val numUsedNestedFields = dedupNestedFields.map(_.canonicalized).distinct
.map { nestedField => totalFieldNum(nestedField.dataType) }.sum
if (numUsedNestedFields < totalFieldNum(attr.dataType)) {
if (dedupNestedFields.nonEmpty && numUsedNestedFields < totalFieldNum(attr.dataType)) {
Some((attr, dedupNestedFields.toSeq))
} else {
None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.SchemaPruningTest
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -763,6 +764,32 @@ class NestedColumnAliasingSuite extends SchemaPruningTest {
$"_extract_search_params.col2".as("col2")).analyze
comparePlans(optimized, query)
}

test("SPARK-36677: NestedColumnAliasing should not push down aggregate functions into " +
"projections") {
val nestedRelation = LocalRelation(
'a.struct(
'c.struct(
'e.string),
'd.string),
'b.string)

val plan = nestedRelation
.select($"a", $"b")
.groupBy($"b")(max($"a").getField("c").getField("e"))
.analyze

val optimized = Optimize.execute(plan)

// The plan should not contain aggregation functions inside the projection
SimpleAnalyzer.checkAnalysis(optimized)

val expected = nestedRelation
.groupBy($"b")(max($"a").getField("c").getField("e"))
.analyze

comparePlans(optimized, expected)
}
}

object NestedColumnAliasingSuite {
Expand Down