From f5f44a7df263bcda13095059eb46b63b95bd5811 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 26 Mar 2024 15:18:00 +0800 Subject: [PATCH] fix analyzer rule order issues about Alias --- .../expressions/complexTypeCreator.scala | 1 + .../sql/catalyst/parser/AstBuilder.scala | 2 +- .../catalyst/parser/UnpivotParserSuite.scala | 39 +++++++++---------- 3 files changed, 20 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 332a49f78ab98..993684f2c1ed4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -374,6 +374,7 @@ object CreateStruct { // alias name inside CreateNamedStruct. case (u: UnresolvedAttribute, _) => Seq(Literal(u.nameParts.last), u) case (u @ UnresolvedExtractValue(_, e: Literal), _) if e.dataType == StringType => Seq(e, u) + case (a: Alias, _) => Seq(Literal(a.name), a) case (e: NamedExpression, _) if e.resolved => Seq(Literal(e.name), e) case (e: NamedExpression, _) => Seq(NamePlaceholder, e) case (g @ GetStructField(_, _, Some(name)), _) => Seq(Literal(name), g) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 131ea2222a3dc..170dcc37f0a56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1346,7 +1346,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { * Create an Unpivot column. */ override def visitUnpivotColumn(ctx: UnpivotColumnContext): NamedExpression = withOrigin(ctx) { - UnresolvedAlias(UnresolvedAttribute(visitMultipartIdentifier(ctx.multipartIdentifier))) + UnresolvedAttribute(visitMultipartIdentifier(ctx.multipartIdentifier)) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/UnpivotParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/UnpivotParserSuite.scala index c680e08c1c832..3012ef6f1544d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/UnpivotParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/UnpivotParserSuite.scala @@ -39,7 +39,7 @@ class UnpivotParserSuite extends AnalysisTest { "SELECT * FROM t UNPIVOT (val FOR col in (a, b))", Unpivot( None, - Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + Some(Seq(Seq($"a"), Seq($"b"))), None, "col", Seq("val"), @@ -59,7 +59,7 @@ class UnpivotParserSuite extends AnalysisTest { sql, Unpivot( None, - Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + Some(Seq(Seq($"a"), Seq($"b"))), Some(Seq(Some("A"), None)), "col", Seq("val"), @@ -76,7 +76,7 @@ class UnpivotParserSuite extends AnalysisTest { "SELECT * FROM t UNPIVOT ((val1, val2) FOR col in ((a, b), (c, d)))", Unpivot( None, - Some(Seq(Seq($"a", $"b").map(UnresolvedAlias(_)), Seq($"c", $"d").map(UnresolvedAlias(_)))), + Some(Seq(Seq($"a", $"b"), Seq($"c", $"d"))), None, "col", Seq("val1", "val2"), @@ -96,10 +96,7 @@ class UnpivotParserSuite extends AnalysisTest { sql, Unpivot( None, - Some(Seq( - Seq($"a", $"b").map(UnresolvedAlias(_)), - Seq($"c", $"d").map(UnresolvedAlias(_)) - )), + Some(Seq(Seq($"a", $"b"), Seq($"c", $"d"))), Some(Seq(Some("first"), None)), "col", Seq("val1", "val2"), @@ -132,7 +129,7 @@ class UnpivotParserSuite extends AnalysisTest { sql, Unpivot( None, - Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + Some(Seq(Seq($"a"), Seq($"b"))), None, "col", Seq("val"), @@ -169,7 +166,7 @@ class UnpivotParserSuite extends AnalysisTest { "SELECT * FROM t UNPIVOT EXCLUDE NULLS (val FOR col in (a, b))", Unpivot( None, - Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + Some(Seq(Seq($"a"), Seq($"b"))), None, "col", Seq("val"), @@ -184,7 +181,7 @@ class UnpivotParserSuite extends AnalysisTest { "SELECT * FROM t UNPIVOT INCLUDE NULLS (val FOR col in (a, b))", Unpivot( None, - Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + Some(Seq(Seq($"a"), Seq($"b"))), None, "col", Seq("val"), @@ -199,7 +196,7 @@ class UnpivotParserSuite extends AnalysisTest { "SELECT * FROM t1 UNPIVOT (val FOR col in (a, b)) JOIN t2", Unpivot( None, - Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + Some(Seq(Seq($"a"), Seq($"b"))), None, "col", Seq("val"), @@ -211,7 +208,7 @@ class UnpivotParserSuite extends AnalysisTest { "SELECT * FROM t1 JOIN t2 UNPIVOT (val FOR col in (a, b))", Unpivot( None, - Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + Some(Seq(Seq($"a"), Seq($"b"))), None, "col", Seq("val"), @@ -224,7 +221,7 @@ class UnpivotParserSuite extends AnalysisTest { table("t1").join( Unpivot( None, - Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + Some(Seq(Seq($"a"), Seq($"b"))), None, "col", Seq("val"), @@ -239,7 +236,7 @@ class UnpivotParserSuite extends AnalysisTest { "SELECT * FROM t1 UNPIVOT (val FOR col in (a, b)), t2", Unpivot( None, - Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + Some(Seq(Seq($"a"), Seq($"b"))), None, "col", Seq("val"), @@ -251,7 +248,7 @@ class UnpivotParserSuite extends AnalysisTest { "SELECT * FROM t1, t2 UNPIVOT (val FOR col in (a, b))", Unpivot( None, - Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + Some(Seq(Seq($"a"), Seq($"b"))), None, "col", Seq("val"), @@ -267,7 +264,7 @@ class UnpivotParserSuite extends AnalysisTest { table("t1").join( Unpivot( None, - Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + Some(Seq(Seq($"a"), Seq($"b"))), None, "col", Seq("val"), @@ -282,7 +279,7 @@ class UnpivotParserSuite extends AnalysisTest { table("t1").join( Unpivot( None, - Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + Some(Seq(Seq($"a"), Seq($"b"))), None, "col", Seq("val"), @@ -296,7 +293,7 @@ class UnpivotParserSuite extends AnalysisTest { "SELECT * FROM t1, t2 JOIN t3 UNPIVOT (val FOR col in (a, b))", Unpivot( None, - Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + Some(Seq(Seq($"a"), Seq($"b"))), None, "col", Seq("val"), @@ -311,7 +308,7 @@ class UnpivotParserSuite extends AnalysisTest { table("t1").join( Unpivot( None, - Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + Some(Seq(Seq($"a"), Seq($"b"))), None, "col", Seq("val"), @@ -326,13 +323,13 @@ class UnpivotParserSuite extends AnalysisTest { "SELECT * FROM t1 UNPIVOT (val FOR col in (a, b)) UNPIVOT (val FOR col in (a, b))", Unpivot( None, - Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + Some(Seq(Seq($"a"), Seq($"b"))), None, "col", Seq("val"), Unpivot( None, - Some(Seq(Seq(UnresolvedAlias($"a")), Seq(UnresolvedAlias($"b")))), + Some(Seq(Seq($"a"), Seq($"b"))), None, "col", Seq("val"),