From 5ab688fe752e8da100003a7c0b973653ad2b0e6f Mon Sep 17 00:00:00 2001 From: Mihailo Timotic Date: Tue, 18 Mar 2025 16:46:13 +0100 Subject: [PATCH] fix --- .../sql/catalyst/analysis/Analyzer.scala | 25 +++------ .../analysis/ColumnResolutionHelper.scala | 9 +++- .../spark/sql/LateralColumnAliasSuite.scala | 51 ++++++------------- 3 files changed, 31 insertions(+), 54 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 93e6fb6746a1..5609961616f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2983,20 +2983,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } } - // We must wait until all expressions except for generator functions are resolved before - // rewriting generator functions in Project/Aggregate. This is necessary to make this rule - // stable for different execution orders of analyzer rules. See also SPARK-47241. - private def canRewriteGenerator(namedExprs: Seq[NamedExpression]): Boolean = { - namedExprs.forall { ne => - ne.resolved || { - trimNonTopLevelAliases(ne) match { - case AliasedGenerator(_, _, _) => true - case _ => false - } - } - } - } - def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( _.containsPattern(GENERATOR), ruleId) { case p @ Project(Seq(UnresolvedStarWithColumns(_, _, _)), _) => @@ -3015,8 +3001,11 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor val generators = aggList.filter(hasGenerator).map(trimAlias) throw QueryCompilationErrors.moreThanOneGeneratorError(generators) - case Aggregate(groupList, aggList, child, _) if canRewriteGenerator(aggList) && - aggList.exists(hasGenerator) => + case Aggregate(groupList, aggList, child, _) if + aggList.forall { + case AliasedGenerator(_, _, _) => true + case other => other.resolved + } && aggList.exists(hasGenerator) => // If generator in the aggregate list was visited, set the boolean flag true. var generatorVisited = false @@ -3061,8 +3050,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor // first for replacing `Project` with `Aggregate`. p - case p @ Project(projectList, child) if canRewriteGenerator(projectList) && - projectList.exists(hasGenerator) => + // The star will be expanded differently if we insert `Generate` under `Project` too early. + case p @ Project(projectList, child) if !projectList.exists(_.exists(_.isInstanceOf[Star])) => val (resolvedGenerator, newProjectList) = projectList .map(trimNonTopLevelAliases) .foldLeft((None: Option[Generate], Nil: Seq[NamedExpression])) { (res, e) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index e778342d0837..b2e068fd990b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -406,7 +406,14 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase { // Lateral column alias does not have qualifiers. We always use the first name part to // look up lateral column aliases. val lowerCasedName = u.nameParts.head.toLowerCase(Locale.ROOT) - aliasMap.get(lowerCasedName).map { + aliasMap.get(lowerCasedName).filter { + // Do not resolve LCA with aliased `Generator`, as it will be rewritten by the rule + // `ExtractGenerator` with fresh output attribute IDs. The `Generator` will be pulled + // out and put in a `Generate` node below `Project`, so that we can resolve the column + // normally without LCA resolution. + case scala.util.Left(alias) => !alias.child.isInstanceOf[Generator] + case _ => true + }.map { case scala.util.Left(alias) => if (alias.resolved) { val resolvedAttr = resolveExpressionByPlanOutput( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala index 3def42cd7ee5..ae8d84aadb03 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala @@ -1366,40 +1366,21 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { sql("select 1 as a, a").queryExecution.assertAnalyzed() } - test("SPARK-49349: Improve error message for LCA with Generate") { - checkError( - exception = intercept[AnalysisException] { - sql( - s""" - |SELECT - | explode(split(name , ',')) AS new_name, - | new_name like 'a%' - |FROM $testTable - |""".stripMargin) - }, - condition = "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_GENERATOR", - sqlState = "0A000", - parameters = Map( - "lca" -> "`new_name`", - "generatorExpr" -> "\"unresolvedalias(lateralAliasReference(new_name) LIKE a%)\"")) - - checkError( - exception = intercept[AnalysisException] { - sql( - s""" - |SELECT - | explode_outer(from_json(name,'array>')) as newName, - | size(from_json(newName.values,'array')) + - | size(array(from_json(newName.values,'map'))) as size - |FROM $testTable - |""".stripMargin) - }, - condition = "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_GENERATOR", - sqlState = "0A000", - parameters = Map( - "lca" -> "`newName.values`", - "generatorExpr" -> ("\"(size(from_json(lateralAliasReference(newName.values), " + - "array)) + size(array(from_json(lateralAliasReference(newName.values), " + - "map)))) AS size\""))) + test("LateralColumnAlias with Generate") { + checkAnswer( + sql("WITH cte AS (SELECT EXPLODE(ARRAY(1, 2, 3)) AS c1, c1) SELECT * FROM cte"), + Row(1, 1) :: Row(2, 2) :: Row(3, 3) :: Nil + ) + checkAnswer( + sql( + s""" + |SELECT + | explode(split(name , ',')) AS new_name, + | new_name like 'a%' + |FROM $testTable + |""".stripMargin), + Row("alex", true) :: Row("amy", true) :: Row("cathy", false) :: + Row("david", false) :: Row("jen", false) :: Nil + ) } }