From b576e81c65a481eb7d6ae56572578c60f14e72a3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 1 Mar 2024 13:52:10 +0800 Subject: [PATCH 1/2] fix rule order issues for ExtractGenerator --- .../main/resources/error/error-classes.json | 2 +- ...tions-unsupported-generator-error-class.md | 2 +- .../sql/catalyst/analysis/Analyzer.scala | 43 ++++++++++++------- .../sql/catalyst/analysis/CheckAnalysis.scala | 10 ----- .../sql/errors/QueryCompilationErrors.scala | 3 +- .../analysis/AnalysisErrorSuite.scala | 14 +----- .../org/apache/spark/sql/DataFrameSuite.scala | 14 ------ .../spark/sql/GeneratorFunctionSuite.scala | 27 +++++++++++- .../errors/QueryCompilationErrorsSuite.scala | 12 ------ .../sql/hive/execution/HiveQuerySuite.scala | 22 ---------- 10 files changed, 57 insertions(+), 92 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 98b151a28fdcf..e82da1adc48f8 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -4070,7 +4070,7 @@ "subClass" : { "MULTI_GENERATOR" : { "message" : [ - "only one generator allowed per clause but found : ." + "only one generator allowed per SELECT clause but found : ." ] }, "NESTED_IN_EXPRESSIONS" : { diff --git a/docs/sql-error-conditions-unsupported-generator-error-class.md b/docs/sql-error-conditions-unsupported-generator-error-class.md index 5b0687c7d03d1..4e42d6b43bca4 100644 --- a/docs/sql-error-conditions-unsupported-generator-error-class.md +++ b/docs/sql-error-conditions-unsupported-generator-error-class.md @@ -32,7 +32,7 @@ This error class has the following derived error classes: ## MULTI_GENERATOR -only one generator allowed per `` clause but found ``: ``. +only one generator allowed per SELECT clause but found ``: ``. ## NESTED_IN_EXPRESSIONS diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c5b01a312664a..720f09fcc75c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2876,28 +2876,36 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } } + // We must wait until all expressions except for generator functions are resolved before + // rewriting generator functions in Project/Aggregate. This is necessary to make this rule + // stable for different execution orders of analyzer rules. See also SPARK-47241. + private def canRewriteGenerator(namedExprs: Seq[NamedExpression]): Boolean = { + namedExprs.forall { ne => + ne.resolved || { + trimNonTopLevelAliases(ne) match { + case AliasedGenerator(_, _, _) => true + case _ => false + } + } + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( _.containsPattern(GENERATOR), ruleId) { case Project(projectList, _) if projectList.exists(hasNestedGenerator) => val nestedGenerator = projectList.find(hasNestedGenerator).get throw QueryCompilationErrors.nestedGeneratorError(trimAlias(nestedGenerator)) - case Project(projectList, _) if projectList.count(hasGenerator) > 1 => - val generators = projectList.filter(hasGenerator).map(trimAlias) - throw QueryCompilationErrors.moreThanOneGeneratorError(generators, "SELECT") - case Aggregate(_, aggList, _) if aggList.exists(hasNestedGenerator) => val nestedGenerator = aggList.find(hasNestedGenerator).get throw QueryCompilationErrors.nestedGeneratorError(trimAlias(nestedGenerator)) case Aggregate(_, aggList, _) if aggList.count(hasGenerator) > 1 => val generators = aggList.filter(hasGenerator).map(trimAlias) - throw QueryCompilationErrors.moreThanOneGeneratorError(generators, "aggregate") + throw QueryCompilationErrors.moreThanOneGeneratorError(generators) - case agg @ Aggregate(groupList, aggList, child) if aggList.forall { - case AliasedGenerator(_, _, _) => true - case other => other.resolved - } && aggList.exists(hasGenerator) => + case Aggregate(groupList, aggList, child) if canRewriteGenerator(aggList) && + aggList.exists(hasGenerator) => // If generator in the aggregate list was visited, set the boolean flag true. var generatorVisited = false @@ -2942,16 +2950,16 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor // first for replacing `Project` with `Aggregate`. p - case p @ Project(projectList, child) => + case p @ Project(projectList, child) if canRewriteGenerator(projectList) && + projectList.exists(hasGenerator) => val (resolvedGenerator, newProjectList) = projectList .map(trimNonTopLevelAliases) .foldLeft((None: Option[Generate], Nil: Seq[NamedExpression])) { (res, e) => e match { - case AliasedGenerator(generator, names, outer) if generator.childrenResolved => - // It's a sanity check, this should not happen as the previous case will throw - // exception earlier. - assert(res._1.isEmpty, "More than one generator found in SELECT.") - + // If there are more than one generator, we only rewrite the first one and wait for + // the next analyzer iteration to rewrite the next one. + case AliasedGenerator(generator, names, outer) if res._1.isEmpty && + generator.childrenResolved => val g = Generate( generator, unrequiredChildIndex = Nil, @@ -2959,7 +2967,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor qualifier = None, generatorOutput = ResolveGenerate.makeGeneratorOutput(generator, names), child) - (Some(g), res._2 ++ g.nullableOutput) case other => (res._1, res._2 :+ other) @@ -2979,6 +2986,10 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case u: UnresolvedTableValuedFunction => u + case p: Project => p + + case a: Aggregate => a + case p if p.expressions.exists(hasGenerator) => throw QueryCompilationErrors.generatorOutsideSelectError(p) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 89bed0518027e..4a979fd214aba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -66,12 +66,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB messageParameters = messageParameters) } - protected def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = { - exprs.flatMap(_.collect { - case e: Generator => e - }).length > 1 - } - protected def hasMapType(dt: DataType): Boolean = { dt.existsRecursively(_.isInstanceOf[MapType]) } @@ -669,10 +663,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB )) } - case p @ Project(exprs, _) if containsMultipleGenerators(exprs) => - val generators = exprs.filter(expr => expr.exists(_.isInstanceOf[Generator])) - throw QueryCompilationErrors.moreThanOneGeneratorError(generators, "SELECT") - case p @ Project(projectList, _) => projectList.foreach(_.transformDownWithPruning( _.containsPattern(UNRESOLVED_WINDOW_EXPRESSION)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 3ab6c22e5fda4..24d54d1e6b0b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -266,11 +266,10 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat messageParameters = Map("expression" -> toSQLExpr(trimmedNestedGenerator))) } - def moreThanOneGeneratorError(generators: Seq[Expression], clause: String): Throwable = { + def moreThanOneGeneratorError(generators: Seq[Expression]): Throwable = { new AnalysisException( errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR", messageParameters = Map( - "clause" -> clause, "num" -> generators.size.toString, "generators" -> generators.map(toSQLExpr).mkString(", "))) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 46d7261a747dc..50d6cad6b6914 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -394,11 +394,6 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { "inputType" -> "\"BOOLEAN\"", "requiredType" -> "\"INT\"")) - errorTest( - "too many generators", - listRelation.select(Explode($"list").as("a"), Explode($"list").as("b")), - "only one generator" :: "explode" :: Nil) - errorClassTest( "unresolved attributes", testRelation.select($"abcd"), @@ -804,18 +799,11 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { "SUM_OF_LIMIT_AND_OFFSET_EXCEEDS_MAX_INT", Map("limit" -> "1000000000", "offset" -> "2000000000")) - errorTest( - "more than one generators in SELECT", - listRelation.select(Explode($"list"), Explode($"list")), - "The generator is not supported: only one generator allowed per select clause but found 2: " + - """"explode(list)", "explode(list)"""" :: Nil - ) - errorTest( "more than one generators for aggregates in SELECT", testRelation.select(Explode(CreateArray(min($"a") :: Nil)), Explode(CreateArray(max($"a") :: Nil))), - "The generator is not supported: only one generator allowed per select clause but found 2: " + + "The generator is not supported: only one generator allowed per SELECT clause but found 2: " + """"explode(array(min(a)))", "explode(array(max(a)))"""" :: Nil ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 8ef95e6fd129b..10453b6009a18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -380,20 +380,6 @@ class DataFrameSuite extends QueryTest Row("a", Seq("a"), 1) :: Nil) } - test("more than one generator in SELECT clause") { - val df = Seq((Array("a"), 1)).toDF("a", "b") - - checkError( - exception = intercept[AnalysisException] { - df.select(explode($"a").as("a"), explode($"a").as("b")) - }, - errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR", - parameters = Map( - "clause" -> "SELECT", - "num" -> "2", - "generators" -> "\"explode(a)\", \"explode(a)\"")) - } - test("sort after generate with join=true") { val df = Seq((Array("a"), 1)).toDF("a", "b") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index da93c70a5b074..96c030258d36e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -442,7 +442,6 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { }, errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR", parameters = Map( - "clause" -> "aggregate", "num" -> "2", "generators" -> ("\"explode(array(min(c2), max(c2)))\", " + "\"posexplode(array(min(c2), max(c2)))\""))) @@ -553,6 +552,32 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { checkAnswer(df, Row(0.7604953758285915d)) } } + + test("SPARK-47241: two generator functions in SELECT") { + def testTwoGenerators(needImplicitCast: Boolean): Unit = { + val df = sql( + s""" + |SELECT + |explode(array('a', 'b')) as c1, + |explode(array(0L, ${if (needImplicitCast) "0L + 1" else "1L"})) as c2 + |""".stripMargin) + checkAnswer(df, Seq(Row("a", 0L), Row("a", 1L), Row("b", 0L), Row("b", 1L))) + } + testTwoGenerators(needImplicitCast = true) + testTwoGenerators(needImplicitCast = false) + } + + test("SPARK-47241: generator function after SELECT *") { + val df = sql( + s""" + |SELECT *, explode(array('a', 'b')) as c1 + |FROM + |( + | SELECT id FROM range(1) GROUP BY 1 + |) + |""".stripMargin) + checkAnswer(df, Seq(Row(0, "a"), Row(0, "b"))) + } } case class EmptyGenerator() extends Generator with LeafLike[Expression] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala index d4e4a41155eaf..228b0cc02e42d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala @@ -655,18 +655,6 @@ class QueryCompilationErrorsSuite parameters = Map("expression" -> "\"(explode(array(1, 2, 3)) + 1)\"")) } - test("UNSUPPORTED_GENERATOR: only one generator allowed") { - val e = intercept[AnalysisException]( - sql("""select explode(Array(1, 2, 3)), explode(Array(1, 2, 3))""").collect() - ) - - checkError( - exception = e, - errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR", - parameters = Map("clause" -> "SELECT", "num" -> "2", - "generators" -> "\"explode(array(1, 2, 3))\", \"explode(array(1, 2, 3))\"")) - } - test("UNSUPPORTED_GENERATOR: generators are not supported outside the SELECT clause") { val e = intercept[AnalysisException]( sql("""select 1 from t order by explode(Array(1, 2, 3))""").collect() diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 1468f325b651c..86e6b01cb6cae 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -160,28 +160,6 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd | SELECT key FROM gen_tmp ORDER BY key ASC; """.stripMargin) - test("multiple generators in projection") { - checkError( - exception = intercept[AnalysisException] { - sql("SELECT explode(array(key, key)), explode(array(key, key)) FROM src").collect() - }, - errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR", - parameters = Map( - "clause" -> "SELECT", - "num" -> "2", - "generators" -> "\"explode(array(key, key))\", \"explode(array(key, key))\"")) - - checkError( - exception = intercept[AnalysisException] { - sql("SELECT explode(array(key, key)) as k1, explode(array(key, key)) FROM src").collect() - }, - errorClass = "UNSUPPORTED_GENERATOR.MULTI_GENERATOR", - parameters = Map( - "clause" -> "SELECT", - "num" -> "2", - "generators" -> "\"explode(array(key, key))\", \"explode(array(key, key))\"")) - } - createQueryTest("! operator", """ |SELECT a FROM ( From 942a3d37acb9a6b3ec82bd15bd5704d75df8df1c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 7 Mar 2024 09:58:28 +0800 Subject: [PATCH 2/2] Update sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala --- .../scala/org/apache/spark/sql/GeneratorFunctionSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index 96c030258d36e..97a56bdea7be7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -567,7 +567,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSparkSession { testTwoGenerators(needImplicitCast = false) } - test("SPARK-47241: generator function after SELECT *") { + test("SPARK-47241: generator function after wildcard in SELECT") { val df = sql( s""" |SELECT *, explode(array('a', 'b')) as c1