Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion common/utils/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -4070,7 +4070,7 @@
"subClass" : {
"MULTI_GENERATOR" : {
"message" : [
"only one generator allowed per <clause> clause but found <num>: <generators>."
"only one generator allowed per SELECT clause but found <num>: <generators>."
]
},
"NESTED_IN_EXPRESSIONS" : {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ This error class has the following derived error classes:

## MULTI_GENERATOR

only one generator allowed per `<clause>` clause but found `<num>`: `<generators>`.
only one generator allowed per SELECT clause but found `<num>`: `<generators>`.

## NESTED_IN_EXPRESSIONS

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2876,28 +2876,36 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
}
}

// We must wait until all expressions except for generator functions are resolved before
// rewriting generator functions in Project/Aggregate. This is necessary to make this rule
// stable for different execution orders of analyzer rules. See also SPARK-47241.
private def canRewriteGenerator(namedExprs: Seq[NamedExpression]): Boolean = {
namedExprs.forall { ne =>
ne.resolved || {
trimNonTopLevelAliases(ne) match {
case AliasedGenerator(_, _, _) => true
case _ => false
}
}
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning(
_.containsPattern(GENERATOR), ruleId) {
case Project(projectList, _) if projectList.exists(hasNestedGenerator) =>
val nestedGenerator = projectList.find(hasNestedGenerator).get
throw QueryCompilationErrors.nestedGeneratorError(trimAlias(nestedGenerator))

case Project(projectList, _) if projectList.count(hasGenerator) > 1 =>
val generators = projectList.filter(hasGenerator).map(trimAlias)
throw QueryCompilationErrors.moreThanOneGeneratorError(generators, "SELECT")

case Aggregate(_, aggList, _) if aggList.exists(hasNestedGenerator) =>
val nestedGenerator = aggList.find(hasNestedGenerator).get
throw QueryCompilationErrors.nestedGeneratorError(trimAlias(nestedGenerator))

case Aggregate(_, aggList, _) if aggList.count(hasGenerator) > 1 =>
val generators = aggList.filter(hasGenerator).map(trimAlias)
throw QueryCompilationErrors.moreThanOneGeneratorError(generators, "aggregate")
throw QueryCompilationErrors.moreThanOneGeneratorError(generators)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still for aggregate, but I saw you remove the clause field in the error message?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find aggregate clause confusing, as what end users write is a SELECT query with GROUP BY or aggregate functions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, okay.

Copy link
Contributor Author

@cloud-fan cloud-fan Mar 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another reason is we can't always figure out if it's aggregate or not. If there is no GROUP BY, the plan is still Project and we may fail before analyzer rewrite it to Aggregate, then we report SELECT clause anyway.


case agg @ Aggregate(groupList, aggList, child) if aggList.forall {
case AliasedGenerator(_, _, _) => true
case other => other.resolved
} && aggList.exists(hasGenerator) =>
case Aggregate(groupList, aggList, child) if canRewriteGenerator(aggList) &&
aggList.exists(hasGenerator) =>
// If generator in the aggregate list was visited, set the boolean flag true.
var generatorVisited = false

Expand Down Expand Up @@ -2942,24 +2950,23 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
// first for replacing `Project` with `Aggregate`.
p

case p @ Project(projectList, child) =>
case p @ Project(projectList, child) if canRewriteGenerator(projectList) &&
projectList.exists(hasGenerator) =>
val (resolvedGenerator, newProjectList) = projectList
.map(trimNonTopLevelAliases)
.foldLeft((None: Option[Generate], Nil: Seq[NamedExpression])) { (res, e) =>
e match {
case AliasedGenerator(generator, names, outer) if generator.childrenResolved =>
// It's a sanity check, this should not happen as the previous case will throw
// exception earlier.
assert(res._1.isEmpty, "More than one generator found in SELECT.")

// If there are more than one generator, we only rewrite the first one and wait for
// the next analyzer iteration to rewrite the next one.
case AliasedGenerator(generator, names, outer) if res._1.isEmpty &&
generator.childrenResolved =>
val g = Generate(
generator,
unrequiredChildIndex = Nil,
outer = outer,
qualifier = None,
generatorOutput = ResolveGenerate.makeGeneratorOutput(generator, names),
child)

(Some(g), res._2 ++ g.nullableOutput)
case other =>
(res._1, res._2 :+ other)
Expand All @@ -2979,6 +2986,10 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor

case u: UnresolvedTableValuedFunction => u

case p: Project => p

case a: Aggregate => a

case p if p.expressions.exists(hasGenerator) =>
throw QueryCompilationErrors.generatorOutsideSelectError(p)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
messageParameters = messageParameters)
}

protected def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = {
exprs.flatMap(_.collect {
case e: Generator => e
}).length > 1
}

protected def hasMapType(dt: DataType): Boolean = {
dt.existsRecursively(_.isInstanceOf[MapType])
}
Expand Down Expand Up @@ -669,10 +663,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB
))
}

case p @ Project(exprs, _) if containsMultipleGenerators(exprs) =>
val generators = exprs.filter(expr => expr.exists(_.isInstanceOf[Generator]))
throw QueryCompilationErrors.moreThanOneGeneratorError(generators, "SELECT")

case p @ Project(projectList, _) =>
projectList.foreach(_.transformDownWithPruning(
_.containsPattern(UNRESOLVED_WINDOW_EXPRESSION)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,11 +266,10 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat
messageParameters = Map("expression" -> toSQLExpr(trimmedNestedGenerator)))
}

def moreThanOneGeneratorError(generators: Seq[Expression], clause: String): Throwable = {
def moreThanOneGeneratorError(generators: Seq[Expression]): Throwable = {
new AnalysisException(
errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR",
messageParameters = Map(
"clause" -> clause,
"num" -> generators.size.toString,
"generators" -> generators.map(toSQLExpr).mkString(", ")))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,11 +394,6 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase {
"inputType" -> "\"BOOLEAN\"",
"requiredType" -> "\"INT\""))

errorTest(
"too many generators",
listRelation.select(Explode($"list").as("a"), Explode($"list").as("b")),
"only one generator" :: "explode" :: Nil)

errorClassTest(
"unresolved attributes",
testRelation.select($"abcd"),
Expand Down Expand Up @@ -804,18 +799,11 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase {
"SUM_OF_LIMIT_AND_OFFSET_EXCEEDS_MAX_INT",
Map("limit" -> "1000000000", "offset" -> "2000000000"))

errorTest(
"more than one generators in SELECT",
listRelation.select(Explode($"list"), Explode($"list")),
"The generator is not supported: only one generator allowed per select clause but found 2: " +
""""explode(list)", "explode(list)"""" :: Nil
)

errorTest(
"more than one generators for aggregates in SELECT",
testRelation.select(Explode(CreateArray(min($"a") :: Nil)),
Explode(CreateArray(max($"a") :: Nil))),
"The generator is not supported: only one generator allowed per select clause but found 2: " +
"The generator is not supported: only one generator allowed per SELECT clause but found 2: " +
""""explode(array(min(a)))", "explode(array(max(a)))"""" :: Nil
)

Expand Down
14 changes: 0 additions & 14 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -380,20 +380,6 @@ class DataFrameSuite extends QueryTest
Row("a", Seq("a"), 1) :: Nil)
}

test("more than one generator in SELECT clause") {
val df = Seq((Array("a"), 1)).toDF("a", "b")

checkError(
exception = intercept[AnalysisException] {
df.select(explode($"a").as("a"), explode($"a").as("b"))
},
errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR",
parameters = Map(
"clause" -> "SELECT",
"num" -> "2",
"generators" -> "\"explode(a)\", \"explode(a)\""))
}

test("sort after generate with join=true") {
val df = Seq((Array("a"), 1)).toDF("a", "b")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,6 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession {
},
errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR",
parameters = Map(
"clause" -> "aggregate",
"num" -> "2",
"generators" -> ("\"explode(array(min(c2), max(c2)))\", " +
"\"posexplode(array(min(c2), max(c2)))\"")))
Expand Down Expand Up @@ -553,6 +552,32 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession {
checkAnswer(df, Row(0.7604953758285915d))
}
}

test("SPARK-47241: two generator functions in SELECT") {
def testTwoGenerators(needImplicitCast: Boolean): Unit = {
val df = sql(
s"""
|SELECT
|explode(array('a', 'b')) as c1,
|explode(array(0L, ${if (needImplicitCast) "0L + 1" else "1L"})) as c2
|""".stripMargin)
checkAnswer(df, Seq(Row("a", 0L), Row("a", 1L), Row("b", 0L), Row("b", 1L)))
}
testTwoGenerators(needImplicitCast = true)
testTwoGenerators(needImplicitCast = false)
}

test("SPARK-47241: generator function after wildcard in SELECT") {
val df = sql(
s"""
|SELECT *, explode(array('a', 'b')) as c1
|FROM
|(
| SELECT id FROM range(1) GROUP BY 1
|)
|""".stripMargin)
checkAnswer(df, Seq(Row(0, "a"), Row(0, "b")))
}
}

case class EmptyGenerator() extends Generator with LeafLike[Expression] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -655,18 +655,6 @@ class QueryCompilationErrorsSuite
parameters = Map("expression" -> "\"(explode(array(1, 2, 3)) + 1)\""))
}

test("UNSUPPORTED_GENERATOR: only one generator allowed") {
val e = intercept[AnalysisException](
sql("""select explode(Array(1, 2, 3)), explode(Array(1, 2, 3))""").collect()
)

checkError(
exception = e,
errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR",
parameters = Map("clause" -> "SELECT", "num" -> "2",
"generators" -> "\"explode(array(1, 2, 3))\", \"explode(array(1, 2, 3))\""))
}

test("UNSUPPORTED_GENERATOR: generators are not supported outside the SELECT clause") {
val e = intercept[AnalysisException](
sql("""select 1 from t order by explode(Array(1, 2, 3))""").collect()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,28 +160,6 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd
| SELECT key FROM gen_tmp ORDER BY key ASC;
""".stripMargin)

test("multiple generators in projection") {
checkError(
exception = intercept[AnalysisException] {
sql("SELECT explode(array(key, key)), explode(array(key, key)) FROM src").collect()
},
errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR",
parameters = Map(
"clause" -> "SELECT",
"num" -> "2",
"generators" -> "\"explode(array(key, key))\", \"explode(array(key, key))\""))

checkError(
exception = intercept[AnalysisException] {
sql("SELECT explode(array(key, key)) as k1, explode(array(key, key)) FROM src").collect()
},
errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR",
parameters = Map(
"clause" -> "SELECT",
"num" -> "2",
"generators" -> "\"explode(array(key, key))\", \"explode(array(key, key))\""))
}

createQueryTest("! operator",
"""
|SELECT a FROM (
Expand Down